English | 中文版
9. DeepSeek Inference: A Cross-Platform Kernel Benchmark Suite
Summary: Softmax and GEMM are useful microbenchmarks, but a real inference workload is the only honest test of a kernel toolchain. We packaged the 13 kernels needed for a full DeepSeek-R1-Distill-Qwen-1.5B decode step as a portable suite and measured the same Rust source on four production accelerators. Headline result: 168.9 tok/s on Ascend 910B2 via the joint
mlir_to_cpp+mlir_to_ptopath (2.47× the aclnn-only baseline, 45.6× the 3.7 tok/s CPU reference). Cross-vendor cross-validation: 162.9 tok/s on Google TPU v2-8 via emitted Pallas, 91.7 tok/s on Apple M2 Max via emitted Metal (beating Apple’s hand-tuned MLX on decode), and 53.7 tok/s on NVIDIA T4 via emitted CUDA — all from the same 13-kernel Rust source. The rest of the chapter documents the suite, the per-platform results, the two backends where end-to-end tok/s is not reported (AWS Trainium NKI and Vulkan/SPIR-V) and why, and how to reproduce any of the numbers above.
9.1 Why DeepSeek?
DeepSeek-R1-Distill-Qwen-1.5B is small enough to fit in 8 GB of unified memory, large enough to be bandwidth-bound on every realistic accelerator, and architecturally representative of the modern transformer family:
- Grouped-query attention (GQA) — 12 Q-heads share 2 KV-heads.
- SwiGLU MLP — three matmuls per layer, fusable into one kernel.
- RMSNorm — replaces LayerNorm everywhere.
- Rotary position embeddings — applied in-place to Q and K.
Per token, decode reads ≈ 2.6 GB of weights across 28 layers. That makes it a bandwidth benchmark, not a FLOPs benchmark. The hardware ceiling is bandwidth ÷ bytes_per_token:
| Device | Memory bandwidth | Theoretical max tok/s |
|---|---|---|
| Apple M2 Max | 400 GB/s | 154 |
| Apple M4 | 120 GB/s | 46 |
| Apple M4 Pro | 273 GB/s | 105 |
| NVIDIA H100 SXM | 3,350 GB/s | 1,288 |
| NVIDIA RTX 4090 | 1,008 GB/s | 388 |
| NVIDIA Tesla T4 | 320 GB/s | 123 |
| AWS Trainium2 | 2,800 GB/s | 1,077 |
| Google TPU v2-8 | 600 GB/s | 231 |
| Huawei Ascend 910B2 | 1,228 GB/s | 472 |
| Cambricon MLU590 | 1,228 GB/s | 472 |
Any kernel that reaches 60% of this number is competitive with hand-tuned production code; reaching 80% is the goal of a memory-bound kernel. CPU reference throughput on the same model is 3.7 tok/s — the floor every accelerator path has to clear.
9.2 The 13-Kernel Suite
A full transformer layer in decode mode reduces to 8 dispatches plus 5 model-level kernels (embedding, two RMSNorm variants, RoPE, argmax). The complete list, with shapes for the 1.5B model (D=1536, NH=12, NKV=2, DH=128, INTER=8960, VOCAB=151936):
| # | Kernel | Op | Input → Output shape |
|---|---|---|---|
| 1 | rms_norm_1536 | RMSNorm + γ scale | (1, D) → (1, D) |
| 2 | embedding_lookup | gather row from table | (VOCAB, D), (1,) → (1, D) |
| 3 | q_proj_matvec | matvec + bias | (1, D) → (1, NH·DH) |
| 4 | kv_proj_matvec | fused K + V matvec + bias | (1, D) → (1, NKV·DH) × 2 |
| 5 | rope_q_decode | RoPE on Q heads, in place | (NH, DH) → (NH, DH) |
| 6 | rope_k_decode | RoPE on K heads, in place | (NKV, DH) → (NKV, DH) |
| 7 | attention_decode_gqa | GQA attention with KV cache | (NH, DH) + KV cache → (NH, DH) |
| 8 | o_proj_residual | O-projection + residual add | (1, NH·DH) → (1, D) |
| 9 | mlp_gate_up_silu | fused gate + up + silu·mul | (1, D) → (1, INTER) |
| 10 | down_proj_residual | down-projection + residual add | (1, INTER) → (1, D) |
| 11 | silu_mul_fused | standalone SwiGLU | (1, INTER) × 2 → (1, INTER) |
| 12 | residual_add | elementwise add | (1, D) × 2 → (1, D) |
| 13 | argmax_greedy | argmax over logits | (1, VOCAB) → (1, 1) u32 |
The full Rust source is at crates/deepseek_metal/src/tile_kernels.rs, expressed against the safe tile.rs view API:
#[ascend_std::aiv_kernel]
pub unsafe fn rms_norm_1536(input: *const f32, gamma: *const f32, output: *mut f32) {
let ctx = unsafe { GmDeviceCtx::new() };
let in_v = unsafe { ctx.view::<1, D, f32>(input) };
let g_v = unsafe { ctx.view::<1, D, f32>(gamma) };
let out_v = unsafe { ctx.view_mut::<1, D, f32>(output) };
let x = tile_load_view_f32(&in_v);
let g = tile_load_view_f32(&g_v);
let normed = safe::tile_rms_norm_f32::<1, D>(x, 1e-6);
let out = safe::tile_mul_f32::<1, D>(normed, g);
tile_store_view_f32(&out_v, out);
}
The same source compiles to every mlir_to_<target> backend. Per-target reference kernels are checked in under benchmarks/deepseek_tile_kernels/templates/<target>/.
9.3 Ascend 910B2 — Headline Result
Hardware: Huawei Ascend 910B2, CANN 8.5.0, bisheng compiler, joint mlir_to_cpp + mlir_to_pto codegen path.
Setup: 28-layer DeepSeek-R1-Distill-Qwen-1.5B, f16 weights, single ACL stream per forward pass. The decode path uses cpp-tile kernels for RMSNorm / RoPE / SiLU, PTO cube matmul for the per-layer f16 projections, and cached-executor aclnnIncreFlashAttention for attention.
| Implementation | Decode tok/s | Speedup |
|---|---|---|
| CPU reference (float) | 3.7 | 1.00× |
| aclnn-only baseline | 68.3 | 18.5× |
ascend-rs (joint mlir_to_cpp + mlir_to_pto) | 168.9 | 45.6× (2.47× vs. aclnn) |
How that 168.9 was reached
The sequence of optimisations applied on 910B2, each measured against the previous step:
| Step | tok/s | Δ |
|---|---|---|
aclnn-only baseline (f16 matmuls via aclnnMatmul) | 68.3 | — |
| f16 PTO matmuls for all per-layer Q/K/V/O/gate/up/down projections | 114.5 | +46.2 |
| Host-side B-repack lm_head on PTO | 149.4 | +34.9 |
| Fused kv-proj and gate-up weight concatenation (single matmul per pair) | 151.6 | +2.2 |
Custom cpp-tile residual_add_rms_norm (4.4 µs vs aclnn fused 27 µs) | 157.5 | +5.9 |
Cached-executor aclnnIncreFlashAttention (38 µs vs plain 61 µs) | 168.0 | +10.5 |
Misc: lm_head chunk sweep, QKV fusion, attention_1head_cpp via vec matvec | 168.9 | +0.9 |
The two custom kernels that contributed most (residual_add_rms_norm cpp-tile fused, and the f16 PTO matmul blocking) are both generated by rustc_codegen_mlir from plain Rust tile-API source — no hand-written AscendC. Detailed per-op timings are in Appendix I.
Same binary on 910C
The same built artifact rebuilds on Ascend 910C (cube-only), with the ptoas --cce-fatobj-link path handling the matmul side. On 910C the split is 98.4% of per-layer time on NPU, 1.6% CPU — the only kernel still on host is RMSNorm, because the 910C cube unit doesn’t speed it up (it is memory-bound and the DMA copy dominates). End-to-end tok/s on 910C is not reported pending longer correctness validation across all 28 layers on a stable 910C chip allocation.
9.4 Google TPU v2-8 (Colab) — 162.9 tok/s
Hardware: Google Colab v2-8 (Cloud TPU, 8 cores × 8 MiB MXU, 600 GB/s HBM), mlir_to_tpu codegen emitting JAX Pallas.
Setup: rms_norm and rope_inplace via emitted Pallas kernels; GQA attention via emitted Pallas; matvec split by memory tier — Pallas for q/k/v/o projections (small, VMEM-friendly shapes) and XLA jnp.dot for gate/up/down/lm_head (large, benefit from XLA’s HBM-staging).
| Implementation | Decode tok/s | Parity with HF |
|---|---|---|
| ascend-rs (Rust → Pallas) | 162.9 | 16/16 greedy |
| Native JAX baseline (same shapes) | ≈ 166 | 16/16 |
The generated Pallas kernels reach 0.98× the native JAX baseline averaged over all per-op head-to-head measurements. Greedy-token parity was confirmed end-to-end: 16 out of 16 generated tokens match the HuggingFace reference implementation byte-for-byte. The TPU result is the most important cross-vendor cross-validation in the suite: it shows that a backend with no C++ exit at all — Pallas goes straight from a Python DSL to XLA — produces competitive output from the same Rust source that targets AscendC.
9.5 Apple M2 Max — 91.7 tok/s (Beats Hand-Tuned MLX)
Hardware: Apple M2 Max, 12-core CPU, 38-core GPU, 400 GB/s unified memory bandwidth, macOS 14.5, Metal 3.1.
Setup: 28-layer DeepSeek-R1-Distill-Qwen-1.5B, bf16 weights uploaded directly to GPU as Metal bfloat. Single Metal command buffer per forward pass. Repetition penalty 1.3, temperature 0.0 (greedy).
| Implementation | Decode tok/s | % of peak (154) |
|---|---|---|
| ascend-rs (Rust → MSL) | 91.7 | 60% |
| MLX 0.29.1 (Apple, hand-tuned) | ≈ 88 | 57% |
The Rust-source kernels, after passing through rustc_codegen_mlir → mlir_to_msl, outperform Apple’s hand-tuned MLX on decode. Decode is the dominant cost in a typical inference session (one prompt, hundreds of generated tokens), so this is the number that matters for end-user latency.
Apple M4 (4P+6E CPU, 10-core GPU, 120 GB/s): decode 33–35 tok/s vs MLX 32 tok/s — the Metal codegen path beats MLX on this smaller part as well, but prefill (9.3 vs MLX 72) is still gated on rewriting the prefill matmul to use simdgroup_matrix_multiply.
How that 91.7 was reached
Optimization rounds on M2 Max (each step measured against the previous):
| Step | tok/s | Δ |
|---|---|---|
| Baseline (templates as committed) | 90.3 | — |
attention_decode_v4 (TG-mem Q cache + float4) | 91.3 | +1.0 |
| Token-buffer hoist out of inner loop | 91.7 | +0.4 |
| Final | 91.7 | +1.4 |
Two attempted optimisations were measured and rolled back because they regressed:
| Attempted | tok/s | Δ |
|---|---|---|
matvec_f16_cached (manual A-cache) | 85.1 | −5.2 (revert) |
| Fused RMSNorm + next matvec | 78.7 | −13 (revert) |
The Apple GPU’s L1/L2 already caches reused activations, so manual threadgroup caching only helps when (a) the data doesn’t fit in cache and (b) the per-thread compute is large enough to amortize the barrier. For decode matvec with K = 1536 (6 KB), neither holds.
9.6 NVIDIA Tesla T4 (Colab) — 53.7 tok/s
Hardware: NVIDIA Tesla T4 on Google Colab, 320 GB/s HBM2, CUDA 12.1, mlir_to_gpu codegen emitting CUDA C, compiled with nvcc -arch=sm_75 -O3.
Setup: emitted rms_norm_1536, matvec_f16 (with _bias and _add variants for the fused cases), and GQA attention_decode_gqa drive the decode loop; host-side Python glue for weight loading and tokenization.
| Implementation | Decode tok/s |
|---|---|
| ascend-rs (Rust → CUDA) | 53.7 |
| Theoretical peak at 320 GB/s | 123 |
53.7 tok/s is 44% of the T4’s theoretical bandwidth ceiling. The remaining gap is split between sub-optimal matvec tiling (the mlir_to_gpu path uses one-element-per-thread today, not warp-striped) and matmul_f32 still routing through cuBLAS as a placeholder. Both are tracked in Chapter 13 §12.3.1 as the short-term mlir_to_gpu + cudarc integration work.
Per-token kernel parity with the Ascend result: all 13 kernels compile; the emitted .cu source is 2,001 LOC generated from the same 13-kernel tile_kernels.rs.
9.7 Where the Time Goes — Per-Kernel Breakdown (M2 Max)
For one decoded token on M2 Max (28 layers × 8 dispatches + 5 model-level dispatches = 229 kernel launches):
| Kernel class | Per-token time (ms) | % of decode |
|---|---|---|
| Q/K/V/O matvecs | 4.3 | 39% |
| Gate + up + silu (MLP) | 3.1 | 28% |
| Down-projection | 2.1 | 19% |
| Attention (decode v4) | 0.8 | 7% |
| RMSNorm × 2/layer | 0.4 | 4% |
| RoPE Q + K | 0.2 | 2% |
| Argmax over vocab | 0.1 | 1% |
| Total | 11.0 | 100% |
The seven matvec/MLP kernels — items 3, 4, 8, 9, 10 from the suite in §9.2 — account for 86% of decode time. Optimisation effort returns the most when spent on those kernels, which is why all the wins listed in §9.5 targeted the matvec / attention path. Norms and RoPE together cost less than 1 ms per token; fusing them away (as we tried) saves no measurable bandwidth and adds compute.
9.8 Cross-Vendor Status
The same 13-kernel Rust source is the input to every mlir_to_<target> backend. Current measured end-to-end status (numbers from Table 2 of the companion paper):
| Backend | Target | LOC | Decode tok/s |
|---|---|---|---|
mlir_to_cpp + mlir_to_pto | Ascend 910B2 (joint) | 11,383 + 4,955 | 168.9 |
mlir_to_tpu | Google TPU v2-8 (Pallas) | 1,645 | 162.9 |
mlir_to_msl | Apple M2 Max (Metal) | 1,730 | 91.7 |
mlir_to_gpu | NVIDIA T4 (CUDA) | 2,001 | 53.7 |
mlir_to_nki | AWS Trainium (trn1.2xlarge) | 1,872 | see note below |
mlir_to_spirv | Vulkan (any GPU) | 1,571 | see note below |
NKI (AWS Trainium). All six emitted kernels compile and run (rms_norm_1536, matvec_f16 / _bias / _add, gate_up_silu, GQA attention). End-to-end tok/s is not reported because @nki.jit uses eager dispatch with no cross-call kernel caching — each of the 370+ kernel dispatches per decoded token incurs ≈ 10 s of setup overhead. A compiled torch-neuronx graph wrapper would fold this into a single graph dispatch; that’s future work, not a codegen gap.
Vulkan (SPIR-V). End-to-end decode requires an adapter exposing the shader-f16 feature. The only hardware we have access to that both supports SPIR-V and runs Colab notebooks is the T4, and Colab’s T4 only exposes Mesa llvmpipe (a CPU rasterizer) through Vulkan — which would time out the decode loop. Per-kernel softmax on Apple M2 Max via the Vulkan backend reaches 90× CPU speedup (see Appendix I).
For the remaining backends in the tree (mlir_to_musa, mlir_to_aie, mlir_to_bang, mlir_to_gaudi, mlir_to_csl, mlir_to_hexagon, mlir_to_linalg), the 13-kernel suite compiles cleanly; on-device decode measurement is blocked only on hardware time allocation for each rig.
9.9 Reproducing the Results
Apple M2 Max / M4:
git clone https://github.com/yijunyu/ascend-rs
cd ascend-rs
cargo run --release -p deepseek_metal -- \
--prompt "The capital of France is" \
--max-tokens 128
The first run downloads DeepSeek-R1-Distill-Qwen-1.5B from Hugging Face (≈ 3 GB) and caches it at ~/.cache/huggingface/. Subsequent runs print:
Loaded DeepSeek-R1-Distill-Qwen-1.5B on Metal
Prefill: 0.23s (26.1 tok/s)
[generated text]
Generated 128 tokens in 1.40s (91.43 tok/s)
MLX baseline for comparison:
pip install mlx mlx-lm
python -m mlx_lm.generate \
--model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
--prompt "The capital of France is" \
--max-tokens 128
Ascend 910B2 (requires CANN 8.5.0 and hardware access):
source /usr/local/Ascend/cann-8.5.0/set_env.sh
export ACLRS_SOC_VERSION=Ascend910B2
cargo run --release -p deepseek_e2e -- --max-tokens 128
TPU v2-8 (Colab) and NVIDIA T4 (Colab): notebooks are under benchmarks/deepseek_tile_kernels/notebooks/. Each notebook pulls the emitted mlir_to_<target> output from the repo and runs the decode loop against the same prompt set. All reproducible runs are logged as CSVs to the pu-rs.org open leaderboard (3,924 data points across all backends and targets as of 2026-04-23).
9.10 Why a Suite, Not a Single Kernel
Single-kernel benchmarks (softmax, GEMM, RMSNorm in isolation) are useful for diagnosing a specific bottleneck, but they systematically over-report the value of optimisations that don’t compose:
- Caching activations is a clear win on a standalone matvec benchmark and a clear loss inside a transformer layer where the cache is already warm from the previous matvec (§9.5).
- Fusing RMSNorm into the next matvec wins on a fused-kernel microbenchmark and loses inside a real layer where the same norm output is consumed by three matvecs (Q, K, V).
- A “fast attention” kernel that ignores the KV cache is irrelevant; in decode, the KV cache is the attention input.
A 13-kernel suite tied to a real model is the smallest benchmark that catches these mistakes. It also lets vendors compare backends honestly: every backend in §9.8 sees the same Rust source, the same shapes, and the same memory-traffic budget.
9.11 Key Takeaways
-
One Rust source, four production accelerators measured end-to-end. 168.9 tok/s on Ascend 910B2, 162.9 on Google TPU v2-8, 91.7 on Apple M2 Max, 53.7 on NVIDIA T4 — all from the same 13-kernel
tile_kernels.rs, compiled through differentmlir_to_<target>backends. The backends range from 1,571 LOC (SPIR-V) to 11,383 LOC (mlir_to_cpp), so targeting a new vendor is a bounded engineering exercise, not a research project. -
45.6× over CPU reference on 910B2, 2.47× over the aclnn-only baseline. The Ascend path demonstrates that a safety-first Rust kernel toolchain does not give up performance: the headline is set by a compiler-generated kernel pipeline, not by hand-written AscendC.
-
The Metal codegen path beats hand-tuned MLX on decode. 91.7 vs ≈ 88 on M2 Max and 33–35 vs 32 on M4. Apple’s engineers hand-tuned MLX against Apple’s own hardware; ascend-rs produces competitive output from Rust source written for a different vendor.
-
TPU Pallas cross-validation at 0.98× native JAX, 16/16 greedy-token parity with HF. The cleanest evidence that the Rust → MLIR → Pallas path is producing sound kernels, not numerically approximate ones.
-
Microbenchmarks lie about full-pipeline performance. Two optimisations measured in isolation as wins (caching, fusion) regressed the full decode path by 5–13 tok/s on M2 Max. Suite-level measurement is the only way to catch this.