feat - add fa3 attention for mtp#980
Open
zerozw wants to merge 1 commit into
Open
Conversation
Collaborator
AI Code Review - PR #980Status: BLOCKING Summary: P0/0 · P1/2 · P2/0 · P3/0 Blocking IssuesP1
Checklist Violations (6 fail / 78 total)General Principles Checklist
RTP-LLM Checklist
Python Static-First Checklist
Strengths
|
Contributor
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds a FlashAttention-3 (FA3) paged-attention implementation for MTP target-verify and draft-prefill CUDA-graph flows, plus a new heterogeneous-batch regression test suite to validate eager vs CUDA-graph replay correctness.
Changes:
- Introduces
PyFA3PagedAttnOpandPyFA3PrefillImplto run FA3 paged attention with stable CUDA-graph buffers. - Adds extensive heterogeneous-batch CUDA-graph replay regression tests and wires them into Bazel.
- Fixes CUDA-graph state leakage by zeroing host-side buffers (KV page table host + FlashInfer MLA host buffer).
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| rtp_llm/models_py/modules/factory/attention/cuda_impl/test/test_fa3_heterogeneous.py | New GPU regression suite covering heterogeneous batching + CUDA-graph capture/replay edge cases for FA3. |
| rtp_llm/models_py/modules/factory/attention/cuda_impl/test/BUILD | Adds Bazel py_test target for the new FA3 heterogeneous regression suite. |
| rtp_llm/models_py/modules/factory/attention/cuda_impl/py_fa3_paged.py | New FA3 paged attention op + factory-facing prefill impl with CUDA-graph-stable buffers. |
| rtp_llm/models_py/modules/factory/attention/init.py | Registers PyFA3PrefillImpl in the attention factory import set. |
| rtp_llm/models_py/bindings/cuda/FlashInferMlaParams.cc | Zeroes the full host buffer before filling to prevent stale tail data across replays. |
| rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc | Clears kv_cache_kernel_block_id_host alongside the device tensor to avoid stale host-side block IDs. |
| rtp_llm/cpp/cuda_graph/cuda_graph_prefill.cc | Minor formatting + sets sequence_lengths_plus_1_d during prefill capture. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| cu_scratch[0].zero_() | ||
| torch.cumsum(input_h, dim=0, out=cu_scratch[1 : n + 1]) | ||
| if n < self._fixed_batch_size: | ||
| cu_scratch[n + 1 :].fill_(cu_scratch[n]) |
Comment on lines
+371
to
+374
| k = paged[:, 0] # [num_pages, kv_heads, page_size, head_dim] | ||
| v = paged[:, 1] | ||
| # -> [num_pages, page_size, kv_heads, head_dim] | ||
| return k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) |
Comment on lines
+543
to
+549
| """Reproduction target: 4 requests with cache lengths spanning 6×. | ||
|
|
||
| Production smoke (5 iter × 8 concurrent × mixed taobao+cooking pool) | ||
| reliably triggers garbled output / cross-request KV-pollution in this | ||
| regime. Until the bug is fixed, this assertion will fail at the | ||
| ``CG-replay vs eager`` compare; the eager and per-request reference | ||
| paths still agree (sanity holds — only the CG path corrupts). |
| // pages owned by other active requests, producing cross-request KV cache | ||
| // corruption and the "** ** **" garbled output. | ||
| if (buf_h.defined() && buf_h.numel() > 0) { | ||
| std::memset(buf_h.data_ptr(), 0, buf_h.numel() * sizeof(int32_t)); |
Comment on lines
+205
to
+216
| py_test( | ||
| name = "test_fa3_heterogeneous", | ||
| srcs = [ | ||
| "test_fa3_heterogeneous.py", | ||
| "base_attention_test.py", | ||
| ], | ||
| deps = py_test_deps, | ||
| tags = ["H20"], | ||
| exec_properties = { | ||
| 'gpu': 'H20', | ||
| }, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.