[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057
[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057plugyawn wants to merge 11 commits into
Conversation
| int t_id = blockIdx.x; | ||
| int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); |
There was a problem hiding this comment.
Redundant binary search across all threads in the block
Every thread in the block calls fused_rope_thd_find_seq_id with the same arguments (t_id = blockIdx.x, nseq, cp_size) and produces an identical result. With warps_per_block = 8, that's 256 threads each doing O(log nseq) global-memory reads of cu_seqlens that could be performed once. For nseq=2401 (~12 iterations x 256 threads), each block reads ~3,072 redundant entries from cu_seqlens. Performing the search once in thread 0 and broadcasting the result via shared memory would eliminate that overhead.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
@plugyawn Hi, could you sign your commits? See https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst#sign-your-work @sudhakarsingh27 Could you take a look? |
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
for more information, see https://pre-commit.ci Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
331a3a0 to
6c46696
Compare
|
Thanks! Signed! fwiw I think the binary search overhead on normal cases can be reduced also, I'll probably add some improvements. |
sudhakarsingh27
left a comment
There was a problem hiding this comment.
Posted the RoPE THD token-linear review comments from the local benchmark/coverage analysis. The main concerns are the dispatch heuristic, CP-local token accounting, CP-rank coverage, and benchmark scope.
| const int o_stride_h = d; | ||
| const int o_stride_d = 1; | ||
|
|
||
| if (fused_rope_thd_use_token_linear(qkv_format, b, s, total_tokens)) { |
There was a problem hiding this comment.
Please make the compact launch decision use the actual local THD rows and the legacy launch blocks. The local patch uses this shape:
const size_t compact_thd_blocks = input.data.shape[0];
const size_t legacy_thd_blocks = static_cast<size_t>(s) * b;
if (fused_rope_thd_use_compact_launch(legacy_thd_blocks, compact_thd_blocks, cp_size)) {
const int t = input.data.shape[0];
dim3 blocks(t);
...
}This also avoids routing the heuristic through a total_tokens value whose CP/global semantics are easy to confuse.
There was a problem hiding this comment.
Fixed. Also renamed the variable from total_tokens, so no CP/global ambiguity. Could you check if it's fine now?
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
|
Fixed some of the review comments, resolving the rest now. Additional CP-rank validation for the THD token-linear RoPE path on the rebased PR tip:
This closes the earlier proof gap where old-vs-new parity only covered |
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
for more information, see https://pre-commit.ci
|
@ptrendx @sudhakarsingh27 the last review comments are addressed. |
| // Heuristic: use the token-linear path when the legacy launch would issue | ||
| // enough extra blocks to amortize one sequence lookup per useful token. The | ||
| // CP factor keeps the gate conservative because local rows shrink with | ||
| // context parallelism while legacy launch space does not. |
There was a problem hiding this comment.
What's the legacy launch space? Does it refer to the one without heuristic or the previous heuristic?
| offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); | ||
| } | ||
|
|
||
| // Token-linear THD forward kernel. Each block handles exactly one packed local |
There was a problem hiding this comment.
A bit iffy about calling it token-linear THD. Maybe the right term is THD linear-grid forward kernel? Pls make that change across the file(s)
| // divided cumulative sequence boundaries, then defers to the same | ||
| // `fused_rope_block_forward` device function as the original kernel. | ||
| template <typename scalar_t> | ||
| __global__ void fused_rope_thd_token_forward_kernel( |
There was a problem hiding this comment.
similarly, this could be fused_rope_thd_linear_grid_forward_kernel and other function below could follow the suit
|
|
||
| const size_t token_linear_blocks = static_cast<size_t>(local_tokens); | ||
| const size_t legacy_blocks = static_cast<size_t>(s) * static_cast<size_t>(b); | ||
| if (fused_rope_thd_use_token_linear(qkv_format, legacy_blocks, token_linear_blocks, cp_size)) { |
There was a problem hiding this comment.
similarly, use_fused_rope_thd_linear_grid_launch
| const int stride_h, const int stride_d, cudaStream_t stream) { | ||
| // For THD the packed local token count is the first dimension of the input | ||
| // tensor. SBHD/BSHD ignore this value. | ||
| const int64_t local_tokens = |
There was a problem hiding this comment.
Maybe we should call this total_tokens or total_tokens_in_input. local_tokens seems to convey a different unrelated meaning but I understand where you're coming from.
|
I'm a bit iffy about adding a benchmark since we aren't actively maintaining benchmarks. @cyanguwa wdyt? |
Description
Adds a token-linear implementation of the existing THD fused RoPE path to remove a launch-scaling bug.
Addresses #2866, which finds an interesting case with RoPE scales by freqs_len × n_spans, which is pathological; it should scale by total tokens. I reproduced the issue and found that it's causing a noticeable drops on even plausibly routine shapes. For eg: the [128/512] and [512/128] cases here.
The new kernel reuses the existing
fused_rope_block_forwardandfused_rope_block_backwarddevice helpers, so the math doesn't change. All we need to do is add a THD-only path that launches one bloc/packed token.This is mostly pathological, however, so I've added a condition on the dispatch to avoid the unnecessary binary search overhead, although the overhead appears to be not-that-relevant. The condition is: token-linear only when
b >= 64and the old launch would issue ≥ 8× as many blocks as there are tokens. I'm not sure if this the usual shape of TE updates, so I could remove it!Some more relevant tests:
Microbenchmark on H100 (bf16,
h=32,d=d2=128,freqs_len=T_local=65536, single GPU):Fixes: #2866.
Type of change
Changes
Please list the changes introduced in this PR:
NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1.fused_rope_block_forwardandfused_rope_block_backwarddevice helpers.Checklist: