Skip to content

Add WACNN and SymmetricalTransFormer (STF, CVPR 2022)#354

Open
Yiozolm wants to merge 3 commits intoInterDigitalInc:masterfrom
Yiozolm:pr-stf-wacnn
Open

Add WACNN and SymmetricalTransFormer (STF, CVPR 2022)#354
Yiozolm wants to merge 3 commits intoInterDigitalInc:masterfrom
Yiozolm:pr-stf-wacnn

Conversation

@Yiozolm
Copy link
Copy Markdown

@Yiozolm Yiozolm commented May 3, 2026

Adds WACNN and SymmetricalTransFormer (STF) from R. Zou, C. Song, Z. Zhang, "The Devil Is in the Details: Window-based Attention for Image Compression", CVPR 2022 (arXiv:2203.08450).

Adapted from the official implementation at https://github.com/Googolxx/STF (Apache-2.0).

This is the first installment of the per-model PR series proposed in #353. Pretrained weights are intentionally not bundled — calling pretrained=True raises a clear RuntimeError until weights are hosted on S3 (per the discussion in #353).

Summary

  • New zoo entries "stf" and "stf-wacnn" (compressai.models.SymmetricalTransFormer and compressai.models.WACNN).
  • New compressai.layers.attn subpackage with the Swin window-based attention building blocks the two models depend on. Reuses timm.models.swin_transformer wherever the implementation is generic to avoid vendoring a parallel Swin stack — see Reuse of timm below.
  • New ChannelSliceLatentCodec + SliceEntropyCompressionModel base — designed to be reused by the channel-conditional models in follow-up PRs (CCA, TCM, …).
  • Checkpoint converter in examples/convert_stf_checkpoint.py that loads the published stf_<bpp>_best.pth.tar / cnn_<bpp>_best.pth.tar files from the upstream repo and writes them in compressai layout.
  • timm added to dependencies (the Swin building blocks reuse DropPath, Mlp, trunc_normal_, WindowAttention, SwinTransformerBlock, window_partition, window_reverse from it).

Reuse of timm

Rather than vendor a full Swin stack inside CompressAI, the implementation in this PR delegates to timm.models.swin_transformer everywhere the upstream STF code matches the Swin reference. This kept the diff focused on the genuinely STF-specific pieces and shaved ~280 lines from an earlier vendored draft.

Component What we do
WindowAttention Thin subclass of timm.models.swin_transformer.WindowAttention that promotes the relative_position_index buffer from persistent=False to True (so released checkpoints load under strict mode) and accepts the historical qk_scale kwarg. ~15 lines instead of a ~50-line reimplementation.
SwinTransformerBlock (used inside _STFBasicLayer) Use timm.models.swin_transformer.SwinTransformerBlock(always_partition=True, dynamic_mask=True) directly. After construction we promote each block's attn.relative_position_index to persistent so per-block keys round-trip strict-mode. Avoids reimplementing the cyclic-shift / pad / window-attn / unpad / unshift forward path.
window_partition / window_reverse Square-window adapters around the timm helpers — the only difference is timm uses Tuple[int, int] whereas STF passes int.
DropPath, Mlp, trunc_normal_ timm.layers versions used directly.
WMSA / WinNoShiftAttention (the STF-specific dual-branch sigmoid-gated attention block) Vendored, but parameterised with output_proj=True/False so a single class serves both the STF / WACNN topology in this PR (no projection) and the projection-bearing variant used by other window-attention CompressAI models. No private _STF* duplicate is kept.
Other STF-specific blocks (SwinBlock, SWAtten, ConvTransBlock, _PatchEmbed, _WinBasedAttention, WinResidualUnit, pad_to_window_multiple, build_window_attention_mask) Vendored. These are the parts where STF deviates from the Swin reference (or where timm does not expose an equivalent), so vendoring keeps the API stable across timm releases.

The dependency on timm.models.swin_transformer.* is deliberate (the file lives under timm.models.* rather than timm.layers.*, so it is not part of timm's stability promise). If maintainers prefer to insulate CompressAI from timm model-internals, the subclass / wrapper pattern makes it a small, self-contained ~120-line revert. Happy to do that on request.

Commits

Three commits, designed to be reviewed independently:

Commit Scope LOC
feat(layers): add Swin window-based attention building blocks compressai/layers/attn/{swin,inference,__init__}.py + tiny re-export in layers/__init__.py +668
feat(latent_codecs): add ChannelSliceLatentCodec + slice-entropy base compressai/latent_codecs/channel_slice.py + compressai/models/_bases/{slice_entropy,__init__}.py + re-export +543
feat(models): add WACNN and SymmetricalTransFormer (STF) from Zou et al. 2022 compressai/models/stf.py + zoo / converter / smoke tests + timm in pyproject.toml +833
Total 15 files, +2044 lines, no modifications to existing logic

License & attribution

compressai/models/stf.py carries a dual-license header noting the upstream source URL and Apache-2.0 license alongside the standard InterDigital BSD 3-Clause Clear License for the modifications. The Swin building blocks in compressai/layers/attn/swin.py are a mix of timm subclasses / wrappers (covered by timm's Apache-2.0) and STF-derived classes (also Apache-2.0); happy to add per-file attribution headers there as well if maintainers prefer.

Verified

  • pytest tests/test_models.py tests/test_layers.py tests/test_init.py32 passed (3 new TestStf + 29 existing).
  • WACNN.from_state_dict(model.state_dict()) round-trip → x_hat diff = 0.0 (405 keys).
  • SymmetricalTransFormer.from_state_dict(model.state_dict()) round-trip → x_hat diff = 0.0 (315 keys).
  • convert_upstream_stf_state_dict correctly re-roots module.cc_* / module.gaussian_conditional keys under latent_codec.* so the published Googolxx/STF checkpoints load via from_state_dict.

Test plan

  • Forward + state-dict round-trip for both backbones at small config (already in TestStf).
  • Smoke-test examples/convert_stf_checkpoint.py against an upstream cnn_<bpp>_best.pth.tar checkpoint locally (x_hat diff = 0 between original and converted state dict in eval mode).
  • Maintainers: confirm timm being moved into hard dependencies is acceptable (alternative: keep [stf] extras group).
  • Maintainers: confirm dependence on timm.models.swin_transformer.* (model-internal API) is acceptable, vs. vendoring a CompressAI copy. Reverting is a small isolated change if preferred.
  • Maintainers: if you want the Swin layer files to carry their own attribution headers (in addition to models/stf.py), I will add them.

Notes for follow-up PRs (per #353)

The next PR will add CCA + TCM together — both reuse ChannelSliceLatentCodec from this PR, and CCA contributes a CausalContextAdjustmentEntropyModel that TCM can opt into. After that, the remaining license-clear models (InvCompress, MLIC++, HPCM, SAAF, DCAE, GLIC, TIC, TinyLIC, ShiftLIC) follow one or two at a time, each PR layering on top of what's already merged.

@Yiozolm Yiozolm changed the title Pr stf wacnn Add WACNN and SymmetricalTransFormer (STF, CVPR 2022) May 3, 2026
@Yiozolm Yiozolm marked this pull request as draft May 3, 2026 15:28
Yiozolm added 3 commits May 4, 2026 09:18
New compressai.layers.attn subpackage with Swin primitives needed by
transformer-based learned image compression models (STF / WACNN in
this PR, follow-up InterDigitalInc#353).

Module layout:
- compressai/layers/attn/swin.py: WindowAttention, WMSA, SwinBlock,
  SWAtten, ConvTransBlock, WinNoShiftAttention, WinResidualUnit,
  PatchMerging, PatchSplit + window_partition / window_reverse /
  pad_to_window_multiple / build_window_attention_mask helpers.
- compressai/layers/attn/inference.py: infer_swatten_* helpers for
  from_state_dict.
- compressai/layers/attn/__init__.py: single re-export surface.

Implementation reuses timm.models.swin_transformer where possible:
- WindowAttention is a thin subclass that promotes timm's
  relative_position_index buffer from persistent=False to True so
  released checkpoints round-trip under strict mode, plus accepts the
  historical qk_scale kwarg.
- window_partition / window_reverse are square-window adapters around
  the timm helpers.

WMSA / _WinBasedAttention / WinNoShiftAttention all take an
output_proj=True/False switch: True (default) keeps the Linear
projection used by SwinBlock / SWAtten elsewhere in CompressAI; False
drops it so the same WinNoShiftAttention class serves the STF / WACNN
topology in this PR (which has no projection there) without a separate
private copy.

Root compressai/layers/__init__.py only appends from .attn import *
so existing call sites keep working.
Channel-conditional slice-entropy machinery shared by STF and WACNN
in this PR, factored out so other slice-conditional models added in
follow-up PRs (CCA, TCM, ...) can reuse it.

- compressai/latent_codecs/channel_slice.py:
  ChannelSliceLatentCodec implements equal-sized channel slicing
  (Minnen2020 / He2022) — cc_mean_transforms / cc_scale_transforms
  per slice + LRP head + optional mean / scale support transforms.
  Sibling of the existing ChannelGroupsLatentCodec.

- compressai/models/_bases/slice_entropy.py:
  SliceEntropyCompressionModel collects the recurring "build entropy
  bottleneck for z plus a ChannelSliceLatentCodec for y" plumbing
  used by every channel-slice model, plus the from_state_dict helpers
  (infer_num_slices, infer_max_support_slices, slice/lrp support
  channel arithmetic, make_entropy_transform). Subclasses populate
  g_a / g_s / h_a / h_mean_s / h_scale_s, then call
  self._init_slice_entropy(...).

- compressai/models/_bases/__init__.py: re-export surface.
- compressai/latent_codecs/__init__.py: export ChannelSliceLatentCodec.
…al. 2022

Adds the WACNN (CNN backbone) and SymmetricalTransFormer (transformer
backbone) from R. Zou, C. Song, Z. Zhang, "The Devil Is in the Details:
Window-based Attention for Image Compression", CVPR 2022
(https://arxiv.org/abs/2203.08450). Adapted from the official
implementation at https://github.com/Googolxx/STF (Apache-2.0).

What is included:

- compressai/models/stf.py: WACNN, SymmetricalTransFormer, and a
  convert_upstream_stf_state_dict helper that strips the DataParallel
  module. prefix and re-roots cc_mean_transforms / cc_scale_transforms
  / lrp_transforms / gaussian_conditional under latent_codec.* so
  released checkpoints from the upstream repo load via
  WACNN.from_state_dict / SymmetricalTransFormer.from_state_dict.

  Builds on the WinNoShiftAttention (output_proj=False) primitive from
  the previous "feat(layers)" commit and uses
  timm.models.swin_transformer.SwinTransformerBlock with
  always_partition=True / dynamic_mask=True for the transformer stages
  (with the per-block relative_position_index buffer promoted to
  persistent so the upstream key list survives strict load).

- compressai/models/__init__.py: export the two new model classes.

- compressai/zoo/{__init__,image}.py: register "stf" and "stf-wacnn"
  in image_models with thin pretrained=False factory functions;
  pretrained=True raises a clear RuntimeError until weights are hosted
  on S3 by the maintainers (per InterDigitalInc#353).

- examples/convert_stf_checkpoint.py: CLI wrapper around the
  upstream-state-dict converter, with an optional smoke test on a
  synthetic image.

- tests/test_models.py: TestStf class — forward + state_dict
  round-trip for both backbones, plus a unit test for the
  convert_upstream_stf_state_dict helper.

- pyproject.toml: add timm to the runtime dependencies (used by the
  Swin building blocks committed earlier in this PR for DropPath /
  Mlp / SwinTransformerBlock / WindowAttention).

Pretrained weights are intentionally not bundled. State-dict round-trip
diff is 0.0 for both WACNN (405 keys) and STF (315 keys); pytest
tests/test_models.py tests/test_layers.py tests/test_init.py = 32
passed (3 new TestStf + 29 existing).
@Yiozolm Yiozolm marked this pull request as ready for review May 4, 2026 01:25
@Yiozolm
Copy link
Copy Markdown
Author

Yiozolm commented May 4, 2026

I have reused the swin-transformer code from timm to avoid unnecessary rewriting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant