[torch.compile] Bunch of small changes needed for enabling torch.compile#3130
[torch.compile] Bunch of small changes needed for enabling torch.compile#3130pggPL wants to merge 6 commits into
Conversation
…stants; fix SP memory leak; test suite hook-up Wrap CommOverlapCore pybind11 methods that return compile-time constants so torch.compile(fullgraph=True) can trace through them without graph breaks: - `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py; `_ub_is_fp8()` in gemm.py - `with_cublasmp()` → `ub_is_cublasmp()` in base.py All callers in linear.py, layernorm_linear.py, layernorm_mlp.py, base.py, gemm.py, userbuffers_backward_linear.py and userbuffers_forward_linear.py updated. Fix quantized grad_output not being freed early for column-parallel SP backward. Row-parallel SP already called clear_tensor_data(grad_output) to release the gathered tensor; column-parallel SP quantizes grad_output to Float8TensorStorage but never freed it before returning. Under torch.compile reduce-overhead this leaves 3 live pool tensors at recording end and triggers "Detected 3 tensor(s) in the cudagraph pool not tracked as outputs". Extend the existing clear_tensor_data guard to cover both parallel modes. Fix custom-recipe quantizer state being re-initialised on every forward call even when the recipe object has not changed. The existing early-exit for CustomRecipeState was missing an identity check on the recipe object, so any repeated call with the same recipe would bypass the early-return and rebuild quantizers unnecessarily. Add `if recipe_state.recipe is recipe: return` to restore the intended caching behaviour. Add test_torch_compile.py to L0_pytorch_unittest so the autocast and existing compile tests run in CI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans LinearBwdArgs stored the entire FP8 recipe object so the backward could extract fp8_gemm_dgrad.use_split_accumulator and fp8_gemm_wgrad.use_split_accumulator at GEMM time. Recipe objects hold process-group references and are not serialisable as compile-time constants, making them incompatible with torch.compile custom-op paths. Replace fp8_recipe with two plain bool fields: - dgrad_use_split_accumulator (default _2X_ACC_DGRAD) - wgrad_use_split_accumulator (default _2X_ACC_WGRAD) These are resolved once in _linear_setup_ctx and passed into the args struct, so the backward consumes scalars instead of a live recipe object. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR makes five targeted changes to enable
Confidence Score: 5/5Safe to merge — all five changes are well-scoped, behavior-equivalent where intended, and correctly fix live-tensor leaks and stale cache hits. The split-accumulator refactor is a pure structural change (booleans resolved at the same point in time, same recipe object, just threaded as plain values instead of the recipe itself). The column-SP FP8 clear_tensor_data is called only after the wgrad GEMM completes, so no use-after-free risk. The CustomRecipeState identity fix prevents spurious quantizer rebuilds without changing correctness. torch.compiler.reset() in destroy_ub() intentionally nukes all compiled caches on teardown to avoid stale assume_constant_result constants, which is the right trade-off for a function that is almost always called at training end or test boundary. No files require special attention. layernorm_linear.py and layernorm_mlp.py still carry ctx.fp8_recipe in their backward paths, but that is explicitly out of scope for this PR. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant M as Linear.forward (nn.Module)
participant GUF as get_ub_is_fp8() @assume_constant_result
participant UB as UB Communicator
participant FWD as _Linear.forward
participant CTX as LinearFwdArgs / LinearBwdArgs
participant BWD as _linear_backward
M->>GUF: get_ub_is_fp8(name, is_fp8_enabled())
GUF->>UB: is_fp8_ubuf()
UB-->>GUF: bool
GUF-->>M: fp8_output / fp8_grad
M->>M: resolve dgrad/wgrad_use_split_accumulator from recipe
M->>FWD: LinearFwdArgs(dgrad_use_split_accumulator, wgrad_use_split_accumulator, ...)
FWD->>CTX: _linear_setup_ctx transfers booleans to LinearBwdArgs
BWD->>BWD: "use_split_accumulator = bwd_args.dgrad_use_split_accumulator"
Note over BWD: Column-SP FP8 path
BWD->>BWD: "grad_output = quantizer(grad_output) [Float8TensorStorage]"
BWD->>BWD: wgrad_gemm(inputmat_total, grad_output)
BWD->>BWD: clear_tensor_data(grad_output) [NEW: free pool tensor]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant M as Linear.forward (nn.Module)
participant GUF as get_ub_is_fp8() @assume_constant_result
participant UB as UB Communicator
participant FWD as _Linear.forward
participant CTX as LinearFwdArgs / LinearBwdArgs
participant BWD as _linear_backward
M->>GUF: get_ub_is_fp8(name, is_fp8_enabled())
GUF->>UB: is_fp8_ubuf()
UB-->>GUF: bool
GUF-->>M: fp8_output / fp8_grad
M->>M: resolve dgrad/wgrad_use_split_accumulator from recipe
M->>FWD: LinearFwdArgs(dgrad_use_split_accumulator, wgrad_use_split_accumulator, ...)
FWD->>CTX: _linear_setup_ctx transfers booleans to LinearBwdArgs
BWD->>BWD: "use_split_accumulator = bwd_args.dgrad_use_split_accumulator"
Note over BWD: Column-SP FP8 path
BWD->>BWD: "grad_output = quantizer(grad_output) [Float8TensorStorage]"
BWD->>BWD: wgrad_gemm(inputmat_total, grad_output)
BWD->>BWD: clear_tensor_data(grad_output) [NEW: free pool tensor]
Reviews (4): Last reviewed commit: "Merge branch 'main' into torch_compile_s..." | Re-trigger Greptile |
|
/te-ci pytorch L1 |
…t_result get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a reset, destroy_ub + re-init with different FP8 settings would read stale values until recompile. Only affects in-memory caches, not disk. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit the no-roles warning, which graph-breaks under fullgraph=True. qfactory dispatches on role.tensor_type instead of a pre-baked string key. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
| # Compiled graphs may have baked is_fp8_ubuf() via assume_constant_result; | ||
| # reset so re-init with different settings doesn't read stale constants. | ||
| torch.compiler.reset() |
There was a problem hiding this comment.
The current helper call sites are all inside @no_torch_dynamo() forwards and the added test_torch_compile.py coverage does not exercise user buffers or it's done implicitly in the test?
Is it possible avoid a process-wide compiler reset on UB teardown, or add a targeted compiled UB test that proves the stale-constant case and justifies this global invalidation?
| python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" | ||
| python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8" | ||
| python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" | ||
| python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_torch_compile.xml $TE_PATH/tests/pytorch/test_torch_compile.py || test_fail "test_torch_compile.py" |
There was a problem hiding this comment.
That file only compiles a local ToyLinear helper and torch.nn.Linear under te.autocast. It does not instantiate changed in this PR te.Linear, LayerNormLinear, or LayerNormMLP, and it has no UB, sequence_parallel/parallel_mode.
What tests would fail without changes to layernorm_linear, layernorm_mlp files?
There was a problem hiding this comment.
I fix the issue that the test was not connected to the CI.
Currently it tests only if te.autocast() can be traced inside torch.compile.
This is first of series of PRs and I change here only small things to make next PRs cleaner.
|
/te-ci pytorch L0 L1 |
Description
Small standalone fixes extracted from a larger torch.compile branch, going directly from main. Two independent changes: making Userbuffers pybind11 queries compile-friendly, and freeing quantized grad_output early for column-parallel SP. Plus a custom-recipe caching fix, a split-accumulator refactor, and a CI test hook-up.
Type of change
Changes
Checklist: