Make TE Sequential Grouped linear Op CUDA graphable#2923
Make TE Sequential Grouped linear Op CUDA graphable#2923vthumbe1503 wants to merge 15 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces a CUDA-graph-safe forward/backward path for
Confidence Score: 4/5Safe to merge on the core eager path; the P1 concern only manifests for delay_wgrad=True on Blackwell with the non-quantized graph-safe path, which is not yet tested. One P1 finding (backward_dw accessing MXFP8 attributes on plain GroupedTensors) caps the score at 4. The rest of the refactor is clean and well-structured, and the P1 only fires on a narrow, untested code path. transformer_engine/pytorch/ops/basic/grouped_linear.py — specifically backward_dw (lines 237-246) and the wgrad_store.put call inside _fuser_backward_grouped_tensor. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fuser_forward] --> B{_is_graph_safe_path_supported?\nSM100+ AND BF16/FP16 or MXFP8}
B -- Yes --> C[_fuser_forward_grouped_tensor\nGraph-safe: GPU split_sizes,\nGroupedTensor input/weight/output]
B -- No --> D[_fuser_forward_split_quantize\nLegacy: CPU split_sizes_int,\ntex.split_quantize]
C --> E[ctx.use_grouped_tensor_path = True]
D --> F[ctx.use_grouped_tensor_path = False]
E --> G[fuser_backward]
F --> G
G --> H{ctx.use_grouped_tensor_path?}
H -- True --> I[_fuser_backward_grouped_tensor\ngeneral_grouped_gemm_for_grouped_tensor]
H -- False --> J[_fuser_backward_split_quantize\ngeneral_grouped_gemm]
I --> K{delay_wgrad?}
K -- Yes --> L[wgrad_store.put GroupedTensor\n⚠️ backward_dw may access\nMXFP8 attrs on plain GroupedTensor]
K -- No --> M[immediate wgrad GEMM\n⚠️ no clear_tensor_data]
J --> N{delay_wgrad?}
N -- Yes --> O[wgrad_store.put list of tensors\nbackward_dw safe path]
N -- No --> P[immediate wgrad GEMM\nclear_tensor_data xs]
|
| bias_scale: Optional[torch.Tensor] = None | ||
| if has_bias: | ||
| # Bias always needs to be passed as a GroupedTensor for the grouped GEMM. | ||
| grouped_bias = self._get_grouped_bias_for_gemm(dtype, device) | ||
| if self._scale_bias: | ||
| bias_scale = scales.reshape(-1) | ||
| if bias_scale.dtype != torch.float32: | ||
| bias_scale = bias_scale.to(dtype=torch.float32) | ||
|
|
||
| # Forward grouped GEMM (TN layout: out[i] = x[i] @ w[i]^T) | ||
| general_grouped_gemm_for_grouped_tensor( | ||
| grouped_weights, | ||
| grouped_x, | ||
| grouped_out, |
There was a problem hiding this comment.
Missing contiguity error handling for
main_grad.view(-1)
main_grad.view(-1) will raise a generic RuntimeError if main_grad is non-contiguous (e.g. when returned by get_main_grad() via __fsdp_param__). The equivalent code in backward_grouped_mlp.py wraps the reshape in try/except and re-raises with an actionable message that includes the shape and stride. Without that guard, users hitting this case will see an opaque PyTorch error instead of a clear diagnostic.
| if ctx.requires_grad: | ||
| saved: list[Optional[torch.Tensor]] = [split_sizes, base_offsets] | ||
| if self._scale_bias: | ||
| saved.append(scales) | ||
| # For the wgrad input we save (data, scale_inv). | ||
| # * Quantized path saves columnwise data + scale. | ||
| # * Unquantized path saves the raw rowwise data and a None scale. | ||
| if grouped_x is not None: | ||
| if with_quantized_compute: | ||
| saved.extend( | ||
| [ | ||
| grouped_x.columnwise_data, | ||
| grouped_x.columnwise_scale_inv, | ||
| ] | ||
| ) | ||
| else: | ||
| saved.extend([grouped_x.rowwise_data, None]) | ||
| else: | ||
| saved.extend([None, None]) | ||
| if self.single_grouped_weight: | ||
| saved.append(grouped_weights) | ||
| else: | ||
| saved.extend(grouped_weights) | ||
| ctx.save_for_backward(*saved) | ||
| ctx.use_grouped_tensor_path = True | ||
| ctx.with_quantized_compute = with_quantized_compute | ||
| ctx.input_quantizers = input_quantizers | ||
| ctx.weight_quantizers = weight_quantizers |
There was a problem hiding this comment.
Comment contradicts implementation for weight saving
The block comment says "we save the GroupedTensor's component buffers (rather than the wrapper) and rebuild it in backward" — but the code that follows saves the entire GroupedTensor wrapper for grouped_weights (when single_grouped_weight=True, saved.append(grouped_weights)). Component-buffer saving only applies to grouped_x (which saves columnwise_data/rowwise_data). The misleading comment could cause confusion when debugging or extending this path.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: