ASIC chips designed for LLM inference are arriving. Groq's LPU, Cerebras's WSE, and a wave of startups are all chasing the same insight: autoregressive token generation is memory-bound, so build hardware with massive on-chip SRAM and skip the DRAM bottleneck entirely. The pitch is compelling — if your weights live on-chip, you eliminate the memory wall and inference becomes compute-limited.
But here's a question worth asking: what happens when you simulate this on a commodity GPU today? NVIDIA's RTX 5090 ships with 96 MB of L2 cache. A quantized 135M-parameter model fits in 85 MB. If you pin those weights in L2, you've effectively built a poor man's ASIC — all weights on-chip, no DRAM round-trips during generation.
This article documents what we found when we tried it. Spoiler: the memory wall does disappear. What replaces it is more interesting.
The Setup: SmolLM2-135M on RTX 5090
We built a custom CUDA inference engine from scratch for SmolLM2-135M, a 30-layer transformer with 576-dimensional hidden state, 9 query heads, 3 KV heads (GQA), and a 1536-dimensional FFN. The architecture is standard — RMSNorm, RoPE, grouped-query attention, SwiGLU MLP — just small enough to be interesting.
The model's weights are stored in GGUF's IQ4_NL and IQ4_XS quantization formats. IQ4_NL packs 32 values into 18 bytes: a half-precision scale factor and 16 bytes of 4-bit indices into a non-linear lookup table. The lookup table lives in CUDA constant memory for broadcast access:
__device__ __constant__ float d_kvalues_iq4nl[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f,
1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
The total weight pool — all 30 layers of IQ4_NL/IQ4_XS projections, Q8_0 embeddings, FP16 norms — comes to 85 MB.
The RTX 5090 (Blackwell, SM 12.0) has 96 MB of L2 cache. At engine startup, we pin the weight pool into L2 using cudaStreamSetAttribute with cudaAccessPolicyWindow:
| Property | Value |
|---|---|
| GPU | RTX 5090 (Blackwell) |
| VRAM | 32 GB GDDR7, 1,790 GB/s |
| L2 cache | 96 MB |
| Weight pool | 85 MB (IQ4_NL/IQ4_XS + Q8_0) |
| L2 hit ratio | ~100% during generation |
Once the weights are warm in L2, every mat-vec reads from on-chip cache. No DRAM traffic for weights. This is the ASIC scenario.
Phase 1: Naive FP16 — 750 tok/s
The first version used FP16 weights and straightforward kernels: one RMSNorm, one mat-vec per projection, separate RoPE, separate KV cache writes. This was the baseline to validate correctness.
At 750 tok/s for 128-token generation, it was already faster than running the same model under most Python-based frameworks, but well below llama.cpp's 1,110 tok/s. The FP16 weight pool was too large for L2 pinning, so this phase still hit DRAM for weights.
Phase 2: IQ4 Quantization + L2 Pinning — 1,255 tok/s
Switching to IQ4_NL/IQ4_XS quantization (loaded directly from GGUF, no conversion) shrunk the weight pool from ~270 MB to 85 MB. Now it fits in L2.
The mat-vec kernel design uses one warp (32 threads) per output row. Each warp iterates over IQ4 blocks, dequantizing through the shared-memory lookup table and accumulating a dot product against the input vector (also in shared memory). The warp reduction is a standard shuffle tree:
template<bool FUSE_RESIDUAL = false>
__global__ void matvec_iq4nl(half* __restrict__ out,
const void* __restrict__ W,
const half* __restrict__ x,
half* __restrict__ residual,
int out_dim, int in_dim) {
// ... cooperative load of x into shared memory ...
const int row = blockIdx.x * warps_per_block + warp_id;
float sum = 0.0f;
for (int b = 0; b < blocks_per_row; b++) {
float d = __half2float(row_blocks[b].d);
uint8_t q = row_blocks[b].qs[lane & 15];
int shift = (lane >> 4) << 2;
int idx = (q >> shift) & 0xf;
float w = d * s_kv[idx];
sum += w * s_x[b * 32 + lane];
}
// warp shuffle reduction ...
}
With L2 pinning, this hit 1,255 tok/s. A 67% improvement over FP16, mostly from the L2 effect — weights served at L2 bandwidth (~3-4 TB/s effective) instead of DRAM (1,790 GB/s peak).
At this point, the memory wall was gone. Now what?
The Dead End: Optimizing the Inner Loop
The natural instinct was to optimize the compute. IQ4_NL dequantization requires a shared-memory table lookup — what if we converted everything to Q8_0 at load time? Q8_0 dequant is a simple d * qs[i], no lookup needed.
We tried it. Mat-vec bandwidth improved from 95 to 152 GB/s. But tok/s barely moved: 1,255 to 1,262.
Why? Two reasons. First, Q8_0 is 34 bytes per 32 values vs. IQ4_NL's 18 bytes. The weight pool grew from 85 to 136 MB — too large for L2 pinning. We traded lookup latency for cache misses. Second, and more fundamental: the layer matrices are tiny. The largest FFN projection is 1536 rows of 576 elements. At that size, a single mat-vec completes in microseconds regardless of dequant cost. The kernel finishes before the GPU has time to be bottlenecked on anything.
The real bottleneck was hiding in the profile output. Each forward pass launched 301 kernels. Each kernel launch costs ~2.5 microseconds of driver overhead. That's 750 microseconds of pure launch tax — almost the entire per-token time budget of 792 microseconds.
The memory wall was gone. The dispatch wall had replaced it.
Phase 3: Kernel Fusion — 1,508 tok/s
Once we identified dispatch overhead as the bottleneck, the optimization strategy flipped. Instead of making individual kernels faster, we needed fewer of them. Each of the 30 layers ran 10 kernels. We fused them down to 6.
Fusion 1: Residual Addition into Mat-Vec
After the attention output projection and the FFN down projection, the original code ran a separate vec_add kernel to accumulate the residual:
// Before: two kernel launches
matvec_iq4nl<<<...>>>(xb, attn_output, attn_out, nullptr, DIM, DIM);
vec_add<<<...>>>(x, x, xb, DIM);
The vec_add kernel reads and writes 576 half values. It takes about 2 microseconds of compute but 2.5 microseconds to launch. We added a template parameter to the mat-vec kernel:
if (lane == 0) {
float result = sum;
if constexpr (FUSE_RESIDUAL) {
result += __half2float(residual[row]);
residual[row] = __float2half(result);
}
out[row] = __float2half(result);
}
Two lines of code. One fewer kernel launch per fusion site, two sites per layer, 60 launches eliminated.
Fusion 2: Gate/Up Projection + SwiGLU
The FFN block computes silu(gate(x)) * up(x) where gate and up are separate linear projections. The original code ran a fused RMSNorm + gate/up mat-vec (dispatching 384 blocks for 1536+1536 output rows) followed by a separate SwiGLU kernel.
We rewrote this so each warp computes both the gate and up dot products in a single pass over the normalized input in shared memory, then applies SwiGLU inline:
for (int b = 0; b < blocks_per_row; b++) {
float xval = s_xn[b * 32 + lane];
// Gate dot product
float dg = __half2float(gate_row_blocks[b].d);
uint8_t qg = gate_row_blocks[b].qs[lane & 15];
gate_sum += (dg * s_kv[(qg >> shift) & 0xf]) * xval;
// Up dot product
float du = __half2float(up_row_blocks[b].d);
uint8_t qu = up_row_blocks[b].qs[lane & 15];
up_sum += (du * s_kv[(qu >> shift) & 0xf]) * xval;
}
// After warp reduction of both accumulators:
float silu_gate = gate_sum / (1.0f + expf(-gate_sum));
gate_out[row] = __float2half(silu_gate * up_sum);
This halves the grid from 384 to 192 blocks, eliminates the SwiGLU kernel, and avoids writing the intermediate up_out buffer to DRAM. One fewer launch per layer, 30 eliminated.
Fusion 3: RoPE + KV Cache Write
RoPE (rotary position embeddings) and KV cache writes are both small operations on the 576-dimensional q/k/v vectors. We fused them into a single kernel of 384 threads (one CUDA block):
__global__ void fused_rope_kv_write(half* q, half* k, half* v,
half* key_cache, half* value_cache,
const int* pos_ptr, ...) {
// Phase 1: threads 0-287 apply RoPE to q (9 heads * 32 pairs)
// threads 288-383 apply RoPE to k (3 heads * 32 pairs)
__syncthreads();
// Phase 2: threads 0-191 write k to cache
// threads 192-383 write v to cache
}
Two kernel launches replaced by one, 30 more eliminated across all layers.
The Result
| Metric | Phase 2 | Phase 3 | Change |
|---|---|---|---|
| Dispatches per forward | 301 | 181 | -120 (-40%) |
| 128 tokens: tok/s | 1,327 | 1,508 | +13.7% |
| 128 tokens: per token | 754 us | 663 us | -91 us |
| 256 tokens: tok/s | 1,156 | 1,269 | +9.8% |
| 256 tokens: per token | 865 us | 788 us | -77 us |
Output is byte-identical between Phase 2 and Phase 3. The fusions are mathematically exact — same accumulation order, same precision, just fewer kernel boundaries.
The improvement shrinks at longer sequences because attention cost grows with sequence length while the dispatch savings remain constant at ~80-90 microseconds per token.
The Forward Pass: 6 Kernels Per Layer
After fusion, each transformer layer runs exactly 6 kernel launches:
for (int l = 0; l < N_LAYERS; l++) {
// 1. Fused: RMSNorm + QKV projection (IQ4_NL, 120 blocks)
fused_rmsnorm_qkv_iq4nl<<<mv_grid(960), 256, smem, stream>>>(...);
// 2. Fused: RoPE + KV cache write (1 block, 384 threads)
fused_rope_kv_write<<<1, 384, 0, stream>>>(...);
// 3. GQA attention (9 blocks, one per head)
gqa_attention_device<<<9, 256, smem, stream>>>(...);
// 4. Attention output projection + residual (72 blocks)
matvec_iq4nl<true><<<mv_grid(576), 256, smem, stream>>>(...);
// 5. Fused: RMSNorm + gate/up + SwiGLU (192 blocks)
fused_rmsnorm_gate_up_swiglu_iq4nl<<<mv_grid(1536), 256, smem, stream>>>(...);
// 6. FFN down projection + residual (72 blocks)
matvec_iq4xs<true><<<mv_grid(576), 256, smem, stream>>>(...);
}
Plus one final kernel for RMSNorm + lm_head. Total: 181 dispatches, captured as a CUDA graph and replayed each token.
What This Tells Us About the ASIC Thesis
The ASIC pitch is "put weights on-chip and inference gets fast." Our experiment confirms the first half: L2 pinning does eliminate the memory wall, and you get a significant speedup from quantization strategies that make your model fit.
But the second half — that inference then becomes compute-limited — doesn't hold for small models on GPUs. What we found instead is a third regime: dispatch-limited inference, where the overhead of launching hundreds of tiny kernels dominates both compute and memory access time.
This matters because it's a bottleneck that ASICs solve structurally. A hardwired transformer pipeline doesn't have kernel launch overhead. It's a static dataflow graph etched in silicon. GPUs, by contrast, pay a tax for their generality: the driver must set up registers, configure shared memory, and schedule thread blocks for every kernel launch, even if the kernel runs for 3 microseconds.
| Bottleneck | Phase | Tok/s | What limits performance |
|---|---|---|---|
| Memory bandwidth | Phase 1 (FP16) | 750 | Weights in DRAM, 1,790 GB/s bus |
| Still memory, but less | Phase 2 (IQ4 + L2) | 1,255 | Weights in L2, compute is trivial |
| Dispatch overhead | Phase 3 (fused) | 1,508 | 181 launches at ~2.5 us each |
At 1,508 tok/s with 128 tokens, per-token time is 663 microseconds. The 181 dispatches account for roughly 450 microseconds of that. Actual compute is somewhere around 200 microseconds. There's a 2-3x speedup still on the table if dispatch overhead were zero — which is roughly what an ASIC achieves.
Diminishing Returns and What's Next
The remaining 181 dispatches are harder to fuse pairwise. The QKV projection is already fused (3 weight matrices, 1 kernel). Attention is inherently a single kernel. The two remaining mat-vecs (attention output, FFN down) need their inputs computed first.
The next lever is a persistent kernel: instead of launching 6 kernels per layer, launch a single kernel that executes all 6 operations using block-level synchronization. This eliminates inter-kernel dispatch overhead within a layer entirely, potentially cutting per-token time by another 200+ microseconds. It also makes the code significantly harder to write — you're essentially building a manual scheduler inside a kernel.
Beyond that, speculative decoding is the orthogonal win. Rather than making one forward pass faster, generate multiple candidate tokens per pass and verify them. This is multiplicative with all the kernel-level optimizations.
Practical Takeaways
For model deployment: If your quantized model fits in L2, you're in a fundamentally different performance regime. Check your GPU's L2 size and do the math. The RTX 5090's 96 MB fits models up to ~500M parameters at 4-bit quantization. The RTX 4090's 72 MB is more constrained but still viable for sub-300M models.
For kernel development: Profile dispatches, not just compute. NVIDIA's Nsight tools report kernel launch overhead, but it's easy to overlook when individual kernels show microsecond execution times. The intuition that "the kernel is fast, so the code is fast" breaks down when you're launching hundreds of them.
For the ASIC vs. GPU question: Modern GPUs can already simulate the on-chip-weight scenario for small models, and the results are informative. The memory wall is real but solvable with quantization and cache pinning. What you find underneath is the dispatch wall — and solving that on a GPU requires increasingly aggressive kernel fusion, eventually converging on something that looks a lot like a hardwired pipeline. At some point, you're fighting the GPU's generality rather than leveraging it, and that's exactly the gap ASICs are designed to fill.
The code is open and the numbers are reproducible. SmolLM2-135M is small enough to experiment with in an afternoon but architecturally identical to models 100x its size. Every technique here — IQ4 quantization, L2 pinning, warp-per-row mat-vec, kernel fusion — transfers directly. The only thing that changes at scale is which wall you hit first.