Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924
Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924ksivaman wants to merge 10 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci |
Greptile SummaryThis PR fixes grouped MXFP8 swizzle when each expert's row count is not a multiple of 128 by teaching the swizzle kernels to accept a "compact" input buffer (per-tensor stride = Confidence Score: 5/5Safe to merge; all findings are P2 documentation nits with no impact on correctness The compact-layout detection, per-tensor stride separation, and IS_PADDED_K/IS_PADDED_M dispatch are logically sound. The out-of-bounds guard correctly prevents reading past the unpadded per-tensor extent in every grouped kernel variant. No P1/P0 issues found. transformer_engine/common/swizzle/swizzle.cu — most complex change; worth a final read of the compact colwise stride semantics Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[swizzle_grouped_scaling_factors] --> B{Detect input layout}
B -->|numel == num_tensors x padded_scale_elems| C[Padded layout\ninput_stride = padded_m x padded_k]
B -->|numel == compact_total_scale_elems| D[Compact layout\ninput_stride = m x padded_k]
B -->|neither| E[NVTE_ERROR]
C --> F[output_stride = num_tensors x padded_scale_elems]
D --> F
F --> G{rowwise?}
G -->|yes| H[grouped_swizzle_row_scaling_uniform_shape_kernel]
G -->|no| I[grouped_swizzle_col_scaling_uniform_shape_kernel]
H --> J{boundary block?}
I --> J
J -->|IS_PADDED_M| K[Skip __ldg for row >= original_M, write 0]
J -->|IS_PADDED_K| L[Skip __ldg for k_coord >= original_K, write 0]
J -->|neither| M[Normal __ldg load + shuffle + store]
K --> N[Output: num_tensors x padded_m x padded_k, padded regions = 0]
L --> N
M --> N
Reviews (4): Last reviewed commit: "Merge branch 'main' into pad_weight_scal..." | Re-trigger Greptile |
| const auto logical_shape_nvte = input.logical_shape(); | ||
| NVTE_CHECK(logical_shape_nvte.ndim >= 2, | ||
| "Grouped GEMM swizzle expects logical_shape with ndim >= 2."); | ||
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; |
There was a problem hiding this comment.
Silent truncation when
logical_shape_nvte.data[0] is not divisible by num_tensors
per_tensor_first_dim is computed with plain integer division. If logical_shape_nvte.data[0] is not an exact multiple of num_tensors (e.g. due to a caller bug or unexpected grouped layout), the result is silently truncated, causing padded_m to be underestimated and the output buffer to be too small. A divisibility assertion would catch this much earlier with a clear error message.
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; | |
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; | |
| NVTE_CHECK(logical_shape_nvte.data[0] % num_tensors == 0, | |
| "Grouped GEMM swizzle expects logical_shape first dim to be divisible by num_tensors."); |
| bool input_is_compact; | ||
| if (input_scale_numel == input->num_tensors * padded_scale_elems) { | ||
| input_is_compact = false; | ||
| } else if (input_scale_numel == compact_total_scale_elems) { | ||
| input_is_compact = true; | ||
| } else { | ||
| NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, | ||
| "Grouped input columnwise_scale_inv size does not match expected packed size."); | ||
| NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, | ||
| "Grouped output columnwise_scale_inv size does not match expected packed size."); | ||
| NVTE_ERROR("Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), | ||
| " size does not match expected packed size (got ", input_scale_numel, | ||
| ", expected either ", input->num_tensors * padded_scale_elems, | ||
| " (per-tensor padded) or ", compact_total_scale_elems, " (compact))."); | ||
| } |
There was a problem hiding this comment.
Implicit contract on compact-buffer alignment is not validated
The compact_total_scale_elems formula assumes the upstream quantize kernel allocates the compact scale buffer with its total first dim rounded up to 128 (rowwise) or 4 (colwise). If a caller passes a "plain compact" buffer of size exactly num_tensors * m * padded_k (without trailing alignment slack), neither branch matches and NVTE_ERROR fires with a size-mismatch message that may be hard to diagnose.
Consider also accepting num_tensors * compact_scale_elems as a valid compact size, or documenting this alignment requirement in the error message.
|
@ksivaman Could you add a test exercising the change? |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci |
Description
Fix grouped MXFP8 swizzle when per-expert rows aren't a multiple of 128 and pad each expert's scales to (128, 4).
Type of change
Changes
Checklist: