Add closed-form MXFP4 -> NVFP4 weight cast (--cast_mxfp4_to_nvfp4)#1372
Add closed-form MXFP4 -> NVFP4 weight cast (--cast_mxfp4_to_nvfp4)#1372
Conversation
Research artifact comparing three algorithms for converting an MXFP4
tensor (block 32, E2M1 + E8M0) to NVFP4 (block 16, E2M1 + E4M3 + FP32
global scale):
Algo 1: dequantize MXFP4 -> bf16 -> standard NVFP4 quantize.
Algo 2: keep E2M1 nibbles verbatim; pick global S = 2^m and store
per-block E4M3 scales as 2^(k_j - m), snapping out-of-range
blocks. Two m strategies: midpoint and 1D integer search over
the closed-form snap-error objective.
Algo 3: hybrid - verbatim path for in-range blocks (zero error) plus
NVFP4 requantization with fixed scale_2 = 2^m for OOR blocks.
m chosen by direct-MSE 1D sweep.
Includes 27 scenarios (gaussian, heavy-tail, outlier patterns, spread
boundary tests, layer-shaped LLM weights) and a report summarizing
results, the snap-up/snap-down asymmetry that drives the m choice, and
the one pathological case (single dominant outlier) where Algo 3 still
trails Algo 1 by 0.21% due to integer-m vs continuous scale_2.
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
The m-search loop in the original Algo 3 turns out to be unnecessary.
Across all 27 test scenarios the search converges on m = k_max - 8 and
that closed-form rule is provably the right pick:
- For spread <= 17, every block's k_j - m lands in [8 - spread, 8],
a subset of E4M3's exact-power-of-2 window [-9, 8]. All blocks take
the verbatim path; the conversion is lossless (MSE = 0).
- For spread > 17, m = k_max - 8 is the only choice that does not
NaN the highest-magnitude blocks: a lower m drives the per-block
scale amax/(6*2^m) above E4M3's max (448); a higher m only shrinks
in-range coverage on the low side without helping the high side.
Replaces the brute-force algo3_hybrid_requant with a single-pass
algo3_hybrid using the closed-form m. The Algo 4 / Algo 5 variants
that were used to discover this rule are removed; the script is back
to three algorithms (Algo 1 / Algo 2 / Algo 3) and the report has been
rewritten accordingly.
Same MSE numbers as before. No library changes — strictly under
scratch/.
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
When the source HF checkpoint is MXFP4 (e.g. openai/gpt-oss-20b), the new flag pins NVFP4 weight quantizers' scale_2 to 2^m (m = k_max - 8) and the per-block _amax to 6 * 2^k_j read from the source *_scales. Per-block scale = 2^(k_j - m) is exactly representable in E4M3 for in-range blocks, so NVFP4 dequant matches MXFP4 dequant bit-for-bit (verified SNR=inf on gpt-oss-20b's full ~19B-element MoE expert weights). For out-of-range blocks (k_max - k_j > 17), the per-block amax falls back to data-derived max(|w_block|), keeping the post-clamp scale closer to the actual block magnitude than the closed-form ideal would. Modelopt-side enablers: - max_calibrate auto-promotes static-block NVFP4 weight quantizers to NVFP4StaticQuantizer at the end of calibration. - static_blockwise_fp4_fake_quant kernel accepts N-D inputs (was 2D-only), unblocking MoE expert weights of shape (E, F, K). - BMM-experts NVFP4 export routes through get_weights_scaling_factor_from_quantizer for static-mode quantizers, so the pinned _amax is consumed (was bypassed by recompute-from-weight). - set_expert_quantizer_amax scalar-reduces per-quantizer amax before stacking, supporting per-block (vs scalar) static-mode amax. Wired through scripts/parser.sh + scripts/huggingface_example.sh as the shell-level --cast_mxfp4_to_nvfp4 flag. Removes the scratch/ MSE experiment (kept in PR description for context). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
📝 WalkthroughWalkthroughAdds a new MXFP4-to-NVFP4 casting utility that converts quantization parameters from MXFP4 HF checkpoints to NVFP4 format. It computes global and per-block Changes
Sequence Diagram(s)sequenceDiagram
participant Checkpoint as MXFP4<br/>Checkpoint
participant Collector as Checkpoint<br/>Collector
participant Calc as Amax<br/>Calculator
participant Quantizer as NVFP4<br/>Quantizer
participant Model as Model<br/>Instance
Checkpoint->>Collector: Load *_scales &<br/>*_blocks tensors
Collector->>Calc: scales, blocks per layer
Note over Calc: Compute global_amax from<br/>scale exponents (in-range logic)
Calc->>Calc: Create E2M1 magnitude table
Note over Calc: Compute per-block _amax<br/>(in-range: formula,<br/>out-of-range: data-driven)
Calc->>Model: Locate NVFP4StaticQuantizer<br/>for each weight quantizer
Model->>Quantizer: Retrieve quantizer instance
Quantizer->>Quantizer: Update global_amax<br/>via property setter
Quantizer->>Quantizer: Replace _amax buffer<br/>in-place
Quantizer->>Model: Apply updated parameters
Model->>Model: Ready for export with<br/>NVFP4 scales
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~30 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error, 1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Review rate limit: 9/10 reviews remaining, refill in 6 minutes. Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (2)
examples/llm_ptq/hf_ptq.py (1)
1126-1128: Add fail-fast validation for cast mode compatibilityPlease reject
--cast_mxfp4_to_nvfp4unless--qformatis NVFP4-family (and preferably disallow with multi-format auto-quantize). Right now invalid combinations can proceed and fail late.Suggested guard
args = parser.parse_args() + if args.cast_mxfp4_to_nvfp4: + qformats = [q.strip() for q in args.qformat.split(",")] + if not all("nvfp4" in q for q in qformats): + parser.error("--cast_mxfp4_to_nvfp4 requires NVFP4-family --qformat values.") + if args.auto_quantize_bits is not None and len(qformats) > 1: + parser.error( + "--cast_mxfp4_to_nvfp4 is not supported with multi-format auto_quantize." + )Also applies to: 1370-1381
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/llm_ptq/hf_ptq.py` around lines 1126 - 1128, Add a fail-fast guard before calling apply_cast_mxfp4_to_nvfp4: validate that args.cast_mxfp4_to_nvfp4 is only true when args.qformat is in the NVFP4 family (check the exact qformat string values your code uses) and reject/raise an error (or exit) if it’s used together with any multi-format auto-quantize option (the flag/variable controlling multi-format auto-quantize in your parser) to prevent late failures; update the two call sites that invoke apply_cast_mxfp4_to_nvfp4 (the one using args.cast_mxfp4_to_nvfp4 and args.pyt_ckpt_path and the similar block around the 1370-1381 region) to perform this validation first and emit a clear message about allowed combinations.examples/llm_ptq/cast_mxfp4_to_nvfp4.py (1)
39-42: Lazy-load the optional HF dependencies.Importing
safetensorsand the quantizer class at module load time makes this helper fail to import unless the full extra set is already installed. Please move these imports into the code paths that actually need them, or gate them through the repo’s plugin-loading pattern.As per coding guidelines, "Avoid hard imports of optional dependencies at module level; features should be gated by install extras (
[onnx],[hf],[all]) and loaded lazily via `import_plugin()``."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py` around lines 39 - 42, The module-level imports of safetensors and NVFP4StaticQuantizer cause hard dependency loading; move these imports into the function(s) that actually use them (e.g., inside the routine that opens the safetensors file or performs quantization) or replace them with the repo's plugin loader (import_plugin) so they are lazily loaded; specifically, remove "from safetensors import safe_open" and "from modelopt.torch.quantization.nn.modules.tensor_quantizer import NVFP4StaticQuantizer" from the top-level and import safe_open and NVFP4StaticQuantizer within the function that calls them (or use import_plugin('safetensors').safe_open / import_plugin('modelopt.torch.quantization').NVFP4StaticQuantizer) and add a clear error message when the import fails.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py`:
- Around line 144-148: The shape check currently only compares leading dims
(blocks.shape[:-1] vs e8m0_scales.shape) and misses verifying the trailing block
size; update the validation to also require blocks.shape[-1] == 16 (and raise
the same ValueError if not) so malformed tensors like (..., N, 15) or (..., N,
32) are rejected; keep the existing error message (referencing blocks and
e8m0_scales) and perform this check where the current if comparing
blocks.shape[:-1] and e8m0_scales.shape appears.
- Around line 328-370: The code still unconditionally loads the packed blocks
tensor via _read and calls compute_per_block_amax_for_mxfp4 even when
compute_global_amax_for_scales reported info["pct_lossless"] >= 100.0; change
the logic in the loop (around compute_global_amax_for_scales, qname,
weight_quantizer checks) to short-circuit the I/O path for fully-lossless
layers: if info["pct_lossless"] >= 100.0, construct global_amax from
global_amax_value on the existing per-block buffer device and reuse the existing
per-block `_amax` (e.g., existing.to(dtype=torch.float32, device=device)) as
per_block_amax, and skip calling _read(blocks_key, ...) and
compute_per_block_amax_for_mxfp4; keep the NVFP4StaticQuantizer/assert checks
and only avoid the expensive block read when pct_lossless==100.0.
- Around line 338-382: Several assertions validating untrusted checkpoint/model
data (the check that blocks_shard is not None, the type-check of
weight_quantizer, the presence/shape check of weight_quantizer._amax, and the
element-count check comparing existing.numel() to per_block_amax.numel()) must
be replaced with explicit exception handling; update the checks around
blocks_shard, weight_quantizer (qname → NVFP4StaticQuantizer), existing/_amax,
and element count to raise specific exceptions (e.g., ValueError or
RuntimeError) with clear diagnostic messages including qname, expected vs actual
types/values, and any relevant shapes/numel counts, rather than using assert so
validation still runs under python -O and treats all checkpoint artifacts as
untrusted.
- Around line 258-263: The handle cache (handles: dict and safe_open(...)
usages) is never closed, leaking file descriptors/mmaps; wrap the cache creation
and all safe_open acquisitions in a context-managed ExitStack (or equivalent) so
each safe_open call is entered via stack.enter_context(...) and all handles are
closed deterministically when the function returns—apply this change in both the
function that iterates sorted(scales_keys.items()) (where scales =
handles[shard].get_tensor(tensor_key)) and in apply_to_model() around its main
loop, ensuring the ExitStack is created at the start of the function and closed
on all return paths so file handles are always released.
In `@modelopt/torch/export/layer_utils.py`:
- Around line 1094-1096: The collection of valid amax values should skip meta
tensors to avoid calling .to(target_device) on tensors without storage; modify
the block that appends to valid_amax_values (where existing_amax is checked) to
first check if existing_amax is a meta tensor (existing_amax.is_meta) and only
call existing_amax.amax().to(target_device) for non-meta tensors, otherwise
ignore/skip that existing_amax so meta tensors are not included in the fallback
aggregation.
In `@modelopt/torch/kernels/quantization/gemm/fp4_kernel.py`:
- Around line 267-274: The code assumes amax has elements but will divide by
zero if amax.numel() == 0; add an explicit guard before computing NUM_FP4_BLOCKS
and BLOCK_SIZE: check if amax.numel() == 0 and raise a clear ValueError (or
handle it) indicating amax is empty; then compute NUM_FP4_BLOCKS = amax.numel(),
verify x.numel() % NUM_FP4_BLOCKS == 0 as before, and compute BLOCK_SIZE =
x.numel() // NUM_FP4_BLOCKS (references: amax, NUM_FP4_BLOCKS, x, BLOCK_SIZE).
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 283-289: The promotion call to promote_nvfp4_static_quantizers is
placed after an early return in max_calibrate so the branch taken when
distributed_sync=False never gets NVFP4StaticQuantizer promotion; move the
promote_nvfp4_static_quantizers(model) invocation so it runs before the early
return in max_calibrate (or invoke it in both the distributed_sync=True and
distributed_sync=False branches) so that promotion always occurs regardless of
the distributed_sync flag.
---
Nitpick comments:
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py`:
- Around line 39-42: The module-level imports of safetensors and
NVFP4StaticQuantizer cause hard dependency loading; move these imports into the
function(s) that actually use them (e.g., inside the routine that opens the
safetensors file or performs quantization) or replace them with the repo's
plugin loader (import_plugin) so they are lazily loaded; specifically, remove
"from safetensors import safe_open" and "from
modelopt.torch.quantization.nn.modules.tensor_quantizer import
NVFP4StaticQuantizer" from the top-level and import safe_open and
NVFP4StaticQuantizer within the function that calls them (or use
import_plugin('safetensors').safe_open /
import_plugin('modelopt.torch.quantization').NVFP4StaticQuantizer) and add a
clear error message when the import fails.
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 1126-1128: Add a fail-fast guard before calling
apply_cast_mxfp4_to_nvfp4: validate that args.cast_mxfp4_to_nvfp4 is only true
when args.qformat is in the NVFP4 family (check the exact qformat string values
your code uses) and reject/raise an error (or exit) if it’s used together with
any multi-format auto-quantize option (the flag/variable controlling
multi-format auto-quantize in your parser) to prevent late failures; update the
two call sites that invoke apply_cast_mxfp4_to_nvfp4 (the one using
args.cast_mxfp4_to_nvfp4 and args.pyt_ckpt_path and the similar block around the
1370-1381 region) to perform this validation first and emit a clear message
about allowed combinations.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 636d8bc6-4bef-45a4-a6e8-d12387edfcf0
📒 Files selected for processing (9)
examples/llm_ptq/cast_mxfp4_to_nvfp4.pyexamples/llm_ptq/hf_ptq.pyexamples/llm_ptq/scripts/huggingface_example.shexamples/llm_ptq/scripts/parser.shmodelopt/torch/export/layer_utils.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/kernels/quantization/gemm/fp4_kernel.pymodelopt/torch/quantization/model_calib.pytests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py
| if blocks.shape[:-1] != e8m0_scales.shape: | ||
| raise ValueError( | ||
| f"shape mismatch: blocks {tuple(blocks.shape)} (expect last dim 16) " | ||
| f"vs scales {tuple(e8m0_scales.shape)}" | ||
| ) |
There was a problem hiding this comment.
Reject malformed block tensors with a non-16-byte trailing dimension.
The current check only compares leading dims. A tensor shaped like (..., N, 15) or (..., N, 32) will pass validation and then produce the wrong per-block max while still duplicating by 2 as if each MXFP4 block were 32 elements.
Suggested fix
- if blocks.shape[:-1] != e8m0_scales.shape:
+ if blocks.shape[-1] != 16 or blocks.shape[:-1] != e8m0_scales.shape:
raise ValueError(
- f"shape mismatch: blocks {tuple(blocks.shape)} (expect last dim 16) "
+ f"shape mismatch: blocks {tuple(blocks.shape)} "
+ "(expected (..., num_mxfp4_blocks, 16)) "
f"vs scales {tuple(e8m0_scales.shape)}"
)As per coding guidelines, "When implementing/using the MXFP4→NVFP4 cast flow and checkpoint loading, treat all model/weight artifacts as untrusted."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py` around lines 144 - 148, The shape
check currently only compares leading dims (blocks.shape[:-1] vs
e8m0_scales.shape) and misses verifying the trailing block size; update the
validation to also require blocks.shape[-1] == 16 (and raise the same ValueError
if not) so malformed tensors like (..., N, 15) or (..., N, 32) are rejected;
keep the existing error message (referencing blocks and e8m0_scales) and perform
this check where the current if comparing blocks.shape[:-1] and
e8m0_scales.shape appears.
| handles: dict[Path, safe_open] = {} | ||
| amax_map: dict[str, dict] = {} | ||
| for tensor_key, shard in sorted(scales_keys.items()): | ||
| if shard not in handles: | ||
| handles[shard] = safe_open(shard, framework="pt", device="cpu") | ||
| scales = handles[shard].get_tensor(tensor_key) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find the file and check if it exists
fd -t f "cast_mxfp4_to_nvfp4.py"Repository: NVIDIA/Model-Optimizer
Length of output: 157
🏁 Script executed:
#!/bin/bash
# Read the file to understand the context
wc -l examples/llm_ptq/cast_mxfp4_to_nvfp4.pyRepository: NVIDIA/Model-Optimizer
Length of output: 110
🏁 Script executed:
#!/bin/bash
# Read the relevant sections around lines 258-263 and 320-323
echo "=== Lines 250-275 ==="
sed -n '250,275p' examples/llm_ptq/cast_mxfp4_to_nvfp4.py
echo ""
echo "=== Lines 310-330 ==="
sed -n '310,330p' examples/llm_ptq/cast_mxfp4_to_nvfp4.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1674
🏁 Script executed:
#!/bin/bash
# Get more context - read the function containing lines 258-263
echo "=== Context around line 258 (full function) ==="
sed -n '240,275p' examples/llm_ptq/cast_mxfp4_to_nvfp4.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1392
🏁 Script executed:
#!/bin/bash
# Get full apply_to_model function to see how it ends
echo "=== apply_to_model function start ==="
sed -n '273,320p' examples/llm_ptq/cast_mxfp4_to_nvfp4.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2064
🏁 Script executed:
#!/bin/bash
# Check the end of apply_to_model function
echo "=== End of apply_to_model function ==="
sed -n '370,409p' examples/llm_ptq/cast_mxfp4_to_nvfp4.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2073
🏁 Script executed:
#!/bin/bash
# Check import of safe_open to understand its type
head -50 examples/llm_ptq/cast_mxfp4_to_nvfp4.py | grep -n "safe_open\|import"Repository: NVIDIA/Model-Optimizer
Length of output: 254
Close safe_open handles deterministically.
Both functions cache safe_open(...) objects in a dict and never close them. On heavily sharded checkpoints, this can keep many file descriptors and mmaps alive until garbage collection runs. Wrap the handle cache in an ExitStack or equivalent scoped context to ensure deterministic resource cleanup.
The first function returns without closing handles (lines 258–270), and apply_to_model() returns without closing handles after the main loop (lines 310–407).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py` around lines 258 - 263, The handle
cache (handles: dict and safe_open(...) usages) is never closed, leaking file
descriptors/mmaps; wrap the cache creation and all safe_open acquisitions in a
context-managed ExitStack (or equivalent) so each safe_open call is entered via
stack.enter_context(...) and all handles are closed deterministically when the
function returns—apply this change in both the function that iterates
sorted(scales_keys.items()) (where scales =
handles[shard].get_tensor(tensor_key)) and in apply_to_model() around its main
loop, ensuring the ExitStack is created at the start of the function and closed
on all return paths so file handles are always released.
| global_amax_value, info = compute_global_amax_for_scales(scales) | ||
| n_total_layers += 1 | ||
| if info["pct_lossless"] >= 100.0: | ||
| n_lossless_layers += 1 | ||
| grand_total_blocks += info["n_total_blocks"] | ||
| grand_lossless_blocks += info["n_lossless_blocks"] | ||
|
|
||
| blocks_key = tensor_key[: -len("_scales")] + "_blocks" | ||
| qname = quantizer_name_from_blocks_key(blocks_key) | ||
| blocks_shard = blocks_keys.get(blocks_key) | ||
| assert blocks_shard is not None, ( | ||
| f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint." | ||
| ) | ||
|
|
||
| weight_quantizer = name_to_module.get(qname) | ||
| if weight_quantizer is None: | ||
| missed.append(qname) | ||
| continue | ||
|
|
||
| # The cast assumes ``max_calibrate`` already promoted this quantizer | ||
| # to NVFP4StaticQuantizer (with ``_amax`` populated per-block by | ||
| # static-block max-cal and ``_global_amax`` set by the auto-promote). | ||
| # Anything else means the qformat or quant_cfg disabled this module's | ||
| # weight quantization — surface that loudly so we don't silently no-op. | ||
| assert isinstance(weight_quantizer, NVFP4StaticQuantizer), ( | ||
| f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's " | ||
| f"auto-promote), got {type(weight_quantizer).__name__}. The cast " | ||
| "requires the matching quantizer to be enabled with static-block " | ||
| "NVFP4 (num_bits=(2,1), scale_bits=(4,3))." | ||
| ) | ||
| existing = getattr(weight_quantizer, "_amax", None) | ||
| assert isinstance(existing, torch.Tensor) and existing.numel() > 1, ( | ||
| f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` " | ||
| f"buffer populated by max_calibrate. Got: {existing!r}." | ||
| ) | ||
|
|
||
| # Pick the device from the existing per-block ``_amax`` buffer. | ||
| device = existing.device | ||
|
|
||
| global_amax = torch.tensor(float(global_amax_value), dtype=torch.float32, device=device) | ||
| blocks = _read(blocks_key, blocks_shard) | ||
| per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to( | ||
| dtype=torch.float32, device=device |
There was a problem hiding this comment.
Don’t read *_blocks for layers that are already 100% lossless.
compute_global_amax_for_scales() already tells you whether every block is in-range, but apply_to_model() still loads the packed block tensor unconditionally. In the common case described in the PR, that defeats the main I/O win of the closed-form path and turns a scale-only cast back into full weight reads.
Suggested direction
- blocks = _read(blocks_key, blocks_shard)
- per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to(
- dtype=torch.float32, device=device
- )
+ if info["n_lossless_blocks"] == info["n_total_blocks"]:
+ per_block_amax = (
+ E2M1_MAX * torch.exp2((scales.to(torch.int32) - E8M0_BIAS).float())
+ ).repeat_interleave(2, dim=-1).to(dtype=torch.float32, device=device)
+ else:
+ blocks = _read(blocks_key, blocks_shard)
+ per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to(
+ dtype=torch.float32, device=device
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py` around lines 328 - 370, The code
still unconditionally loads the packed blocks tensor via _read and calls
compute_per_block_amax_for_mxfp4 even when compute_global_amax_for_scales
reported info["pct_lossless"] >= 100.0; change the logic in the loop (around
compute_global_amax_for_scales, qname, weight_quantizer checks) to short-circuit
the I/O path for fully-lossless layers: if info["pct_lossless"] >= 100.0,
construct global_amax from global_amax_value on the existing per-block buffer
device and reuse the existing per-block `_amax` (e.g.,
existing.to(dtype=torch.float32, device=device)) as per_block_amax, and skip
calling _read(blocks_key, ...) and compute_per_block_amax_for_mxfp4; keep the
NVFP4StaticQuantizer/assert checks and only avoid the expensive block read when
pct_lossless==100.0.
| assert blocks_shard is not None, ( | ||
| f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint." | ||
| ) | ||
|
|
||
| weight_quantizer = name_to_module.get(qname) | ||
| if weight_quantizer is None: | ||
| missed.append(qname) | ||
| continue | ||
|
|
||
| # The cast assumes ``max_calibrate`` already promoted this quantizer | ||
| # to NVFP4StaticQuantizer (with ``_amax`` populated per-block by | ||
| # static-block max-cal and ``_global_amax`` set by the auto-promote). | ||
| # Anything else means the qformat or quant_cfg disabled this module's | ||
| # weight quantization — surface that loudly so we don't silently no-op. | ||
| assert isinstance(weight_quantizer, NVFP4StaticQuantizer), ( | ||
| f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's " | ||
| f"auto-promote), got {type(weight_quantizer).__name__}. The cast " | ||
| "requires the matching quantizer to be enabled with static-block " | ||
| "NVFP4 (num_bits=(2,1), scale_bits=(4,3))." | ||
| ) | ||
| existing = getattr(weight_quantizer, "_amax", None) | ||
| assert isinstance(existing, torch.Tensor) and existing.numel() > 1, ( | ||
| f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` " | ||
| f"buffer populated by max_calibrate. Got: {existing!r}." | ||
| ) | ||
|
|
||
| # Pick the device from the existing per-block ``_amax`` buffer. | ||
| device = existing.device | ||
|
|
||
| global_amax = torch.tensor(float(global_amax_value), dtype=torch.float32, device=device) | ||
| blocks = _read(blocks_key, blocks_shard) | ||
| per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to( | ||
| dtype=torch.float32, device=device | ||
| ) | ||
| # Numel must match — calibration may store ``_amax`` flat (e.g. (N, 1)) | ||
| # while we compute it in natural (E, F, num_blocks) layout. The static | ||
| # export path reshapes via ``.view(expected_shape)``, so we just need | ||
| # element count to agree, then reshape for the in-place copy. | ||
| assert existing.numel() == per_block_amax.numel(), ( | ||
| f"{qname}: ``_amax`` element count {existing.numel()} does not " | ||
| f"match the cast-computed count {per_block_amax.numel()}. The " | ||
| "block layout from calibration disagrees with the source MXFP4 " | ||
| "scales — check that the qformat block_size is 16 and the source " | ||
| "checkpoint is the same MXFP4 model." | ||
| ) |
There was a problem hiding this comment.
❓ Verification inconclusive
Script executed:
cat -n examples/llm_ptq/cast_mxfp4_to_nvfp4.py | sed -n '330,390p'Repository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
330 if info["pct_lossless"] >= 100.0:
331 n_lossless_layers += 1
332 grand_total_blocks += info["n_total_blocks"]
333 grand_lossless_blocks += info["n_lossless_blocks"]
334
335 blocks_key = tensor_key[: -len("_scales")] + "_blocks"
336 qname = quantizer_name_from_blocks_key(blocks_key)
337 blocks_shard = blocks_keys.get(blocks_key)
338 assert blocks_shard is not None, (
339 f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint."
340 )
341
342 weight_quantizer = name_to_module.get(qname)
343 if weight_quantizer is None:
344 missed.append(qname)
345 continue
346
347 # The cast assumes ``max_calibrate`` already promoted this quantizer
348 # to NVFP4StaticQuantizer (with ``_amax`` populated per-block by
349 # static-block max-cal and ``_global_amax`` set by the auto-promote).
350 # Anything else means the qformat or quant_cfg disabled this module's
351 # weight quantization — surface that loudly so we don't silently no-op.
352 assert isinstance(weight_quantizer, NVFP4StaticQuantizer), (
353 f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's "
354 f"auto-promote), got {type(weight_quantizer).__name__}. The cast "
355 "requires the matching quantizer to be enabled with static-block "
356 "NVFP4 (num_bits=(2,1), scale_bits=(4,3))."
357 )
358 existing = getattr(weight_quantizer, "_amax", None)
359 assert isinstance(existing, torch.Tensor) and existing.numel() > 1, (
360 f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` "
361 f"buffer populated by max_calibrate. Got: {existing!r}."
362 )
363
364 # Pick the device from the existing per-block ``_amax`` buffer.
365 device = existing.device
366
367 global_amax = torch.tensor(float(global_amax_value), dtype=torch.float32, device=device)
368 blocks = _read(blocks_key, blocks_shard)
369 per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to(
370 dtype=torch.float32, device=device
371 )
372 # Numel must match — calibration may store ``_amax`` flat (e.g. (N, 1))
373 # while we compute it in natural (E, F, num_blocks) layout. The static
374 # export path reshapes via ``.view(expected_shape)``, so we just need
375 # element count to agree, then reshape for the in-place copy.
376 assert existing.numel() == per_block_amax.numel(), (
377 f"{qname}: ``_amax`` element count {existing.numel()} does not "
378 f"match the cast-computed count {per_block_amax.numel()}. The "
379 "block layout from calibration disagrees with the source MXFP4 "
380 "scales — check that the qformat block_size is 16 and the source "
381 "checkpoint is the same MXFP4 model."
382 )
383
384 # global_amax via the NVFP4StaticQuantizer property setter (writes to
385 # the canonical ``_global_amax`` buffer).
386 weight_quantizer.global_amax = global_amax
387 # _amax: in-place buffer copy, reshaping our value to the calibrator's
388 # storage layout (numel verified above).
389 with torch.no_grad():
390 existing.data.copy_(per_block_amax.view_as(existing))
Script executed:
# Check if this is in examples/ and verify guideline applicability
ls -la examples/llm_ptq/cast_mxfp4_to_nvfp4.pyRepository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
-rw-r--r-- 1 jailuser jailuser 18001 Apr 29 16:46 examples/llm_ptq/cast_mxfp4_to_nvfp4.py
Script executed:
# Verify the assert pattern used in this section
rg -n "assert.*blocks_shard|assert.*isinstance.*NVFP4StaticQuantizer|assert.*existing.*numel" examples/llm_ptq/cast_mxfp4_to_nvfp4.py -A 2Repository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
338: assert blocks_shard is not None, (
339- f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint."
340- )
--
352: assert isinstance(weight_quantizer, NVFP4StaticQuantizer), (
353- f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's "
354- f"auto-promote), got {type(weight_quantizer).__name__}. The cast "
--
359: assert isinstance(existing, torch.Tensor) and existing.numel() > 1, (
360- f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` "
361- f"buffer populated by max_calibrate. Got: {existing!r}."
--
376: assert existing.numel() == per_block_amax.numel(), (
377- f"{qname}: ``_amax`` element count {existing.numel()} does not "
378- f"match the cast-computed count {per_block_amax.numel()}. The "
Replace assert statements with explicit exception handling for checkpoint and model validation.
The four assertions at lines 338, 352, 359, and 376 validate untrusted checkpoint contents (blocks_shard, quantizer type, _amax buffer). Under python -O, these asserts vanish entirely, leaving no validation. Raise concrete exceptions (e.g., ValueError, RuntimeError) instead.
Per SECURITY.md: when implementing the MXFP4→NVFP4 cast flow and checkpoint loading, treat all model/weight artifacts as untrusted and prefer safe, unconditional validation.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py` around lines 338 - 382, Several
assertions validating untrusted checkpoint/model data (the check that
blocks_shard is not None, the type-check of weight_quantizer, the presence/shape
check of weight_quantizer._amax, and the element-count check comparing
existing.numel() to per_block_amax.numel()) must be replaced with explicit
exception handling; update the checks around blocks_shard, weight_quantizer
(qname → NVFP4StaticQuantizer), existing/_amax, and element count to raise
specific exceptions (e.g., ValueError or RuntimeError) with clear diagnostic
messages including qname, expected vs actual types/values, and any relevant
shapes/numel counts, rather than using assert so validation still runs under
python -O and treats all checkpoint artifacts as untrusted.
| if isinstance(existing_amax, torch.Tensor): | ||
| valid_amax_values.append(existing_amax.to(target_device)) | ||
| valid_amax_values.append(existing_amax.amax().to(target_device)) | ||
| else: |
There was a problem hiding this comment.
Skip meta amax tensors before fallback aggregation
Line 1095 can fail when existing_amax is a meta tensor (.to(target_device) on meta has no storage). Please skip meta tensors in valid_amax_values collection.
Suggested fix
if existing_amax is not None:
# Convert to tensor and add to collection
if isinstance(existing_amax, torch.Tensor):
- valid_amax_values.append(existing_amax.amax().to(target_device))
+ if not existing_amax.is_meta:
+ valid_amax_values.append(existing_amax.amax().to(target_device))
else:
valid_amax_values.append(
torch.tensor(existing_amax, dtype=torch.float32, device=target_device)
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if isinstance(existing_amax, torch.Tensor): | |
| valid_amax_values.append(existing_amax.to(target_device)) | |
| valid_amax_values.append(existing_amax.amax().to(target_device)) | |
| else: | |
| if isinstance(existing_amax, torch.Tensor): | |
| if not existing_amax.is_meta: | |
| valid_amax_values.append(existing_amax.amax().to(target_device)) | |
| else: |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/layer_utils.py` around lines 1094 - 1096, The
collection of valid amax values should skip meta tensors to avoid calling
.to(target_device) on tensors without storage; modify the block that appends to
valid_amax_values (where existing_amax is checked) to first check if
existing_amax is a meta tensor (existing_amax.is_meta) and only call
existing_amax.amax().to(target_device) for non-meta tensors, otherwise
ignore/skip that existing_amax so meta tensors are not included in the fallback
aggregation.
| NUM_FP4_BLOCKS = amax.numel() | ||
| if x.numel() % NUM_FP4_BLOCKS != 0: | ||
| raise ValueError( | ||
| f"x.numel() ({x.numel()}) is not divisible by amax.numel() ({NUM_FP4_BLOCKS}); " | ||
| "they must satisfy x.numel() == NUM_FP4_BLOCKS * BLOCK_SIZE." | ||
| ) | ||
| BLOCK_SIZE = x.numel() // NUM_FP4_BLOCKS | ||
|
|
There was a problem hiding this comment.
Guard against empty amax before modulo/division
If amax.numel() == 0, Line 268/Line 273 will hit division-by-zero. Add an explicit precheck.
Suggested fix
original_shape = x.shape
NUM_FP4_BLOCKS = amax.numel()
+ if NUM_FP4_BLOCKS == 0:
+ raise ValueError("amax must contain at least one block.")
if x.numel() % NUM_FP4_BLOCKS != 0:
raise ValueError(
f"x.numel() ({x.numel()}) is not divisible by amax.numel() ({NUM_FP4_BLOCKS}); "
"they must satisfy x.numel() == NUM_FP4_BLOCKS * BLOCK_SIZE."
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/quantization/gemm/fp4_kernel.py` around lines 267 -
274, The code assumes amax has elements but will divide by zero if amax.numel()
== 0; add an explicit guard before computing NUM_FP4_BLOCKS and BLOCK_SIZE:
check if amax.numel() == 0 and raise a clear ValueError (or handle it)
indicating amax is empty; then compute NUM_FP4_BLOCKS = amax.numel(), verify
x.numel() % NUM_FP4_BLOCKS == 0 as before, and compute BLOCK_SIZE = x.numel() //
NUM_FP4_BLOCKS (references: amax, NUM_FP4_BLOCKS, x, BLOCK_SIZE).
| # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer | ||
| # so the static blockwise fake-quant path is used in forward and the export | ||
| # picks up the two-level (per-block + global) scaling. ``promote_nvfp4_static_quantizers`` | ||
| # only promotes when ``is_static_block_quant`` is True and the per-block ``_amax`` | ||
| # buffer is populated, so it's a no-op for dynamic-block / non-NVFP4 configs. | ||
| promote_nvfp4_static_quantizers(model) | ||
|
|
There was a problem hiding this comment.
Promotion is bypassed when distributed_sync=False
This new promotion step is after the early return in max_calibrate, so callers using distributed_sync=False never get NVFP4StaticQuantizer promotion. Move promotion before the return (or call it in both branches).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/model_calib.py` around lines 283 - 289, The
promotion call to promote_nvfp4_static_quantizers is placed after an early
return in max_calibrate so the branch taken when distributed_sync=False never
gets NVFP4StaticQuantizer promotion; move the
promote_nvfp4_static_quantizers(model) invocation so it runs before the early
return in max_calibrate (or invoke it in both the distributed_sync=True and
distributed_sync=False branches) so that promotion always occurs regardless of
the distributed_sync flag.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1372 +/- ##
==========================================
- Coverage 76.93% 76.00% -0.93%
==========================================
Files 471 471
Lines 50404 51850 +1446
==========================================
+ Hits 38776 39411 +635
- Misses 11628 12439 +811
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary
--cast_mxfp4_to_nvfp4flag inhf_ptq.py(andhuggingface_example.sh) that converts an MXFP4 source checkpoint (e.g.openai/gpt-oss-20b) into an NVFP4 export with bit-exact weight reconstruction for the in-range blocks.scale_2 = 2^m(wherem = k_max − 8) and_amax = 6·2^k_jper NVFP4 block, both read from the source*_scales. The resulting per-block scale2^(k_j − m)is exactly representable in E4M3, soround_to_E2M1(value / 2^k_j)yields the original MXFP4 nibble verbatim. For out-of-range blocks (k_max − k_j > 17) the per-block amax falls back to data-derivedmax(|w_block|), which keeps the post-E4M3-clamp scale close to the block's actual magnitude.Verification
End-to-end on
openai/gpt-oss-20bwith--qformat=nvfp4_mlp_only --cast_mxfp4_to_nvfp4:Per-tensor MSE between MXFP4 source dequant and NVFP4 export dequant (~19B elements):
Modelopt-side enablers
max_calibrateauto-promotes static-block NVFP4 weight quantizers toNVFP4StaticQuantizerat the end of calibration.static_blockwise_fp4_fake_quantkernel accepts N-D inputs (was 2D-only), unblocking MoE expert weights of shape(E, F, K).get_weights_scaling_factor_from_quantizerfor static-mode quantizers, so the pinned_amaxis actually consumed.set_expert_quantizer_amaxscalar-reduces per-quantizer amax before stacking, supporting per-block (vs scalar) static-mode amax.Test plan
tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py(15 tests, all passing) cover: scalar/global-amax math, per-block hybrid (in-range closed-form vs OOR data-derived), shape preservation, key collection, and end-to-endbuild_amax_mapagainst a synthetic safetensors checkpoint.openai/gpt-oss-20b(nvfp4_mlp_onlyqformat) with--cast_mxfp4_to_nvfp4succeeds; export takes ~21 s.🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
New Features
--cast_mxfp4_to_nvfp4command-line flag for automatic conversionBug Fixes
Tests