Skip to content

Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924

Open
ksivaman wants to merge 10 commits intoNVIDIA:mainfrom
ksivaman:pad_weight_scale_inv
Open

Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924
ksivaman wants to merge 10 commits intoNVIDIA:mainfrom
ksivaman:pad_weight_scale_inv

Conversation

@ksivaman
Copy link
Copy Markdown
Member

Description

Fix grouped MXFP8 swizzle when per-expert rows aren't a multiple of 128 and pad each expert's scales to (128, 4).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Make sure scaling factor inverses are 128x4 padded per tensor.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 24, 2026

Greptile Summary

This PR fixes grouped MXFP8 swizzle when each expert's row count is not a multiple of 128 by teaching the swizzle kernels to accept a "compact" input buffer (per-tensor stride = m × padded_k rather than padded_m × padded_k) and allocating the output in the correct per-tensor padded layout. The kernel refactor adds compile-time IS_PADDED_K/IS_PADDED_M template specializations to avoid out-of-bounds loads past the unpadded extent of each expert's buffer, dispatching at the block level where the decision is uniform across all threads.

Confidence Score: 5/5

Safe to merge; all findings are P2 documentation nits with no impact on correctness

The compact-layout detection, per-tensor stride separation, and IS_PADDED_K/IS_PADDED_M dispatch are logically sound. The out-of-bounds guard correctly prevents reading past the unpadded per-tensor extent in every grouped kernel variant. No P1/P0 issues found.

transformer_engine/common/swizzle/swizzle.cu — most complex change; worth a final read of the compact colwise stride semantics

Important Files Changed

Filename Overview
transformer_engine/common/swizzle/swizzle.cu Core fix: introduces IS_PADDED_K/IS_PADDED_M compile-time dispatch to skip out-of-bounds loads from compact input buffers; adds compact-layout detection and separate input/output strides for grouped kernels; one minor comment typo (DIVUP(original_K, 1))
transformer_engine/pytorch/csrc/extensions/swizzle.cpp Output scale buffers are now allocated with the per-tensor padded shape (num_tensors * padded_m, padded_k) instead of the raw input shape, ensuring the swizzle kernel receives correctly sized output regardless of whether the input is compact or padded
transformer_engine/common/common.h Adds TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH macro to replace the repeated switch-case boilerplate for vec_load_size in the grouped swizzle dispatch
tests/cpp/operator/test_swizzle.cu Adds SwizzleGroupedCompactInputTestSuite covering aligned/unaligned M and K shapes, including the originally-failing 2880×2880 case; also refactors existing ceiling-division calls to divide_round_up

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[swizzle_grouped_scaling_factors] --> B{Detect input layout}
    B -->|numel == num_tensors x padded_scale_elems| C[Padded layout\ninput_stride = padded_m x padded_k]
    B -->|numel == compact_total_scale_elems| D[Compact layout\ninput_stride = m x padded_k]
    B -->|neither| E[NVTE_ERROR]
    C --> F[output_stride = num_tensors x padded_scale_elems]
    D --> F
    F --> G{rowwise?}
    G -->|yes| H[grouped_swizzle_row_scaling_uniform_shape_kernel]
    G -->|no| I[grouped_swizzle_col_scaling_uniform_shape_kernel]
    H --> J{boundary block?}
    I --> J
    J -->|IS_PADDED_M| K[Skip __ldg for row >= original_M, write 0]
    J -->|IS_PADDED_K| L[Skip __ldg for k_coord >= original_K, write 0]
    J -->|neither| M[Normal __ldg load + shuffle + store]
    K --> N[Output: num_tensors x padded_m x padded_k, padded regions = 0]
    L --> N
    M --> N
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into pad_weight_scal..." | Re-trigger Greptile

const auto logical_shape_nvte = input.logical_shape();
NVTE_CHECK(logical_shape_nvte.ndim >= 2,
"Grouped GEMM swizzle expects logical_shape with ndim >= 2.");
const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Silent truncation when logical_shape_nvte.data[0] is not divisible by num_tensors

per_tensor_first_dim is computed with plain integer division. If logical_shape_nvte.data[0] is not an exact multiple of num_tensors (e.g. due to a caller bug or unexpected grouped layout), the result is silently truncated, causing padded_m to be underestimated and the output buffer to be too small. A divisibility assertion would catch this much earlier with a clear error message.

Suggested change
const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors;
const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors;
NVTE_CHECK(logical_shape_nvte.data[0] % num_tensors == 0,
"Grouped GEMM swizzle expects logical_shape first dim to be divisible by num_tensors.");

Comment on lines +2077 to 2087
bool input_is_compact;
if (input_scale_numel == input->num_tensors * padded_scale_elems) {
input_is_compact = false;
} else if (input_scale_numel == compact_total_scale_elems) {
input_is_compact = true;
} else {
NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems,
"Grouped input columnwise_scale_inv size does not match expected packed size.");
NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems,
"Grouped output columnwise_scale_inv size does not match expected packed size.");
NVTE_ERROR("Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"),
" size does not match expected packed size (got ", input_scale_numel,
", expected either ", input->num_tensors * padded_scale_elems,
" (per-tensor padded) or ", compact_total_scale_elems, " (compact)).");
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Implicit contract on compact-buffer alignment is not validated

The compact_total_scale_elems formula assumes the upstream quantize kernel allocates the compact scale buffer with its total first dim rounded up to 128 (rowwise) or 4 (colwise). If a caller passes a "plain compact" buffer of size exactly num_tensors * m * padded_k (without trailing alignment slack), neither branch matches and NVTE_ERROR fires with a size-mismatch message that may be hard to diagnose.

Consider also accepting num_tensors * compact_scale_elems as a valid compact size, or documenting this alignment requirement in the error message.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Apr 24, 2026

@ksivaman Could you add a test exercising the change?

@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci

Copy link
Copy Markdown
Collaborator

@Oleg-Goncharov Oleg-Goncharov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall

Comment thread tests/cpp/operator/test_swizzle.cu Outdated
Comment thread tests/cpp/operator/test_swizzle.cu Outdated
Comment thread tests/cpp/operator/test_swizzle.cu Outdated
Comment thread transformer_engine/common/swizzle/swizzle.cu Outdated
Comment thread transformer_engine/common/swizzle/swizzle.cu Outdated
Comment thread transformer_engine/common/swizzle/swizzle.cu
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci

Copy link
Copy Markdown
Collaborator

@Oleg-Goncharov Oleg-Goncharov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants