Skip to content

Enable FA4 for context-parallel attention#3149

Draft
sudhakarsingh27 wants to merge 2 commits into
NVIDIA:mainfrom
sudhakarsingh27:fa4-cp-exploration
Draft

Enable FA4 for context-parallel attention#3149
sudhakarsingh27 wants to merge 2 commits into
NVIDIA:mainfrom
sudhakarsingh27:fa4-cp-exploration

Conversation

@sudhakarsingh27

Copy link
Copy Markdown
Member

Description

Enable FlashAttention 4 for context-parallel attention on the already-supported CP communication shapes: p2p, all_gather, and a2a.

FA4 can represent padded THD layouts with physical padded offsets plus seqused_* visible lengths. This lets the CP path handle fa_pad_between_seqs=True with the same layout distinction used by the non-CP reference path. The hierarchical a2a+p2p path remains disabled because it has not been validated for FA4.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Route FA4 raw forward/backward calls through the CP p2p, all_gather, and a2a FlashAttention paths.
  • Pass padded THD cu_seqlens_* together with seqused_* for FA4 when padding exists between sequences.
  • Keep FA4 disabled for a2a+p2p until that path is separately validated.
  • Update FlashAttention CP tests so FA4-capable environments exercise the CP matrix and padded THD guards.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Validation

  • git diff --check
  • python3.12 -m py_compile on the touched attention/test files
  • Static forward/backward arity check for the three touched CP autograd functions
  • Padded THD FlashAttention CP pytest subset on H100-class and B200-class GPUs: 6 passed, 10 skipped, 48 deselected on each platform. The skips are existing unsupported combinations: p2p+SWA, MLA with non-P2P CP, and padded THD with a2a+p2p.

sudhakarsingh27 and others added 2 commits June 26, 2026 01:20
Add minimal FA4 raw-call plumbing for p2p, all_gather, and a2a context-parallel attention. FA4 accepts padded THD cu_seqlens plus seqused values, so keep padded physical offsets separate from visible token lengths for CP and non-CP reference paths.

Keep a2a+p2p disabled because that hierarchical path has not been validated for FA4. Update CP tests so FA4-only environments do not skip the FlashAttention CP matrix before the FA4-specific guards run.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant