Enable FA4 for context-parallel attention#3149
Draft
sudhakarsingh27 wants to merge 2 commits into
Draft
Conversation
Add minimal FA4 raw-call plumbing for p2p, all_gather, and a2a context-parallel attention. FA4 accepts padded THD cu_seqlens plus seqused values, so keep padded physical offsets separate from visible token lengths for CP and non-CP reference paths. Keep a2a+p2p disabled because that hierarchical path has not been validated for FA4. Update CP tests so FA4-only environments do not skip the FlashAttention CP matrix before the FA4-specific guards run.
for more information, see https://pre-commit.ci
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Enable FlashAttention 4 for context-parallel attention on the already-supported CP communication shapes:
p2p,all_gather, anda2a.FA4 can represent padded THD layouts with physical padded offsets plus
seqused_*visible lengths. This lets the CP path handlefa_pad_between_seqs=Truewith the same layout distinction used by the non-CP reference path. The hierarchicala2a+p2ppath remains disabled because it has not been validated for FA4.Fixes # (issue)
Type of change
Changes
p2p,all_gather, anda2aFlashAttention paths.cu_seqlens_*together withseqused_*for FA4 when padding exists between sequences.a2a+p2puntil that path is separately validated.Checklist:
Validation
git diff --checkpython3.12 -m py_compileon the touched attention/test files6 passed, 10 skipped, 48 deselectedon each platform. The skips are existing unsupported combinations:p2p+SWA, MLA with non-P2P CP, and padded THD witha2a+p2p.