Skip to content

[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057

Open
plugyawn wants to merge 11 commits into
NVIDIA:mainfrom
plugyawn:rope-thd-token-linear
Open

[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057
plugyawn wants to merge 11 commits into
NVIDIA:mainfrom
plugyawn:rope-thd-token-linear

Conversation

@plugyawn

@plugyawn plugyawn commented May 28, 2026

Copy link
Copy Markdown

Description

Adds a token-linear implementation of the existing THD fused RoPE path to remove a launch-scaling bug.

Addresses #2866, which finds an interesting case with RoPE scales by freqs_len × n_spans, which is pathological; it should scale by total tokens. I reproduced the issue and found that it's causing a noticeable drops on even plausibly routine shapes. For eg: the [128/512] and [512/128] cases here.

The new kernel reuses the existing fused_rope_block_forward and fused_rope_block_backward device helpers, so the math doesn't change. All we need to do is add a THD-only path that launches one bloc/packed token.

n_seqs max span old layer fwd+bwd (ms) new layer fwd+bwd (ms) layer speedup old paired-RoPE share new paired-RoPE share
128 512 41.8151 23.0284 1.816x 49.12% 6.14%
512 128 102.1047 23.0167 4.436x 79.38% 6.59%
1024 64 182.9933 23.3783 7.827x 88.36% 6.77%
2401 28 401.0516 24.5668 16.325x 94.40% 6.41%

This is mostly pathological, however, so I've added a condition on the dispatch to avoid the unnecessary binary search overhead, although the overhead appears to be not-that-relevant. The condition is: token-linear only when b >= 64 and the old launch would issue ≥ 8× as many blocks as there are tokens. I'm not sure if this the usual shape of TE updates, so I could remove it!

Some more relevant tests:
Microbenchmark on H100 (bf16, h=32, d=d2=128, freqs_len=T_local=65536, single GPU):

n_seqs old fwd+bwd (ms) new fwd+bwd (ms) speedup
1 1.2746 1.2734 1.001x
8 1.8860 1.3827 1.364x
32 3.9359 1.4462 2.722x
128 12.1849 1.5024 8.110x
512 44.9411 1.5600 28.808x
1024 89.1110 1.5919 55.977x
2401 208.4182 1.6373 127.296x

Fixes: #2866.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add token-linear THD fused RoPE forward/backward kernels that launch one CUDA block per packed local token row.
  • Add NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1.
  • Reuses existing fused_rope_block_forward and fused_rope_block_backward device helpers.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation <<(none?)>>
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 28, 2026
@greptile-apps

greptile-apps Bot commented May 28, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes a launch-scaling bug in the THD fused RoPE path where the old kernel launched freqs_len × n_seqs blocks regardless of actual token count, causing severe over-launch for batches with many short sequences. It adds new token-linear forward/backward kernels that launch one block per packed local token, reusing the existing fused_rope_block_forward/backward device helpers so the per-token math is unchanged.

  • New fused_rope_thd_token_forward/backward_kernel kernels: one block per packed local token, with a shared-memory broadcast of the binary-search result (fused_rope_thd_find_seq_id) to avoid redundant per-thread lookups.
  • fused_rope_thd_use_token_linear host dispatcher: selects the new path when legacy_blocks > 2 × cp_size × token_linear_blocks, overridable via NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1 for testing; local_tokens is derived from the input tensor's first dimension.
  • Tests: existing test_fused_rope_thd is patched to force the new path; a new test_fused_rope_thd_token_linear_parity exhaustively checks bitwise equality between old and new paths across dtypes, cp_size, cp_rank, variable sequence counts, and zero-length spans.

Confidence Score: 4/5

Safe to merge; the new kernels reuse battle-tested device helpers and the parity test enforces bitwise equality with the original path across a wide parameter grid.

The kernel logic, binary search, shared-memory broadcast, and CP dual-chunk offset all look correct. The two review comments flag a documentation discrepancy in the heuristic threshold description and a test coverage gap — neither affects runtime correctness — but a reviewer should confirm both before merging.

transformer_engine/common/fused_rope/fused_rope.cu (heuristic threshold constant and its documentation) and tests/pytorch/test_fused_rope.py (direct coverage of the old kernel path)

Important Files Changed

Filename Overview
transformer_engine/common/fused_rope/fused_rope.cu Adds two new CUDA kernels (token-linear THD forward/backward), a binary-search helper, and a host-side heuristic dispatcher; uses shared memory to broadcast the per-block valid/seq results; changes are self-consistent with the existing block helpers
tests/pytorch/test_fused_rope.py Adds a comprehensive parity test between old and new THD paths; modifies test_fused_rope_thd to force the new path, leaving the old path validated only transitively through the parity test
benchmarks/attention/benchmark_rope_thd_token_linear.py New standalone microbenchmark sweeping n_seqs under forced-old, forced-new, and heuristic modes; clean, self-contained, no issues found

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["fused_rope_forward / fused_rope_backward\n(host)"] --> B["Compute local_tokens\nfrom input.data.shape[0]\n(THD only)"]
    B --> C["fused_rope_thd_use_token_linear(\n  qkv_format, legacy_blocks,\n  token_linear_blocks, cp_size\n)"]
    C -->|"qkv_format != THD\nor token_linear_blocks == 0"| F["Legacy path\ndim3 blocks(s, b)\nfused_rope_*_kernel"]
    C -->|"env == 0 (forced-old)"| F
    C -->|"env == 1 (forced-new)"| G["Token-linear path\ndim3 blocks(local_tokens)\nfused_rope_thd_token_*_kernel"]
    C -->|"legacy_blocks > 2 x cp_size x token_linear_blocks"| G
    C -->|"otherwise"| F
    G --> H["Each block: threadIdx==(0,0)\ncomputes valid_token and seq_id\ninto shared memory"]
    H --> I["__syncthreads()"]
    I -->|"!valid_token"| J["return (dead block)"]
    I -->|"valid_token"| K["fused_rope_thd_find_seq_id\n(binary search on cu_seqlens)"]
    K --> L["Compute s_id, s_id_for_freqs\n(with CP dual-chunk offset)"]
    L --> M["fused_rope_block_forward /\nfused_rope_block_backward"]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["fused_rope_forward / fused_rope_backward\n(host)"] --> B["Compute local_tokens\nfrom input.data.shape[0]\n(THD only)"]
    B --> C["fused_rope_thd_use_token_linear(\n  qkv_format, legacy_blocks,\n  token_linear_blocks, cp_size\n)"]
    C -->|"qkv_format != THD\nor token_linear_blocks == 0"| F["Legacy path\ndim3 blocks(s, b)\nfused_rope_*_kernel"]
    C -->|"env == 0 (forced-old)"| F
    C -->|"env == 1 (forced-new)"| G["Token-linear path\ndim3 blocks(local_tokens)\nfused_rope_thd_token_*_kernel"]
    C -->|"legacy_blocks > 2 x cp_size x token_linear_blocks"| G
    C -->|"otherwise"| F
    G --> H["Each block: threadIdx==(0,0)\ncomputes valid_token and seq_id\ninto shared memory"]
    H --> I["__syncthreads()"]
    I -->|"!valid_token"| J["return (dead block)"]
    I -->|"valid_token"| K["fused_rope_thd_find_seq_id\n(binary search on cu_seqlens)"]
    K --> L["Compute s_id, s_id_for_freqs\n(with CP dual-chunk offset)"]
    L --> M["fused_rope_block_forward /\nfused_rope_block_backward"]
Loading

Reviews (9): Last reviewed commit: "Merge branch 'main' into rope-thd-token-..." | Re-trigger Greptile

Comment on lines +250 to +251
int t_id = blockIdx.x;
int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant binary search across all threads in the block

Every thread in the block calls fused_rope_thd_find_seq_id with the same arguments (t_id = blockIdx.x, nseq, cp_size) and produces an identical result. With warps_per_block = 8, that's 256 threads each doing O(log nseq) global-memory reads of cu_seqlens that could be performed once. For nseq=2401 (~12 iterations x 256 threads), each block reads ~3,072 redundant entries from cu_seqlens. Performing the search once in thread 0 and broadcasting the result via shared memory would eliminate that overhead.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart bot!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved!

Comment thread transformer_engine/common/fused_rope/fused_rope.cu
@ptrendx

ptrendx commented May 28, 2026

Copy link
Copy Markdown
Member

@plugyawn Hi, could you sign your commits? See https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst#sign-your-work
Nice improvement :-).

@sudhakarsingh27 Could you take a look?

plugyawn and others added 3 commits May 29, 2026 03:23
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
for more information, see https://pre-commit.ci

Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
@plugyawn plugyawn force-pushed the rope-thd-token-linear branch from 331a3a0 to 6c46696 Compare May 28, 2026 21:55
@plugyawn

plugyawn commented May 28, 2026

Copy link
Copy Markdown
Author

Thanks! Signed!

fwiw I think the binary search overhead on normal cases can be reduced also, I'll probably add some improvements.

@sudhakarsingh27 sudhakarsingh27 self-requested a review June 3, 2026 22:08

@sudhakarsingh27 sudhakarsingh27 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Posted the RoPE THD token-linear review comments from the local benchmark/coverage analysis. The main concerns are the dispatch heuristic, CP-local token accounting, CP-rank coverage, and benchmark scope.

Comment thread transformer_engine/common/fused_rope/fused_rope.cu Outdated
const int o_stride_h = d;
const int o_stride_d = 1;

if (fused_rope_thd_use_token_linear(qkv_format, b, s, total_tokens)) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make the compact launch decision use the actual local THD rows and the legacy launch blocks. The local patch uses this shape:

const size_t compact_thd_blocks = input.data.shape[0];
const size_t legacy_thd_blocks = static_cast<size_t>(s) * b;

if (fused_rope_thd_use_compact_launch(legacy_thd_blocks, compact_thd_blocks, cp_size)) {
  const int t = input.data.shape[0];
  dim3 blocks(t);
  ...
}

This also avoids routing the heuristic through a total_tokens value whose CP/global semantics are easy to confuse.

@plugyawn plugyawn Jun 9, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Also renamed the variable from total_tokens, so no CP/global ambiguity. Could you check if it's fine now?

Comment thread tests/pytorch/test_fused_rope.py Outdated
Comment thread benchmarks/attention/benchmark_rope_thd_token_linear.py Outdated
Comment thread benchmarks/attention/benchmark_rope_thd_full_layer.py Outdated
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
@plugyawn

plugyawn commented Jun 9, 2026

Copy link
Copy Markdown
Author

Fixed some of the review comments, resolving the rest now.


Additional CP-rank validation for the THD token-linear RoPE path on the rebased PR tip:

  • Commit: eaee5a1731141654a006f9872fcfe10132cdcf76 (Cover CP ranks in THD RoPE token-linear tests)
  • Hardware/runtime: Prime Datacrunch A100 80GB, driver 580.126.09, CUDA 12.8, PyTorch 2.8.0+cu128
  • Build: editable TE PyTorch build passed; fused_rope.cu and apply_rope.cpp compiled on this exact tip
  • test_fused_rope_thd_token_linear_parity: 288 passed / 96 skipped / 0 failed. The skips are invalid cp_rank >= cp_size; the JUnit/log includes 96 passing cp_rank=1, cp_size=2 cases.
  • test_fused_rope_thd with the token-linear path forced: 384 passed / 0 failed

This closes the earlier proof gap where old-vs-new parity only covered cp_rank=0.

@plugyawn plugyawn requested a review from sudhakarsingh27 June 9, 2026 08:10
@plugyawn

Copy link
Copy Markdown
Author

@ptrendx @sudhakarsingh27 the last review comments are addressed.

// Heuristic: use the token-linear path when the legacy launch would issue
// enough extra blocks to amortize one sequence lookup per useful token. The
// CP factor keeps the gate conservative because local rows shrink with
// context parallelism while legacy launch space does not.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the legacy launch space? Does it refer to the one without heuristic or the previous heuristic?

offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}

// Token-linear THD forward kernel. Each block handles exactly one packed local

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit iffy about calling it token-linear THD. Maybe the right term is THD linear-grid forward kernel? Pls make that change across the file(s)

// divided cumulative sequence boundaries, then defers to the same
// `fused_rope_block_forward` device function as the original kernel.
template <typename scalar_t>
__global__ void fused_rope_thd_token_forward_kernel(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, this could be fused_rope_thd_linear_grid_forward_kernel and other function below could follow the suit


const size_t token_linear_blocks = static_cast<size_t>(local_tokens);
const size_t legacy_blocks = static_cast<size_t>(s) * static_cast<size_t>(b);
if (fused_rope_thd_use_token_linear(qkv_format, legacy_blocks, token_linear_blocks, cp_size)) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, use_fused_rope_thd_linear_grid_launch

const int stride_h, const int stride_d, cudaStream_t stream) {
// For THD the packed local token count is the first dimension of the input
// tensor. SBHD/BSHD ignore this value.
const int64_t local_tokens =

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should call this total_tokens or total_tokens_in_input. local_tokens seems to convey a different unrelated meaning but I understand where you're coming from.

@sudhakarsingh27

Copy link
Copy Markdown
Member

I'm a bit iffy about adding a benchmark since we aren't actively maintaining benchmarks. @cyanguwa wdyt?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance] Fused RoPE THD kernel becomes dominant bottleneck in long-context training with many packed sequences

3 participants