Skip to content

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938

Open
hxbai wants to merge 2 commits intoNVIDIA:mainfrom
hxbai:swiglu_offset
Open

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938
hxbai wants to merge 2 commits intoNVIDIA:mainfrom
hxbai:swiglu_offset

Conversation

@hxbai
Copy link
Copy Markdown
Contributor

@hxbai hxbai commented Apr 28, 2026

Description

The previous ClampedSwiGLU follows GPT-OSS, which hard-coded the offset 1.0.
DeepSeek-V4 uses ClampedSwiGLU without alpha and offset.
This PR makes the offset of ClampedSwiGLU configurable to support DeepSeek-V4.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR adds a configurable glu_linear_offset parameter (default 1.0) to ClampedSwiGLU and all related operations across the PyTorch, JAX, and common CUDA backends, enabling DeepSeek-V4-style ClampedSwiGLU that omits the offset. The change is consistently applied in all forward and backward kernels, the cuDNN fusion opt-out condition, and the Python-level fallback paths, with default values preserving existing behavior.

Confidence Score: 4/5

Safe to merge once the breaking public C API change (insertion of glu_linear_offset before cudaStream_t) is acknowledged in the PR checklist and a compatibility plan is confirmed.

The implementation is mathematically correct across all kernel paths and backward passes; default values preserve backward compatibility at the Python/PyTorch level. The unresolved concern is the ABI-breaking change to the public versioned C header (flagged in a prior review thread), which is still unchecked in the PR checklist and undocumented as a breaking change.

transformer_engine/common/include/transformer_engine/activation.h — the changed public C API signatures are the only file that needs careful attention from a compatibility standpoint.

Important Files Changed

Filename Overview
transformer_engine/common/include/transformer_engine/activation.h Public C API signatures for nvte_clamped_swiglu and nvte_clamped_dswiglu changed (new glu_linear_offset parameter inserted before cudaStream_t) — ABI-breaking for pre-compiled binaries, already flagged in prior review thread.
transformer_engine/common/util/math.h Added glu_linear_offset field with default 1.0f to ClampedSwiGLUParam struct — backward compatible, correct.
transformer_engine/common/util/vectorized_pointwise.h Both forward and backward kernels updated to use p.glu_linear_offset instead of hardcoded 1/1.0f; gradient math is correct — offset factors into after_dgelu term and is absent from after_dgate, matching the derivative of silu(x_glu) * (clamp(x_linear) + offset).
transformer_engine/common/cast/fp8/gated_fp8.cuh FP8 gated kernel updated to use p.glu_linear_offset; backward pass correctly uses the offset in after_dact but not in after_dgate, consistent with the non-FP8 path.
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh Both forward kernel variants updated to use p.glu_linear_offset consistently — two separate template instantiation sites both updated.
transformer_engine/pytorch/ops/_common.py cuDNN fusion opt-out correctly extended to disable when glu_linear_offset != 1.0, matching the existing pattern for non-default alpha.
transformer_engine/pytorch/ops/basic/swiglu.py ClampedSwiGLU and ScaledClampedQGeGLU both gain glu_linear_offset parameter with default 1.0; passed through to all tex kernel calls and to the inner _clamped instance of ScaledClampedQGeGLU.
transformer_engine/jax/cpp_extensions/activation.py ClampedSwigluParams hash, to_ffi_lowering_dict, and clamped_linear lambda all correctly updated to include glu_linear_offset.
transformer_engine/jax/csrc/extensions.h ClampedSwigluConfig struct and XLA_FFI_REGISTER_STRUCT_ATTR_DECODING both updated with glu_linear_offset — field order matches dict keys in to_ffi_lowering_dict.
tests/pytorch/test_fusible_ops.py Both test_clamped_swiglu and test_scaled_clamped_qgeglu parameterized with (1.0, 0.0) for glu_linear_offset, covering both the legacy and DeepSeek-V4 cases.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["ClampedSwiGLU / ScaledClampedQGeGLU\n(Python API — glu_linear_offset param added)"] --> B{glu_linear_offset == 1.0\nAND alpha == 1.702?}
    B -- Yes --> C["cuDNN Fusion Path\n(fuse_grouped_mlp_ops)"]
    B -- No --> D["TE Kernel Path"]

    D --> E["tex.clamped_swiglu / tex.clamped_dswiglu\n(PyTorch pybind11 — glu_linear_offset added)"]
    D --> F["nvte_clamped_swiglu / nvte_clamped_dswiglu\n(JAX FFI — glu_linear_offset added)"]

    E --> G["ClampedSwiGLUParam\n{limit, alpha, glu_linear_offset}"]
    F --> G

    G --> H["vectorized_pointwise.h\nFWD: clamp(x_linear) + glu_linear_offset\nBWD: offset in ∂act, not in ∂gate"]
    G --> I["gated_fp8.cuh\ngated_mxfp8.cuh\nFWD+BWD: same offset logic"]

    style C fill:#90EE90
    style D fill:#ADD8E6
Loading

Reviews (3): Last reviewed commit: "fix fusion pattern check" | Re-trigger Greptile

Comment on lines +339 to 341
* \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Breaking public C API change

nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused op for grouped MLP is hard-coded for GPT-OSS, so we should make sure not to fuse if glu_linear_offset != 1:

elif isinstance(window[1], ScaledClampedQGeGLU) and (
abs(window[1]._clamped.alpha - 1.702) > 0.001
or not _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu()
):

@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@hxbai hxbai marked this pull request as draft April 29, 2026 00:28
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@hxbai hxbai marked this pull request as ready for review April 29, 2026 01:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants