Add WACNN and SymmetricalTransFormer (STF, CVPR 2022)#354
Open
Yiozolm wants to merge 3 commits intoInterDigitalInc:masterfrom
Open
Add WACNN and SymmetricalTransFormer (STF, CVPR 2022)#354Yiozolm wants to merge 3 commits intoInterDigitalInc:masterfrom
Yiozolm wants to merge 3 commits intoInterDigitalInc:masterfrom
Conversation
New compressai.layers.attn subpackage with Swin primitives needed by transformer-based learned image compression models (STF / WACNN in this PR, follow-up InterDigitalInc#353). Module layout: - compressai/layers/attn/swin.py: WindowAttention, WMSA, SwinBlock, SWAtten, ConvTransBlock, WinNoShiftAttention, WinResidualUnit, PatchMerging, PatchSplit + window_partition / window_reverse / pad_to_window_multiple / build_window_attention_mask helpers. - compressai/layers/attn/inference.py: infer_swatten_* helpers for from_state_dict. - compressai/layers/attn/__init__.py: single re-export surface. Implementation reuses timm.models.swin_transformer where possible: - WindowAttention is a thin subclass that promotes timm's relative_position_index buffer from persistent=False to True so released checkpoints round-trip under strict mode, plus accepts the historical qk_scale kwarg. - window_partition / window_reverse are square-window adapters around the timm helpers. WMSA / _WinBasedAttention / WinNoShiftAttention all take an output_proj=True/False switch: True (default) keeps the Linear projection used by SwinBlock / SWAtten elsewhere in CompressAI; False drops it so the same WinNoShiftAttention class serves the STF / WACNN topology in this PR (which has no projection there) without a separate private copy. Root compressai/layers/__init__.py only appends from .attn import * so existing call sites keep working.
Channel-conditional slice-entropy machinery shared by STF and WACNN in this PR, factored out so other slice-conditional models added in follow-up PRs (CCA, TCM, ...) can reuse it. - compressai/latent_codecs/channel_slice.py: ChannelSliceLatentCodec implements equal-sized channel slicing (Minnen2020 / He2022) — cc_mean_transforms / cc_scale_transforms per slice + LRP head + optional mean / scale support transforms. Sibling of the existing ChannelGroupsLatentCodec. - compressai/models/_bases/slice_entropy.py: SliceEntropyCompressionModel collects the recurring "build entropy bottleneck for z plus a ChannelSliceLatentCodec for y" plumbing used by every channel-slice model, plus the from_state_dict helpers (infer_num_slices, infer_max_support_slices, slice/lrp support channel arithmetic, make_entropy_transform). Subclasses populate g_a / g_s / h_a / h_mean_s / h_scale_s, then call self._init_slice_entropy(...). - compressai/models/_bases/__init__.py: re-export surface. - compressai/latent_codecs/__init__.py: export ChannelSliceLatentCodec.
…al. 2022 Adds the WACNN (CNN backbone) and SymmetricalTransFormer (transformer backbone) from R. Zou, C. Song, Z. Zhang, "The Devil Is in the Details: Window-based Attention for Image Compression", CVPR 2022 (https://arxiv.org/abs/2203.08450). Adapted from the official implementation at https://github.com/Googolxx/STF (Apache-2.0). What is included: - compressai/models/stf.py: WACNN, SymmetricalTransFormer, and a convert_upstream_stf_state_dict helper that strips the DataParallel module. prefix and re-roots cc_mean_transforms / cc_scale_transforms / lrp_transforms / gaussian_conditional under latent_codec.* so released checkpoints from the upstream repo load via WACNN.from_state_dict / SymmetricalTransFormer.from_state_dict. Builds on the WinNoShiftAttention (output_proj=False) primitive from the previous "feat(layers)" commit and uses timm.models.swin_transformer.SwinTransformerBlock with always_partition=True / dynamic_mask=True for the transformer stages (with the per-block relative_position_index buffer promoted to persistent so the upstream key list survives strict load). - compressai/models/__init__.py: export the two new model classes. - compressai/zoo/{__init__,image}.py: register "stf" and "stf-wacnn" in image_models with thin pretrained=False factory functions; pretrained=True raises a clear RuntimeError until weights are hosted on S3 by the maintainers (per InterDigitalInc#353). - examples/convert_stf_checkpoint.py: CLI wrapper around the upstream-state-dict converter, with an optional smoke test on a synthetic image. - tests/test_models.py: TestStf class — forward + state_dict round-trip for both backbones, plus a unit test for the convert_upstream_stf_state_dict helper. - pyproject.toml: add timm to the runtime dependencies (used by the Swin building blocks committed earlier in this PR for DropPath / Mlp / SwinTransformerBlock / WindowAttention). Pretrained weights are intentionally not bundled. State-dict round-trip diff is 0.0 for both WACNN (405 keys) and STF (315 keys); pytest tests/test_models.py tests/test_layers.py tests/test_init.py = 32 passed (3 new TestStf + 29 existing).
Author
|
I have reused the swin-transformer code from |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds WACNN and SymmetricalTransFormer (STF) from R. Zou, C. Song, Z. Zhang, "The Devil Is in the Details: Window-based Attention for Image Compression", CVPR 2022 (arXiv:2203.08450).
Adapted from the official implementation at https://github.com/Googolxx/STF (Apache-2.0).
This is the first installment of the per-model PR series proposed in #353. Pretrained weights are intentionally not bundled — calling
pretrained=Trueraises a clearRuntimeErroruntil weights are hosted on S3 (per the discussion in #353).Summary
"stf"and"stf-wacnn"(compressai.models.SymmetricalTransFormerandcompressai.models.WACNN).compressai.layers.attnsubpackage with the Swin window-based attention building blocks the two models depend on. Reusestimm.models.swin_transformerwherever the implementation is generic to avoid vendoring a parallel Swin stack — see Reuse of timm below.ChannelSliceLatentCodec+SliceEntropyCompressionModelbase — designed to be reused by the channel-conditional models in follow-up PRs (CCA, TCM, …).examples/convert_stf_checkpoint.pythat loads the publishedstf_<bpp>_best.pth.tar/cnn_<bpp>_best.pth.tarfiles from the upstream repo and writes them in compressai layout.timmadded todependencies(the Swin building blocks reuseDropPath,Mlp,trunc_normal_,WindowAttention,SwinTransformerBlock,window_partition,window_reversefrom it).Reuse of timm
Rather than vendor a full Swin stack inside CompressAI, the implementation in this PR delegates to
timm.models.swin_transformereverywhere the upstream STF code matches the Swin reference. This kept the diff focused on the genuinely STF-specific pieces and shaved ~280 lines from an earlier vendored draft.WindowAttentiontimm.models.swin_transformer.WindowAttentionthat promotes therelative_position_indexbuffer frompersistent=FalsetoTrue(so released checkpoints load under strict mode) and accepts the historicalqk_scalekwarg. ~15 lines instead of a ~50-line reimplementation.SwinTransformerBlock(used inside_STFBasicLayer)timm.models.swin_transformer.SwinTransformerBlock(always_partition=True, dynamic_mask=True)directly. After construction we promote each block'sattn.relative_position_indexto persistent so per-block keys round-trip strict-mode. Avoids reimplementing the cyclic-shift / pad / window-attn / unpad / unshift forward path.window_partition/window_reverseTuple[int, int]whereas STF passesint.DropPath,Mlp,trunc_normal_timm.layersversions used directly.WMSA/WinNoShiftAttention(the STF-specific dual-branch sigmoid-gated attention block)output_proj=True/Falseso a single class serves both the STF / WACNN topology in this PR (no projection) and the projection-bearing variant used by other window-attention CompressAI models. No private_STF*duplicate is kept.SwinBlock,SWAtten,ConvTransBlock,_PatchEmbed,_WinBasedAttention,WinResidualUnit,pad_to_window_multiple,build_window_attention_mask)timmdoes not expose an equivalent), so vendoring keeps the API stable acrosstimmreleases.The dependency on
timm.models.swin_transformer.*is deliberate (the file lives undertimm.models.*rather thantimm.layers.*, so it is not part of timm's stability promise). If maintainers prefer to insulate CompressAI fromtimmmodel-internals, the subclass / wrapper pattern makes it a small, self-contained ~120-line revert. Happy to do that on request.Commits
Three commits, designed to be reviewed independently:
feat(layers): add Swin window-based attention building blockscompressai/layers/attn/{swin,inference,__init__}.py+ tiny re-export inlayers/__init__.pyfeat(latent_codecs): add ChannelSliceLatentCodec + slice-entropy basecompressai/latent_codecs/channel_slice.py+compressai/models/_bases/{slice_entropy,__init__}.py+ re-exportfeat(models): add WACNN and SymmetricalTransFormer (STF) from Zou et al. 2022compressai/models/stf.py+ zoo / converter / smoke tests +timminpyproject.tomlLicense & attribution
compressai/models/stf.pycarries a dual-license header noting the upstream source URL and Apache-2.0 license alongside the standard InterDigital BSD 3-Clause Clear License for the modifications. The Swin building blocks incompressai/layers/attn/swin.pyare a mix of timm subclasses / wrappers (covered by timm's Apache-2.0) and STF-derived classes (also Apache-2.0); happy to add per-file attribution headers there as well if maintainers prefer.Verified
pytest tests/test_models.py tests/test_layers.py tests/test_init.py→ 32 passed (3 newTestStf+ 29 existing).WACNN.from_state_dict(model.state_dict())round-trip →x_hatdiff = 0.0 (405 keys).SymmetricalTransFormer.from_state_dict(model.state_dict())round-trip →x_hatdiff = 0.0 (315 keys).convert_upstream_stf_state_dictcorrectly re-rootsmodule.cc_*/module.gaussian_conditionalkeys underlatent_codec.*so the publishedGoogolxx/STFcheckpoints load viafrom_state_dict.Test plan
TestStf).examples/convert_stf_checkpoint.pyagainst an upstreamcnn_<bpp>_best.pth.tarcheckpoint locally (x_hatdiff = 0 between original and converted state dict in eval mode).timmbeing moved into harddependenciesis acceptable (alternative: keep[stf]extras group).timm.models.swin_transformer.*(model-internal API) is acceptable, vs. vendoring a CompressAI copy. Reverting is a small isolated change if preferred.models/stf.py), I will add them.Notes for follow-up PRs (per #353)
The next PR will add CCA + TCM together — both reuse
ChannelSliceLatentCodecfrom this PR, and CCA contributes aCausalContextAdjustmentEntropyModelthat TCM can opt into. After that, the remaining license-clear models (InvCompress,MLIC++,HPCM,SAAF,DCAE,GLIC,TIC,TinyLIC,ShiftLIC) follow one or two at a time, each PR layering on top of what's already merged.