Skip to content

Make TE Sequential Grouped linear Op CUDA graphable#2923

Open
vthumbe1503 wants to merge 15 commits intoNVIDIA:mainfrom
vthumbe1503:grouped_linear_integration_v2
Open

Make TE Sequential Grouped linear Op CUDA graphable#2923
vthumbe1503 wants to merge 15 commits intoNVIDIA:mainfrom
vthumbe1503:grouped_linear_integration_v2

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as draft April 24, 2026 20:05
@vthumbe1503 vthumbe1503 changed the title Grouped linear integration v2 Make Grouped linear TE Sequential Op CUDA graphable Apr 24, 2026
@vthumbe1503 vthumbe1503 changed the title Make Grouped linear TE Sequential Op CUDA graphable Make TE Sequential Grouped linear Op CUDA graphable Apr 24, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 24, 2026

Greptile Summary

This PR introduces a CUDA-graph-safe forward/backward path for GroupedLinear — dispatching to _fuser_forward_grouped_tensor / _fuser_backward_grouped_tensor on Blackwell (SM100+) and falling back to the legacy split_quantize path elsewhere. It also refactors shared main_grad / dummy-wgrad helpers into _common.py, used by BasicLinear, fused ops, and the new grouped path.

  • The backward_dw method's else branch (triggered for any non-list activation) accesses MXFP8-specific GroupedTensor attributes (.columnwise_data, .scale_inv, .columnwise_scale_inv). The new non-quantized graph-safe backward also stores a plain GroupedTensor here, and if those attributes are absent on plain instances this will raise AttributeError whenever delay_wgrad=True is used with the non-quantized graph-safe path.
  • The new CUDA-graph test does not parametrize delay_wgrad_compute=True, leaving that interaction untested.

Confidence Score: 4/5

Safe to merge on the core eager path; the P1 concern only manifests for delay_wgrad=True on Blackwell with the non-quantized graph-safe path, which is not yet tested.

One P1 finding (backward_dw accessing MXFP8 attributes on plain GroupedTensors) caps the score at 4. The rest of the refactor is clean and well-structured, and the P1 only fires on a narrow, untested code path.

transformer_engine/pytorch/ops/basic/grouped_linear.py — specifically backward_dw (lines 237-246) and the wgrad_store.put call inside _fuser_backward_grouped_tensor.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py Core change: splits fuser_forward/backward into graph-safe grouped-tensor path and legacy split-quantize path; adds single_grouped_weight/bias support. Several tricky backward interactions with delay_wgrad and backward_dw deserve attention.
transformer_engine/pytorch/ops/_common.py Adds well-factored helpers (get_main_grad_from_param, get_accumulate_flag_in_param, view_main_grad_as_grouped_buffer, get_dummy_wgrads_for_params) shared across BasicLinear, GroupedLinear, and fused ops; clean refactor.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Migrates main_grad/dummy_wgrad handling to shared helpers; removes duplicated numel-check + view logic replaced by view_main_grad_as_grouped_buffer; clean.
tests/pytorch/test_fusible_ops.py Adds single_grouped_weight/bias to existing grouped_linear test and a new CUDA-graph-safety test; delay_wgrad_compute=True is not parametrized in the new graph test.
transformer_engine/pytorch/ops/basic/basic_linear.py Migrates main_grad/dummy_wgrad logic to shared helpers from _common.py; straightforward, no logic changes.
transformer_engine/pytorch/ops/fused/backward_linear_add.py Routine migration to shared _common.py helpers; no logic changes.
transformer_engine/pytorch/ops/fused/backward_linear_scale.py Routine migration to shared _common.py helpers; no logic changes.
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py Routine migration to shared _common.py helpers; no logic changes.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fuser_forward] --> B{_is_graph_safe_path_supported?\nSM100+ AND BF16/FP16 or MXFP8}
    B -- Yes --> C[_fuser_forward_grouped_tensor\nGraph-safe: GPU split_sizes,\nGroupedTensor input/weight/output]
    B -- No --> D[_fuser_forward_split_quantize\nLegacy: CPU split_sizes_int,\ntex.split_quantize]
    C --> E[ctx.use_grouped_tensor_path = True]
    D --> F[ctx.use_grouped_tensor_path = False]

    E --> G[fuser_backward]
    F --> G
    G --> H{ctx.use_grouped_tensor_path?}
    H -- True --> I[_fuser_backward_grouped_tensor\ngeneral_grouped_gemm_for_grouped_tensor]
    H -- False --> J[_fuser_backward_split_quantize\ngeneral_grouped_gemm]

    I --> K{delay_wgrad?}
    K -- Yes --> L[wgrad_store.put GroupedTensor\n⚠️ backward_dw may access\nMXFP8 attrs on plain GroupedTensor]
    K -- No --> M[immediate wgrad GEMM\n⚠️ no clear_tensor_data]

    J --> N{delay_wgrad?}
    N -- Yes --> O[wgrad_store.put list of tensors\nbackward_dw safe path]
    N -- No --> P[immediate wgrad GEMM\nclear_tensor_data xs]
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/ops/basic/grouped_linear.py, line 237-246 (link)

    P1 backward_dw accesses MXFP8-specific attributes on non-quantized GroupedTensor

    backward_dw reaches the else branch for any non-list activation (i.e., any GroupedTensor). It then accesses .columnwise_data, .scale_inv, and .columnwise_scale_inv — attributes that only exist (non-None) on MXFP8 quantized GroupedTensors returned by tex.group_quantize. The new _fuser_backward_grouped_tensor path also stores a plain GroupedTensor (built from the Python constructor with quantizer=None) when with_quantized_compute=False. If those attributes are not defined on plain GroupedTensor instances, backward_dw will raise AttributeError whenever delay_wgrad=True is combined with the non-quantized graph-safe forward path (Blackwell + BF16/FP16).

    The else branch guard should be tightened, or the non-quantized activation should be stored as a list so it falls through the isinstance(activations, list) branch safely:

    # In _fuser_backward_grouped_tensor, delay_wgrad branch:
    # store as a list so backward_dw uses the safe clear_tensor_data(*activations) path
    self.wgrad_store.put([[grouped_x], grouped_dy, wgrad_output], wgrad_gemm)

    Alternatively, update backward_dw to detect the quantized/non-quantized case before accessing the quantized attributes.

Reviews (4): Last reviewed commit: "fix on l40/hopper to skip" | Re-trigger Greptile

Comment on lines +1157 to +1170
bias_scale: Optional[torch.Tensor] = None
if has_bias:
# Bias always needs to be passed as a GroupedTensor for the grouped GEMM.
grouped_bias = self._get_grouped_bias_for_gemm(dtype, device)
if self._scale_bias:
bias_scale = scales.reshape(-1)
if bias_scale.dtype != torch.float32:
bias_scale = bias_scale.to(dtype=torch.float32)

# Forward grouped GEMM (TN layout: out[i] = x[i] @ w[i]^T)
general_grouped_gemm_for_grouped_tensor(
grouped_weights,
grouped_x,
grouped_out,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing contiguity error handling for main_grad.view(-1)

main_grad.view(-1) will raise a generic RuntimeError if main_grad is non-contiguous (e.g. when returned by get_main_grad() via __fsdp_param__). The equivalent code in backward_grouped_mlp.py wraps the reshape in try/except and re-raises with an actionable message that includes the shape and stride. Without that guard, users hitting this case will see an opaque PyTorch error instead of a clear diagnostic.

Comment on lines +1191 to 1218
if ctx.requires_grad:
saved: list[Optional[torch.Tensor]] = [split_sizes, base_offsets]
if self._scale_bias:
saved.append(scales)
# For the wgrad input we save (data, scale_inv).
# * Quantized path saves columnwise data + scale.
# * Unquantized path saves the raw rowwise data and a None scale.
if grouped_x is not None:
if with_quantized_compute:
saved.extend(
[
grouped_x.columnwise_data,
grouped_x.columnwise_scale_inv,
]
)
else:
saved.extend([grouped_x.rowwise_data, None])
else:
saved.extend([None, None])
if self.single_grouped_weight:
saved.append(grouped_weights)
else:
saved.extend(grouped_weights)
ctx.save_for_backward(*saved)
ctx.use_grouped_tensor_path = True
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizers = input_quantizers
ctx.weight_quantizers = weight_quantizers
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Comment contradicts implementation for weight saving

The block comment says "we save the GroupedTensor's component buffers (rather than the wrapper) and rebuild it in backward" — but the code that follows saves the entire GroupedTensor wrapper for grouped_weights (when single_grouped_weight=True, saved.append(grouped_weights)). Component-buffer saving only applies to grouped_x (which saves columnwise_data/rowwise_data). The misleading comment could cause confusion when debugging or extending this path.

Comment thread transformer_engine/pytorch/ops/basic/grouped_linear.py
vthumbe1503 and others added 4 commits April 25, 2026 00:34
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
@vthumbe1503 vthumbe1503 marked this pull request as ready for review April 28, 2026 23:32
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants