Skip to content

feat - add fa3 attention for mtp#980

Open
zerozw wants to merge 1 commit into
mainfrom
feature/fa3_attention
Open

feat - add fa3 attention for mtp#980
zerozw wants to merge 1 commit into
mainfrom
feature/fa3_attention

Conversation

@zerozw
Copy link
Copy Markdown
Collaborator

@zerozw zerozw commented May 8, 2026

No description provided.

Copilot AI review requested due to automatic review settings May 8, 2026 12:06
@zerozw zerozw requested a review from LLLLKKKK as a code owner May 8, 2026 12:06
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

LLLLKKKK commented May 8, 2026

AI Code Review - PR #980

Status: BLOCKING

Summary: P0/0 · P1/2 · P2/0 · P3/0

Blocking Issues

P1

  • FA3 依赖顶层导入会让非 FA3 环境启动失败 @ rtp_llm/models_py/modules/factory/attention/cuda_impl/py_fa3_paged.py:47
    • 建议:按 headwise_fp8 的模式用 try/except 记录依赖可用性,并让 support() 在依赖不可用时返回 False,或延迟到真正选中 FA3 后再导入。
  • FA3 support 未排除 Hopper 以下 GPU @ rtp_llm/models_py/modules/factory/attention/cuda_impl/py_fa3_paged.py:441
    • 建议:新增 runtime_ready 判断:依赖存在、CUDA 可用、compute capability >= 9 且非 SM100 时才 support,并补缺依赖/非 Hopper 的 factory 选择单测。

Checklist Violations (6 fail / 78 total)

General Principles Checklist

  • [6.1] Architecture — 兼容性:公开 API/持久数据/配置/环境迁移安全 → issue FA3 依赖顶层导入会让非 FA3 环境启动失败
    Cuda factory 导入 FA3 模块时会同步导入 flash_attn_interface,缺依赖时没有 fallback。
  • [6.1] Tests — 分布式/跨平台变更有对应覆盖 → issue FA3 support 未排除 Hopper 以下 GPU
    FA3 support 未按测试中的 Hopper 前置条件限制,SM80/SM70 回退路径未覆盖。
  • [6.1] Tests — 新逻辑有聚焦单测 + 相关集成/smoke 测试 → issue FA3 support 未排除 Hopper 以下 GPU
    现有测试只覆盖 H20 正路径,缺少非 Hopper/缺依赖时 support 回退的覆盖。

RTP-LLM Checklist

  • [A] 兼容性与配置 — 可选依赖 lazy import,pybind 新字段有 C++ 默认值 → issue FA3 依赖顶层导入会让非 FA3 环境启动失败
    flash_attn_interface 是新增可选运行依赖,但当前在模块顶层导入,缺包会中断 factory import。
  • [H] 测试与 CI — 测试覆盖充分:大重构等价覆盖,新功能端到端测试 → issue FA3 support 未排除 Hopper 以下 GPU
    H20 异构 replay 覆盖较全,但缺少 factory support 在缺依赖/非 Hopper 上回退的测试。

Python Static-First Checklist

  • [P.F] 语言陷阱 — 禁止模块级 import 副作用 → issue FA3 依赖顶层导入会让非 FA3 环境启动失败
    模块导入阶段加载 flash_attn_interface,未选中 FA3 的环境也会承担依赖失败风险。

Strengths

  • FA3 路径把 CUDA graph replay 需要稳定 data_ptr 的 cache_seqlens、cu_seqlens_q 和 page_table 封装在 op 内,修复方向清晰。
  • 异构 batch 测试覆盖 target verify 与 draft prefill 的 replay/padding 场景,能保护核心 CUDA graph 行为。

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds a FlashAttention-3 (FA3) paged-attention implementation for MTP target-verify and draft-prefill CUDA-graph flows, plus a new heterogeneous-batch regression test suite to validate eager vs CUDA-graph replay correctness.

Changes:

  • Introduces PyFA3PagedAttnOp and PyFA3PrefillImpl to run FA3 paged attention with stable CUDA-graph buffers.
  • Adds extensive heterogeneous-batch CUDA-graph replay regression tests and wires them into Bazel.
  • Fixes CUDA-graph state leakage by zeroing host-side buffers (KV page table host + FlashInfer MLA host buffer).

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
rtp_llm/models_py/modules/factory/attention/cuda_impl/test/test_fa3_heterogeneous.py New GPU regression suite covering heterogeneous batching + CUDA-graph capture/replay edge cases for FA3.
rtp_llm/models_py/modules/factory/attention/cuda_impl/test/BUILD Adds Bazel py_test target for the new FA3 heterogeneous regression suite.
rtp_llm/models_py/modules/factory/attention/cuda_impl/py_fa3_paged.py New FA3 paged attention op + factory-facing prefill impl with CUDA-graph-stable buffers.
rtp_llm/models_py/modules/factory/attention/init.py Registers PyFA3PrefillImpl in the attention factory import set.
rtp_llm/models_py/bindings/cuda/FlashInferMlaParams.cc Zeroes the full host buffer before filling to prevent stale tail data across replays.
rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc Clears kv_cache_kernel_block_id_host alongside the device tensor to avoid stale host-side block IDs.
rtp_llm/cpp/cuda_graph/cuda_graph_prefill.cc Minor formatting + sets sequence_lengths_plus_1_d during prefill capture.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

cu_scratch[0].zero_()
torch.cumsum(input_h, dim=0, out=cu_scratch[1 : n + 1])
if n < self._fixed_batch_size:
cu_scratch[n + 1 :].fill_(cu_scratch[n])
Comment on lines +371 to +374
k = paged[:, 0] # [num_pages, kv_heads, page_size, head_dim]
v = paged[:, 1]
# -> [num_pages, page_size, kv_heads, head_dim]
return k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
Comment on lines +543 to +549
"""Reproduction target: 4 requests with cache lengths spanning 6×.

Production smoke (5 iter × 8 concurrent × mixed taobao+cooking pool)
reliably triggers garbled output / cross-request KV-pollution in this
regime. Until the bug is fixed, this assertion will fail at the
``CG-replay vs eager`` compare; the eager and per-request reference
paths still agree (sanity holds — only the CG path corrupts).
// pages owned by other active requests, producing cross-request KV cache
// corruption and the "** ** **" garbled output.
if (buf_h.defined() && buf_h.numel() > 0) {
std::memset(buf_h.data_ptr(), 0, buf_h.numel() * sizeof(int32_t));
Comment on lines +205 to +216
py_test(
name = "test_fa3_heterogeneous",
srcs = [
"test_fa3_heterogeneous.py",
"base_attention_test.py",
],
deps = py_test_deps,
tags = ["H20"],
exec_properties = {
'gpu': 'H20',
},
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants