Cosmos3 context parallel#14054
Draft
atharvajoshi10 wants to merge 2 commits into
Draft
Conversation
Under sharded placement (device_map="balanced"), vae.encode() runs on the VAE's own device while the mean/inv_std buffers were pinned to x.device, causing a cross-device RuntimeError. Compute raw_mu first, then pin the normalization buffers to its device so all tensors share one device.
962f513 to
67fb9ec
Compare
Cosmos 3 cannot use diffusers' declarative `_cp_plan` CP path: it is grouped-query
attention (the shared Ulysses kernel assumes K/V share the query head count), its
understanding (causal) and generation (full) streams are separate packed sequences
(gen attends to cat(und, gen)), and per-pathway lengths are ragged. The model carries
no parallelism logic -- it exposes only small, CP-agnostic seams; all sharding lives
outside it, in a reusable example module.
Model (transformer_cosmos3.py): adds two default-None `forward` seams -- `_cp_shard_fn`
(shards und/gen + rotary before the decoder layers) and `_cp_gather_fn` (gathers/unpads
after the final norm) -- and extracts `Cosmos3AttnProcessor._run_attention` as an
override point. The non-parallel path is unchanged.
Helpers (examples/cosmos3/cosmos_parallel.py): one importable module, two orthogonal
and composable axes:
* Context parallelism (Ulysses) -- `enable_cosmos3_context_parallel`. Shards the
sequence; brackets the two attention pathways with all-to-all (DTensor redistribute),
repeats GQA KV heads, pads ragged lengths and masks padded generation keys.
* Tensor parallelism (Megatron) -- `enable_cosmos3_tensor_parallel`. Column/row-shards
the attention + MLP weights so a checkpoint that does not fit one GPU (Super, ~120 GB)
loads across several; weights load to CPU then shard layer by layer.
Both expand KV heads to the query-head count and call SDPA with enable_gqa=False so it
dispatches to the flash kernel; enable_gqa=True forces the math path, which materializes
the full [S, S] score matrix and OOMs on long videos. A dense `Cosmos3FlashAttnProcessor`
(`enable_cosmos3_flash_attention`) provides the same for TP without CP.
CLI (examples/cosmos3/inference_cosmos3.py): imports these helpers, so any modality
(text-to-image/video, image-to-video, sound, action) runs single- or multi-GPU via
`--tp-degree` / `--cp-degree` (their product must equal --nproc_per_node). Single-GPU
behavior is unchanged.
Docs + example README updated. Verified: CP attention core is bit-exact vs non-CP in
fp32 (max|d|=0), and a full 36-layer forward matches CP-on vs CP-off to ~1e-6 in fp32
(bf16 differs only by floating-point rounding).
67fb9ec to
6edc5fd
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes # (issue)
Before submitting
.ai/review-rules.md?documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.