diff --git a/compressai/latent_codecs/__init__.py b/compressai/latent_codecs/__init__.py index ceada0b1..82a41947 100644 --- a/compressai/latent_codecs/__init__.py +++ b/compressai/latent_codecs/__init__.py @@ -29,6 +29,7 @@ from .base import LatentCodec from .channel_groups import ChannelGroupsLatentCodec +from .channel_slice import ChannelSliceLatentCodec from .checkerboard import CheckerboardLatentCodec from .entropy_bottleneck import EntropyBottleneckLatentCodec from .gain import GainHyperLatentCodec, GainHyperpriorLatentCodec @@ -40,6 +41,7 @@ __all__ = [ "LatentCodec", "ChannelGroupsLatentCodec", + "ChannelSliceLatentCodec", "CheckerboardLatentCodec", "EntropyBottleneckLatentCodec", "GainHyperLatentCodec", diff --git a/compressai/latent_codecs/channel_slice.py b/compressai/latent_codecs/channel_slice.py new file mode 100644 index 00000000..888cd5a8 --- /dev/null +++ b/compressai/latent_codecs/channel_slice.py @@ -0,0 +1,263 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.ans import BufferedRansEncoder, RansDecoder +from compressai.entropy_models import GaussianConditional +from compressai.ops import quantize_ste +from compressai.registry import register_module + +from .base import LatentCodec + +__all__ = [ + "ChannelSliceLatentCodec", +] + + +@register_module("ChannelSliceLatentCodec") +class ChannelSliceLatentCodec(LatentCodec): + """Channel-conditional entropy model with separate scale/mean heads and LRP. + + Splits ``y`` into equal-sized slices along the channel axis. For each + slice ``k`` the previously decoded slices (truncated to + ``max_support_slices``) are concatenated with ``latent_means`` / + ``latent_scales`` and pushed through ``cc_mean_transforms[k]`` and + ``cc_scale_transforms[k]`` to obtain ``mu`` / ``scale``. After the + Gaussian conditional step, an optional latent residual prediction + (LRP) head refines ``y_hat``. + + This is the channel-autoregressive entropy model from [Minnen2020] + with the LRP refinement variant used in [Zhu2022] (STF / WACNN), + [He2022] (ELIC) and many follow-up papers (MLIC++, TCM, ...). + + [Minnen2020]: `"Channel-wise Autoregressive Entropy Models for + Learned Image Compression" `_, by + David Minnen and Saurabh Singh, ICIP 2020. + + [Zhu2022]: `"Transformer-based Transform Coding" + `_, by Yinhao Zhu, + Yang Yang and Taco Cohen, ICLR 2022. + """ + + cc_mean_transforms: nn.ModuleList + cc_scale_transforms: nn.ModuleList + lrp_transforms: nn.ModuleList + gaussian_conditional: GaussianConditional + + def __init__( + self, + cc_mean_transforms: nn.ModuleList, + cc_scale_transforms: nn.ModuleList, + lrp_transforms: Optional[nn.ModuleList] = None, + gaussian_conditional: Optional[GaussianConditional] = None, + mean_support_transforms: Optional[nn.ModuleList] = None, + scale_support_transforms: Optional[nn.ModuleList] = None, + *, + num_slices: Optional[int] = None, + max_support_slices: int = -1, + quantizer: str = "ste", + lrp_scale: float = 0.5, + **kwargs: Any, + ) -> None: + super().__init__() + self._kwargs = kwargs + + inferred_num_slices = len(cc_mean_transforms) + if num_slices is None: + num_slices = inferred_num_slices + if inferred_num_slices != num_slices: + raise ValueError( + "cc_mean_transforms must have num_slices entries " + f"(got {inferred_num_slices}, expected {num_slices})" + ) + if len(cc_scale_transforms) != num_slices: + raise ValueError("cc_scale_transforms must have num_slices entries") + if lrp_transforms is not None and len(lrp_transforms) != num_slices: + raise ValueError("lrp_transforms must have num_slices entries") + if mean_support_transforms is not None and len(mean_support_transforms) != num_slices: + raise ValueError("mean_support_transforms must have num_slices entries") + if scale_support_transforms is not None and len(scale_support_transforms) != num_slices: + raise ValueError("scale_support_transforms must have num_slices entries") + if quantizer not in ("ste", "noise"): + raise ValueError(f"unknown quantizer {quantizer!r}") + + self.num_slices = int(num_slices) + self.max_support_slices = int(max_support_slices) + self.quantizer = quantizer + self.lrp_scale = float(lrp_scale) + self.cc_mean_transforms = cc_mean_transforms + self.cc_scale_transforms = cc_scale_transforms + self.mean_support_transforms = mean_support_transforms or nn.ModuleList( + nn.Identity() for _ in range(num_slices) + ) + self.scale_support_transforms = scale_support_transforms or nn.ModuleList( + nn.Identity() for _ in range(num_slices) + ) + self.lrp_transforms = lrp_transforms or nn.ModuleList( + nn.Identity() for _ in range(num_slices) + ) + self.gaussian_conditional = gaussian_conditional or GaussianConditional(None) + + def _support_slices(self, y_hat_slices: Sequence[Tensor]) -> List[Tensor]: + if self.max_support_slices < 0: + return list(y_hat_slices) + return list(y_hat_slices[: self.max_support_slices]) + + def _slice_params( + self, + slice_index: int, + latent_means: Tensor, + latent_scales: Tensor, + y_hat_slices: Sequence[Tensor], + spatial_shape: Tuple[int, int], + ) -> Tuple[Tensor, Tensor, Tensor]: + support = self._support_slices(y_hat_slices) + mean_support = torch.cat([latent_means, *support], dim=1) + mean_support = self.mean_support_transforms[slice_index](mean_support) + mu = self.cc_mean_transforms[slice_index](mean_support) + mu = mu[:, :, : spatial_shape[0], : spatial_shape[1]] + scale_support = torch.cat([latent_scales, *support], dim=1) + scale_support = self.scale_support_transforms[slice_index](scale_support) + scale = self.cc_scale_transforms[slice_index](scale_support) + scale = scale[:, :, : spatial_shape[0], : spatial_shape[1]] + return mu, scale, mean_support + + def _apply_lrp( + self, slice_index: int, mean_support: Tensor, y_hat_slice: Tensor + ) -> Tensor: + lrp = self.lrp_transforms[slice_index]( + torch.cat([mean_support, y_hat_slice], dim=1) + ) + return y_hat_slice + self.lrp_scale * torch.tanh(lrp) + + def forward( + self, + y: Tensor, + latent_means: Tensor, + latent_scales: Tensor, + ) -> Dict[str, Any]: + spatial_shape = (y.shape[2], y.shape[3]) + y_hat_slices: List[Tensor] = [] + y_likelihoods_slices: List[Tensor] = [] + + for slice_index, y_slice in enumerate(y.chunk(self.num_slices, dim=1)): + mu, scale, mean_support = self._slice_params( + slice_index, latent_means, latent_scales, y_hat_slices, spatial_shape + ) + _, y_slice_likelihoods = self.gaussian_conditional( + y_slice, scale, means=mu + ) + if self.quantizer == "ste": + y_hat_slice = quantize_ste(y_slice - mu) + mu + else: + y_hat_slice = self.gaussian_conditional.quantize( + y_slice, "noise" if self.training else "dequantize", mu + ) + y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) + y_hat_slices.append(y_hat_slice) + y_likelihoods_slices.append(y_slice_likelihoods) + + return { + "y_hat": torch.cat(y_hat_slices, dim=1), + "likelihoods": {"y": torch.cat(y_likelihoods_slices, dim=1)}, + } + + def compress( + self, + y: Tensor, + latent_means: Tensor, + latent_scales: Tensor, + ) -> Dict[str, Any]: + spatial_shape = (y.shape[2], y.shape[3]) + cdf = self.gaussian_conditional.quantized_cdf.tolist() + cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() + offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() + encoder = BufferedRansEncoder() + symbols_list: List[int] = [] + indexes_list: List[int] = [] + y_hat_slices: List[Tensor] = [] + + for slice_index, y_slice in enumerate(y.chunk(self.num_slices, dim=1)): + mu, scale, mean_support = self._slice_params( + slice_index, latent_means, latent_scales, y_hat_slices, spatial_shape + ) + indexes = self.gaussian_conditional.build_indexes(scale) + y_q_slice = self.gaussian_conditional.quantize(y_slice, "symbols", mu) + y_hat_slice = y_q_slice + mu + symbols_list.extend(y_q_slice.reshape(-1).tolist()) + indexes_list.extend(indexes.reshape(-1).tolist()) + y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) + y_hat_slices.append(y_hat_slice) + + encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_lengths, offsets) + return { + "strings": [encoder.flush()], + "shape": spatial_shape, + "y_hat": torch.cat(y_hat_slices, dim=1), + } + + def decompress( + self, + strings: Sequence[bytes], + shape: Tuple[int, int], + latent_means: Tensor, + latent_scales: Tensor, + **kwargs: Any, + ) -> Dict[str, Any]: + cdf = self.gaussian_conditional.quantized_cdf.tolist() + cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() + offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() + decoder = RansDecoder() + decoder.set_stream(strings[0]) + y_hat_slices: List[Tensor] = [] + + for slice_index in range(self.num_slices): + mu, scale, mean_support = self._slice_params( + slice_index, latent_means, latent_scales, y_hat_slices, shape + ) + indexes = self.gaussian_conditional.build_indexes(scale) + values = decoder.decode_stream( + indexes.reshape(-1).tolist(), cdf, cdf_lengths, offsets + ) + y_q_slice = torch.tensor( + values, device=mu.device, dtype=mu.dtype + ).reshape(mu.shape) + y_hat_slice = self.gaussian_conditional.dequantize(y_q_slice, mu) + y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) + y_hat_slices.append(y_hat_slice) + + return {"y_hat": torch.cat(y_hat_slices, dim=1)} diff --git a/compressai/layers/__init__.py b/compressai/layers/__init__.py index 0362981c..cd3e30cd 100644 --- a/compressai/layers/__init__.py +++ b/compressai/layers/__init__.py @@ -30,3 +30,5 @@ from .basic import * from .gdn import * from .layers import * + +from .attn import * diff --git a/compressai/layers/attn/__init__.py b/compressai/layers/attn/__init__.py new file mode 100644 index 00000000..977bba94 --- /dev/null +++ b/compressai/layers/attn/__init__.py @@ -0,0 +1,47 @@ +from .inference import ( + infer_swatten_attention_dim, + infer_swatten_head_dim, + infer_swatten_window_size, +) +from .swin import ( + ConvTransBlock, + PatchMerging, + PatchSplit, + SWAtten, + SwinBlock, + WMSA, + WinNoShiftAttention, + WinResidualUnit, + WindowAttention, + build_window_attention_mask, + pad_to_window_multiple, + window_partition, + window_reverse, +) + +__all__ = [ + "ConvTransBlock", + "PatchMerging", + "PatchSplit", + "SWAtten", + "SwinBlock", + "WMSA", + "WinNoShiftAttention", + "WinResidualUnit", + "WindowAttention", + "build_window_attention_mask", + "infer_swatten_attention_dim", + "infer_swatten_head_dim", + "infer_swatten_window_size", + "pad_to_window_multiple", + "window_partition", + "window_reverse", +] + + +def __getattr__(name): + if name == "Win_noShift_Attention": + from .swin import Win_noShift_Attention as _alias + + return _alias + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/compressai/layers/attn/inference.py b/compressai/layers/attn/inference.py new file mode 100644 index 00000000..0369ab93 --- /dev/null +++ b/compressai/layers/attn/inference.py @@ -0,0 +1,53 @@ +"""State-dict introspection helpers for SWAtten-based context heads. + +Used by MambaIC / MambaVC ``from_state_dict`` to recover ``window_size``, +``head_dim`` and ``inter_dim`` (a.k.a. ``support_attention_dim``) from a +checkpoint. Each model used to ship a private ``_infer_*`` copy of these. +""" +from __future__ import annotations + +import math + +from typing import Dict + +from torch import Tensor + +__all__ = [ + "infer_swatten_attention_dim", + "infer_swatten_head_dim", + "infer_swatten_window_size", +] + + +def infer_swatten_window_size( + state_dict: Dict[str, Tensor], prefix: str, *, default: int = 8 +) -> int: + for key, tensor in state_dict.items(): + if key.startswith(prefix) and key.endswith("relative_position_bias_table"): + return (math.isqrt(tensor.size(0)) + 1) // 2 + return default + + +def infer_swatten_head_dim( + state_dict: Dict[str, Tensor], + prefix: str, + hidden_channels: int, + *, + default: int = 16, +) -> int: + for key, tensor in state_dict.items(): + if key.startswith(prefix) and key.endswith("relative_position_bias_table"): + return hidden_channels // tensor.size(1) + return default + + +def infer_swatten_attention_dim( + state_dict: Dict[str, Tensor], + prefix: str, + *, + default: int = 128, +) -> int: + key = f"{prefix}.in_conv.weight" + if key in state_dict: + return state_dict[key].size(0) + return default diff --git a/compressai/layers/attn/swin.py b/compressai/layers/attn/swin.py new file mode 100644 index 00000000..e99fbfd1 --- /dev/null +++ b/compressai/layers/attn/swin.py @@ -0,0 +1,566 @@ +from __future__ import annotations + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.layers import DropPath, Mlp +from timm.models.swin_transformer import ( + WindowAttention as _TimmWindowAttention, + window_partition as _timm_window_partition, + window_reverse as _timm_window_reverse, +) +from torch import Tensor + +from ..layers import AttentionBlock, ResidualBlock, conv1x1, conv3x3 + +__all__ = [ + "ConvTransBlock", + "PatchMerging", + "PatchSplit", + "SWAtten", + "SwinBlock", + "WMSA", + "WinNoShiftAttention", + "WinResidualUnit", + "WindowAttention", + "build_window_attention_mask", + "pad_to_window_multiple", + "window_partition", + "window_reverse", +] + + +def window_partition(input_tensor: Tensor, window_size: int) -> Tensor: + """Square-window adapter around timm's ``window_partition``. + + timm uses ``Tuple[int, int]`` for the window size; the STF / WACNN models + in compressai always use square windows, so this thin wrapper keeps the + ``window_size: int`` call-site convention while delegating to timm. + """ + return _timm_window_partition(input_tensor, (window_size, window_size)) + + +def window_reverse( + windows: Tensor, + window_size: int, + height: int, + width: int, +) -> Tensor: + """Square-window adapter around timm's ``window_reverse`` (see + :func:`window_partition` for the rationale).""" + return _timm_window_reverse(windows, (window_size, window_size), height, width) + + +def build_window_attention_mask( + height: int, + width: int, + window_size: int, + shift_size: int, + device: torch.device, +) -> Optional[Tensor]: + if shift_size == 0: + return None + + img_mask = torch.zeros((1, height, width, 1), device=device) + h_slices = ( + slice(0, -window_size), + slice(-window_size, -shift_size), + slice(-shift_size, None), + ) + w_slices = ( + slice(0, -window_size), + slice(-window_size, -shift_size), + slice(-shift_size, None), + ) + + count = 0 + for h_index in h_slices: + for w_index in w_slices: + img_mask[:, h_index, w_index, :] = count + count += 1 + + mask_windows = window_partition(img_mask, window_size) + mask_windows = mask_windows.view(-1, window_size * window_size) + attention_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attention_mask = attention_mask.masked_fill(attention_mask != 0, float(-100.0)) + return attention_mask.masked_fill(attention_mask == 0, float(0.0)) + + +def pad_to_window_multiple( + input_tensor: Tensor, + window_size: Union[int, Tuple[int, int]], + *, + layout: str = "BCHW", +) -> Tuple[Tensor, int, int]: + """Right/bottom-pad a 4D tensor so its spatial dims are multiples of + ``window_size``. + + Args: + input_tensor: 4D tensor in either ``BCHW`` or ``BHWC`` layout. + window_size: ``int`` (square window) or ``(window_h, window_w)``. + layout: ``"BCHW"`` (default, PyTorch convention) or ``"BHWC"`` + (Swin / FTIC token-major layout). + + Returns: + ``(padded_tensor, pad_h, pad_w)``, where ``pad_h`` / ``pad_w`` are + the bottom / right padding widths added to the height / width + dimension respectively. + """ + if isinstance(window_size, int): + win_h = win_w = int(window_size) + else: + win_h, win_w = (int(s) for s in window_size) + + if layout == "BCHW": + height, width = input_tensor.shape[-2], input_tensor.shape[-1] + elif layout == "BHWC": + height, width = input_tensor.shape[1], input_tensor.shape[2] + else: + raise ValueError(f"layout must be 'BCHW' or 'BHWC', got {layout!r}") + + pad_h = (win_h - height % win_h) % win_h + pad_w = (win_w - width % win_w) % win_w + if pad_h == 0 and pad_w == 0: + return input_tensor, 0, 0 + + if layout == "BCHW": + # F.pad on BCHW: (W_left, W_right, H_left, H_right) + return F.pad(input_tensor, (0, pad_w, 0, pad_h)), pad_h, pad_w + # F.pad on BHWC: (C_left, C_right, W_left, W_right, H_left, H_right) + return F.pad(input_tensor, (0, 0, 0, pad_w, 0, pad_h)), pad_h, pad_w + + +class WindowAttention(_TimmWindowAttention): + """timm ``WindowAttention`` with two minor tweaks for compressai: + + 1. ``relative_position_index`` is re-registered as a *persistent* buffer + so released compressai checkpoints (which include this tensor) load + under ``strict=True``. timm registers it as ``persistent=False``. + 2. The constructor accepts an optional ``qk_scale`` to keep STF's + (and CompressAI's) call-site convention; timm always derives the + scale from ``head_dim``. + + Forward / state-dict layout otherwise match timm exactly, including + the optional fused-attention path. + """ + + def __init__( + self, + dim: int, + window_size: int, + num_heads: int, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__( + dim=dim, + num_heads=num_heads, + window_size=window_size, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + if qk_scale is not None: + self.scale = qk_scale + # Promote the index buffer to persistent so checkpoint round-trip + # works without filtering keys at load time. + index = self.relative_position_index + del self._buffers["relative_position_index"] + self.register_buffer("relative_position_index", index, persistent=True) + + +class WMSA(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: Optional[int], + head_dim: int, + window_size: int, + type: str = "W", + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + output_proj: bool = True, + ) -> None: + super().__init__() + if type not in {"W", "SW"}: + raise ValueError(f"Unsupported attention type: {type}") + if input_dim % head_dim != 0: + raise ValueError("`input_dim` must be divisible by `head_dim`.") + + self.window_size = window_size + self.shift_size = 0 if type == "W" else window_size // 2 + self.attn = WindowAttention( + dim=input_dim, + window_size=window_size, + num_heads=input_dim // head_dim, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + # ``output_proj=False`` mirrors the STF / WACNN topology, which feeds + # the WindowAttention output straight back into the downstream block + # without an extra Linear projection. Set ``True`` (default) for the + # SwinBlock / SWAtten variant used by the rest of CompressAI. + self.output_proj = ( + nn.Linear(input_dim, output_dim or input_dim) + if output_proj + else nn.Identity() + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + _, height, width, _ = input_tensor.shape + output, pad_height, pad_width = pad_to_window_multiple( + input_tensor, + self.window_size, + layout="BHWC", + ) + padded_height, padded_width = output.shape[1], output.shape[2] + + if self.shift_size > 0: + mask = build_window_attention_mask( + padded_height, + padded_width, + self.window_size, + self.shift_size, + output.device, + ) + output = torch.roll( + output, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2), + ) + else: + mask = None + + windows = window_partition(output, self.window_size) + windows = windows.view( + -1, + self.window_size * self.window_size, + windows.shape[-1], + ) + windows = self.attn(windows, mask=mask) + windows = windows.view( + -1, + self.window_size, + self.window_size, + windows.shape[-1], + ) + output = window_reverse(windows, self.window_size, padded_height, padded_width) + + if self.shift_size > 0: + output = torch.roll( + output, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2), + ) + if pad_height > 0 or pad_width > 0: + output = output[:, :height, :width, :].contiguous() + return self.output_proj(output) + + +class Block(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: Optional[int], + head_dim: int, + window_size: int, + drop_path: float, + type: str = "W", + mlp_ratio: float = 4.0, + ) -> None: + super().__init__() + output_dim = output_dim or input_dim + self.norm1 = nn.LayerNorm(input_dim) + self.msa = WMSA(input_dim, input_dim, head_dim, window_size, type=type) + self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() + self.norm2 = nn.LayerNorm(input_dim) + self.mlp = Mlp( + in_features=input_dim, + hidden_features=int(input_dim * mlp_ratio), + out_features=output_dim, + ) + self.residual_proj = ( + nn.Linear(input_dim, output_dim) + if input_dim != output_dim + else nn.Identity() + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = input_tensor + self.drop_path(self.msa(self.norm1(input_tensor))) + residual = self.residual_proj(output) + return residual + self.drop_path(self.mlp(self.norm2(output))) + + +class SwinBlock(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: Optional[int], + head_dim: int, + window_size: int, + drop_path: float, + mlp_ratio: float = 4.0, + ) -> None: + super().__init__() + output_dim = output_dim or input_dim + self.block_1 = Block( + input_dim, + input_dim, + head_dim, + window_size, + drop_path, + type="W", + mlp_ratio=mlp_ratio, + ) + self.block_2 = Block( + input_dim, + output_dim, + head_dim, + window_size, + drop_path, + type="SW", + mlp_ratio=mlp_ratio, + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = input_tensor.permute(0, 2, 3, 1).contiguous() + output = self.block_1(output) + output = self.block_2(output) + return output.permute(0, 3, 1, 2).contiguous() + + +class ConvTransBlock(nn.Module): + def __init__( + self, + conv_dim: int, + trans_dim: int, + head_dim: int, + window_size: int, + drop_path: float, + type: str = "W", + mlp_ratio: float = 4.0, + ) -> None: + super().__init__() + if type not in {"W", "SW"}: + raise ValueError(f"Unsupported attention type: {type}") + + self.conv_dim = conv_dim + self.trans_dim = trans_dim + self.conv1_1 = nn.Conv2d(conv_dim + trans_dim, conv_dim + trans_dim, 1) + self.conv1_2 = nn.Conv2d(conv_dim + trans_dim, conv_dim + trans_dim, 1) + self.conv_block = ResidualBlock(conv_dim, conv_dim) + self.trans_block = Block( + trans_dim, + trans_dim, + head_dim, + window_size, + drop_path, + type=type, + mlp_ratio=mlp_ratio, + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + mixed = self.conv1_1(input_tensor) + conv_tensor, trans_tensor = torch.split( + mixed, + (self.conv_dim, self.trans_dim), + dim=1, + ) + conv_tensor = self.conv_block(conv_tensor) + conv_tensor + trans_tensor = trans_tensor.permute(0, 2, 3, 1).contiguous() + trans_tensor = self.trans_block(trans_tensor) + trans_tensor = trans_tensor.permute(0, 3, 1, 2).contiguous() + output = torch.cat((conv_tensor, trans_tensor), dim=1) + return input_tensor + self.conv1_2(output) + + +class SWAtten(AttentionBlock): + def __init__( + self, + input_dim: int, + output_dim: int, + head_dim: int, + window_size: int, + drop_path: float, + inter_dim: Optional[int] = 192, + ) -> None: + hidden_dim = inter_dim or input_dim + super().__init__(N=hidden_dim) + self.in_conv = ( + conv1x1(input_dim, hidden_dim) if inter_dim is not None else nn.Identity() + ) + self.out_conv = ( + conv1x1(hidden_dim, output_dim) if inter_dim is not None else nn.Identity() + ) + self.non_local_block = SwinBlock( + hidden_dim, + hidden_dim, + head_dim, + window_size, + drop_path, + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.in_conv(input_tensor) + identity = output + non_local = self.non_local_block(output) + output = self.conv_a(output) * torch.sigmoid(self.conv_b(non_local)) + output = output + identity + return self.out_conv(output) + + +class WinResidualUnit(nn.Module): + """1x1 -> 3x3 -> 1x1 GELU residual unit; bottleneck width is half the + input channels. Used inside :class:`WinNoShiftAttention`.""" + + def __init__(self, channels: int) -> None: + super().__init__() + self.conv = nn.Sequential( + conv1x1(channels, channels // 2), + nn.GELU(), + conv3x3(channels // 2, channels // 2), + nn.GELU(), + conv1x1(channels // 2, channels), + ) + self.act = nn.GELU() + + def forward(self, input_tensor: Tensor) -> Tensor: + return self.act(self.conv(input_tensor) + input_tensor) + + +class _WinBasedAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + window_size: int, + shift_size: int, + drop_path: float, + output_proj: bool = True, + ) -> None: + super().__init__() + attention_type = "SW" if shift_size > 0 else "W" + self.attn = WMSA( + input_dim=dim, + output_dim=dim, + head_dim=dim // num_heads, + window_size=window_size, + type=attention_type, + output_proj=output_proj, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() + + def forward(self, input_tensor: Tensor) -> Tensor: + output = input_tensor.permute(0, 2, 3, 1).contiguous() + output = self.attn(output) + output = output.permute(0, 3, 1, 2).contiguous() + return input_tensor + self.drop_path(output) + + +class WinNoShiftAttention(nn.Module): + """Sigmoid-gated dual-branch window attention block, used by STF / WACNN + and (with ``output_proj=True``) by other window-attention CompressAI + models. ``output_proj=False`` reproduces the STF / WACNN topology in which + the WindowAttention output feeds straight back into the block without + an additional Linear projection.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + window_size: int = 8, + shift_size: int = 0, + drop_path: float = 0.0, + output_proj: bool = True, + ) -> None: + super().__init__() + self.conv_a = nn.Sequential( + WinResidualUnit(dim), + WinResidualUnit(dim), + WinResidualUnit(dim), + ) + self.conv_b = nn.Sequential( + _WinBasedAttention( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=shift_size, + drop_path=drop_path, + output_proj=output_proj, + ), + WinResidualUnit(dim), + WinResidualUnit(dim), + WinResidualUnit(dim), + conv1x1(dim, dim), + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + return input_tensor + self.conv_a(input_tensor) * torch.sigmoid( + self.conv_b(input_tensor) + ) + + +class PatchMerging(nn.Module): + def __init__(self, dim: int, norm_layer: type[nn.Module] = nn.LayerNorm) -> None: + super().__init__() + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, input_tensor: Tensor, height: int, width: int) -> Tensor: + batch_size, length, channels = input_tensor.shape + if length != height * width: + raise ValueError("Input feature has wrong size.") + + output = input_tensor.view(batch_size, height, width, channels) + if height % 2 == 1 or width % 2 == 1: + output = F.pad(output, (0, 0, 0, width % 2, 0, height % 2)) + + x0 = output[:, 0::2, 0::2, :] + x1 = output[:, 1::2, 0::2, :] + x2 = output[:, 0::2, 1::2, :] + x3 = output[:, 1::2, 1::2, :] + output = torch.cat([x0, x1, x2, x3], dim=-1) + output = output.view(batch_size, -1, 4 * channels) + return self.reduction(self.norm(output)) + + +class PatchSplit(nn.Module): + def __init__(self, dim: int, norm_layer: type[nn.Module] = nn.LayerNorm) -> None: + super().__init__() + self.reduction = nn.Linear(dim, dim * 2, bias=False) + self.norm = norm_layer(dim) + self.shuffle = nn.PixelShuffle(2) + + def forward(self, input_tensor: Tensor, height: int, width: int) -> Tensor: + batch_size, length, channels = input_tensor.shape + if length != height * width: + raise ValueError("Input feature has wrong size.") + + output = self.reduction(self.norm(input_tensor)) + output = output.permute(0, 2, 1).contiguous() + output = output.view(batch_size, 2 * channels, height, width) + output = self.shuffle(output) + output = output.permute(0, 2, 3, 1).contiguous() + return output.view(batch_size, 4 * length, -1) + + +def __getattr__(name): + if name == "Win_noShift_Attention": + import warnings + + warnings.warn( + "Win_noShift_Attention is deprecated; use WinNoShiftAttention instead.", + DeprecationWarning, + stacklevel=2, + ) + return WinNoShiftAttention + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/compressai/models/__init__.py b/compressai/models/__init__.py index 79112b89..2899722c 100644 --- a/compressai/models/__init__.py +++ b/compressai/models/__init__.py @@ -31,5 +31,6 @@ from .google import * from .pointcloud import * from .sensetime import * +from .stf import * from .vbr import * from .waseda import * diff --git a/compressai/models/_bases/__init__.py b/compressai/models/_bases/__init__.py new file mode 100644 index 00000000..d2520889 --- /dev/null +++ b/compressai/models/_bases/__init__.py @@ -0,0 +1,23 @@ +"""Abstract base classes shared by multiple slice-based LIC models. + +These were historically hidden behind ``stf_support`` / ``dcae_support`` file +names which obscured the fact that they're real abstract :class:`CompressionModel` +subclasses inherited by 3-4 models each. +""" +from .slice_entropy import ( + SliceEntropyCompressionModel, + infer_max_support_slices, + infer_num_slices, + lrp_support_channels, + make_entropy_transform, + slice_support_channels, +) + +__all__ = [ + "SliceEntropyCompressionModel", + "infer_max_support_slices", + "infer_num_slices", + "lrp_support_channels", + "make_entropy_transform", + "slice_support_channels", +] diff --git a/compressai/models/_bases/slice_entropy.py b/compressai/models/_bases/slice_entropy.py new file mode 100644 index 00000000..ffe4c4e2 --- /dev/null +++ b/compressai/models/_bases/slice_entropy.py @@ -0,0 +1,255 @@ +"""Slice-conditional entropy backbone shared by WACNN / SymmetricalTransFormer / MambaVC. + +Promoted out of the historical ``models/stf_support.py`` so the abstract base +class is discoverable by name. Channel-counting helpers and a parameterised +entropy-transform factory live here too — they used to be duplicated across +``stf_support`` / ``ssm_support`` / ``weconvene_support``. +""" +from __future__ import annotations + +from typing import Dict, Optional, Sequence, Tuple + +import torch.nn as nn + +from torch import Tensor + +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import ChannelSliceLatentCodec +from compressai.models.utils import conv + +from ..base import CompressionModel + +__all__ = [ + "SliceEntropyCompressionModel", + "infer_max_support_slices", + "infer_num_slices", + "lrp_support_channels", + "make_entropy_transform", + "slice_support_channels", +] + + +_DEFAULT_NUM_SLICES_PREFIX = "latent_codec.cc_mean_transforms." +_KEY_SUFFIX = ".0.weight" + + +def slice_support_channels( + latent_channels: int, + slice_channels: int, + index: int, + max_support_slices: int, +) -> int: + if max_support_slices < 0: + return latent_channels + slice_channels * index + return latent_channels + slice_channels * min(index, max_support_slices) + + +def lrp_support_channels( + latent_channels: int, + slice_channels: int, + index: int, + max_support_slices: int, +) -> int: + if max_support_slices < 0: + return latent_channels + slice_channels * (index + 1) + return latent_channels + slice_channels * min(index + 1, max_support_slices + 1) + + +def make_entropy_transform( + in_channels: int, + out_channels: int, + *, + widths: Sequence[int] = (224, 128), +) -> nn.Sequential: + """Stack of stride-1 3x3 convs with GELU between, used by every slice + entropy model. ``widths`` specifies hidden conv widths; defaults to the + Mamba/WeConvene 3-conv stack. Pass ``widths=(224, 176, 128, 64)`` for the + STF/WACNN 5-conv stack.""" + layers: list[nn.Module] = [] + prev = in_channels + for width in widths: + layers.append(conv(prev, width, stride=1, kernel_size=3)) + layers.append(nn.GELU()) + prev = width + layers.append(conv(prev, out_channels, stride=1, kernel_size=3)) + return nn.Sequential(*layers) + + +def infer_num_slices( + state_dict: Dict[str, Tensor], + *, + prefix: str = _DEFAULT_NUM_SLICES_PREFIX, + suffix: str = _KEY_SUFFIX, +) -> int: + slice_indices = { + int(key[len(prefix) :].split(".", 1)[0]) + for key in state_dict + if key.startswith(prefix) and key.endswith(suffix) + } + return len(slice_indices) + + +def infer_max_support_slices( + state_dict: Dict[str, Tensor], + latent_channels: int, + num_slices: int, + *, + prefix: str = _DEFAULT_NUM_SLICES_PREFIX, + suffix: str = _KEY_SUFFIX, + extra_factor: int = 1, +) -> int: + """Infer ``max_support_slices`` from the input width of the first + cc_mean transform conv. ``extra_factor`` accounts for models like DCAE/SAAF + that prepend additional copies of the latent (``M*3 + slice_channels*N``); + pass ``extra_factor=3`` there. Slice-only models (STF/Mamba*) keep the + default ``extra_factor=1``.""" + slice_channels = latent_channels // num_slices + matching = [ + tensor.size(1) + for key, tensor in state_dict.items() + if key.startswith(prefix) and key.endswith(suffix) + ] + if not matching: + return 0 + max_input_channels = max(matching) + return max(0, (max_input_channels - extra_factor * latent_channels) // slice_channels) + + +class SliceEntropyCompressionModel(CompressionModel): + """Channel-conditional entropy backbone shared by WACNN, SymmetricalTransFormer, MambaVC. + + Subclasses must populate ``g_a``, ``g_s``, ``h_a``, ``h_mean_s`` and + ``h_scale_s``, then call :meth:`_init_slice_entropy` to wire up the + entropy bottleneck for ``z`` and the :class:`ChannelSliceLatentCodec` + for ``y``. + """ + + h_a: nn.Module + h_mean_s: nn.Module + h_scale_s: nn.Module + entropy_bottleneck: EntropyBottleneck + latent_codec: ChannelSliceLatentCodec + + def _init_slice_entropy( + self, + latent_channels: int, + entropy_bottleneck_channels: int, + num_slices: int, + max_support_slices: int, + mean_support_transforms: Optional[nn.ModuleList] = None, + scale_support_transforms: Optional[nn.ModuleList] = None, + ) -> None: + if latent_channels % num_slices != 0: + raise ValueError("latent_channels must be divisible by num_slices") + if ( + mean_support_transforms is not None + and len(mean_support_transforms) != num_slices + ): + raise ValueError("mean_support_transforms must have num_slices entries") + if ( + scale_support_transforms is not None + and len(scale_support_transforms) != num_slices + ): + raise ValueError("scale_support_transforms must have num_slices entries") + + slice_channels = latent_channels // num_slices + widths = (224, 176, 128, 64) + cc_mean_transforms = nn.ModuleList( + make_entropy_transform( + slice_support_channels( + latent_channels, slice_channels, index, max_support_slices + ), + slice_channels, + widths=widths, + ) + for index in range(num_slices) + ) + cc_scale_transforms = nn.ModuleList( + make_entropy_transform( + slice_support_channels( + latent_channels, slice_channels, index, max_support_slices + ), + slice_channels, + widths=widths, + ) + for index in range(num_slices) + ) + lrp_transforms = nn.ModuleList( + make_entropy_transform( + lrp_support_channels( + latent_channels, slice_channels, index, max_support_slices + ), + slice_channels, + widths=widths, + ) + for index in range(num_slices) + ) + + self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels) + self.latent_codec = ChannelSliceLatentCodec( + cc_mean_transforms=cc_mean_transforms, + cc_scale_transforms=cc_scale_transforms, + lrp_transforms=lrp_transforms, + mean_support_transforms=mean_support_transforms, + scale_support_transforms=scale_support_transforms, + num_slices=num_slices, + max_support_slices=max_support_slices, + quantizer="ste", + ) + + @property + def num_slices(self) -> int: + return self.latent_codec.num_slices + + @property + def max_support_slices(self) -> int: + return self.latent_codec.max_support_slices + + def _hyper_priors(self, y: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + z = self.h_a(y) + z_hat, z_likelihoods = self.entropy_bottleneck(z) + latent_means = self.h_mean_s(z_hat) + latent_scales = self.h_scale_s(z_hat) + return z, z_likelihoods, latent_means, latent_scales + + def _forward_latent_output(self, y: Tensor) -> Dict[str, Dict[str, Tensor] | Tensor]: + _, z_likelihoods, latent_means, latent_scales = self._hyper_priors(y) + y_out = self.latent_codec(y, latent_means, latent_scales) + output: Dict[str, Dict[str, Tensor] | Tensor] = { + "y_hat": y_out["y_hat"], + "likelihoods": {"y": y_out["likelihoods"]["y"], "z": z_likelihoods}, + } + return output + + def _forward_latent(self, y: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + output = self._forward_latent_output(y) + return output["y_hat"], output["likelihoods"]["y"], output["likelihoods"]["z"] + + def _compress_latent(self, y: Tensor) -> Dict[str, object]: + z = self.h_a(y) + z_strings = self.entropy_bottleneck.compress(z) + z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) + latent_means = self.h_mean_s(z_hat) + latent_scales = self.h_scale_s(z_hat) + y_out = self.latent_codec.compress(y, latent_means, latent_scales) + return { + "strings": [[y_out["strings"][0]], z_strings], + "shape": z.size()[-2:], + } + + def _decompress_latent( + self, + strings: Sequence[Sequence[bytes]], + shape: Tuple[int, int], + ) -> Tensor: + if len(strings) != 2: + raise ValueError("strings must contain [y_strings, z_strings]") + + z_hat = self.entropy_bottleneck.decompress(strings[1], shape) + latent_means = self.h_mean_s(z_hat) + latent_scales = self.h_scale_s(z_hat) + y_shape = (z_hat.shape[2] * 4, z_hat.shape[3] * 4) + y_out = self.latent_codec.decompress( + strings[0], y_shape, latent_means, latent_scales + ) + return y_out["y_hat"] diff --git a/compressai/models/stf.py b/compressai/models/stf.py new file mode 100644 index 00000000..bfcb15a1 --- /dev/null +++ b/compressai/models/stf.py @@ -0,0 +1,594 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts code from https://github.com/Googolxx/STF +# (originally distributed under the Apache License 2.0). The upstream copyright +# notice is preserved in that repository; modifications by InterDigital +# Communications, Inc. are released under the BSD 3-Clause Clear License +# terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import math + +from typing import Dict, Optional, Sequence, Tuple, Type + +import torch +import torch.nn as nn + +from timm.layers import DropPath, Mlp +from timm.models.swin_transformer import SwinTransformerBlock as _TimmSwinBlock +from torch import Tensor + +from compressai.layers import GDN, conv1x1, conv3x3, subpel_conv3x3 +from compressai.layers.attn import ( + PatchMerging, + PatchSplit, + WinNoShiftAttention, +) +from compressai.models._bases import ( + SliceEntropyCompressionModel, + infer_max_support_slices, + infer_num_slices, +) +from compressai.models.utils import conv, deconv +from compressai.registry import register_model + +__all__ = [ + "SymmetricalTransFormer", + "WACNN", + "convert_upstream_stf_state_dict", +] + + +# ---------------------------------------------------------------------------- +# STF building blocks +# (formerly compressai/layers/lic/stf.py; private to the WACNN / SymmetricalTransFormer models) +# ---------------------------------------------------------------------------- + + +class _STFBasicLayer(nn.Module): + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + window_size: int = 7, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float | Sequence[float] = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + downsample: Optional[Type[nn.Module]] = None, + ) -> None: + del qk_scale # timm SwinTransformerBlock derives scale from head_dim + super().__init__() + drop_path_values = ( + list(drop_path) + if isinstance(drop_path, Sequence) and not isinstance(drop_path, (str, bytes)) + else [float(drop_path)] * depth + ) + self.window_size = window_size + self.shift_size = window_size // 2 + self.blocks = nn.ModuleList( + [ + _TimmSwinBlock( + dim=dim, + input_resolution=(0, 0), # ignored when always_partition=True + num_heads=num_heads, + window_size=window_size, + shift_size=0 if index % 2 == 0 else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=drop, + attn_drop=attn_drop, + drop_path=drop_path_values[index], + norm_layer=norm_layer, + always_partition=True, # keep configured window/shift even if input is small + dynamic_mask=True, + ) + for index in range(depth) + ] + ) + self.downsample = downsample(dim=dim, norm_layer=norm_layer) if downsample else None + + # Released STF checkpoints carry `attn.relative_position_index` per block + # (the upstream WindowAttention registers it as a persistent buffer). + # timm's WindowAttention uses persistent=False, so promote it here so + # strict-mode state_dict loading round-trips without filtering keys. + for block in self.blocks: + index = block.attn.relative_position_index + del block.attn._buffers["relative_position_index"] + block.attn.register_buffer("relative_position_index", index, persistent=True) + + def forward(self, input_tensor: Tensor, height: int, width: int) -> tuple[Tensor, int, int]: + batch_size, length, channels = input_tensor.shape + if length != height * width: + raise ValueError("input feature has wrong size") + x = input_tensor.view(batch_size, height, width, channels) + for block in self.blocks: + x = block(x) + x = x.reshape(batch_size, height * width, channels) + + if self.downsample is None: + return x, height, width + + x = self.downsample(x, height, width) + if isinstance(self.downsample, PatchMerging): + return x, (height + 1) // 2, (width + 1) // 2 + return x, height * 2, width * 2 + + +class _PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 4, + in_chans: int = 3, + embed_dim: int = 96, + norm_layer: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + self.patch_size = (patch_size, patch_size) + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + self.norm = norm_layer(embed_dim) if norm_layer is not None else None + + def forward(self, input_tensor: Tensor) -> Tensor: + _, _, height, width = input_tensor.size() + if width % self.patch_size[1] != 0: + input_tensor = nn.functional.pad( + input_tensor, + (0, self.patch_size[1] - width % self.patch_size[1]), + ) + if height % self.patch_size[0] != 0: + input_tensor = nn.functional.pad( + input_tensor, + (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]), + ) + + output = self.proj(input_tensor) + if self.norm is None: + return output + + out_height, out_width = output.size(2), output.size(3) + output = output.flatten(2).transpose(1, 2) + output = self.norm(output) + return output.transpose(1, 2).view(-1, self.embed_dim, out_height, out_width) + + +# ---------------------------------------------------------------------------- +# STF / WACNN models +# ---------------------------------------------------------------------------- + + +_UPSTREAM_LATENT_CODEC_PREFIXES = ( + "cc_mean_transforms", + "cc_scale_transforms", + "lrp_transforms", + "gaussian_conditional", +) + + +def convert_upstream_stf_state_dict(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + """Translate a candidate ``STF`` / ``WACNN`` state dict into compressai layout. + + Upstream checkpoints (``stf__best.pth.tar`` / ``cnn__best.pth.tar`` + from `Zou et al. 2022 `_) are saved from a + ``DataParallel``-wrapped module and place the channel-conditional entropy + transforms at the model root. compressai houses those transforms (plus the + Gaussian conditional) under ``latent_codec.*``. This helper: + + - strips the leading ``module.`` prefix added by ``DataParallel``; + - re-roots ``cc_mean_transforms`` / ``cc_scale_transforms`` / + ``lrp_transforms`` / ``gaussian_conditional`` under ``latent_codec.``; + - leaves ``g_a`` / ``g_s`` / ``patch_embed`` / ``layers`` / ``syn_layers`` + / ``end_conv`` / ``h_a`` / ``h_mean_s`` / ``h_scale_s`` / + ``entropy_bottleneck`` keys unchanged. + + The returned dict can be loaded by :meth:`WACNN.from_state_dict` or + :meth:`SymmetricalTransFormer.from_state_dict`. Both ``from_state_dict`` + entry points auto-detect the upstream layout and call this helper, so + direct invocation is only needed when persisting the converted dict. + """ + converted: Dict[str, Tensor] = {} + for key, value in state_dict.items(): + new_key = key[len("module."):] if key.startswith("module.") else key + head = new_key.split(".", 1)[0] + if head in _UPSTREAM_LATENT_CODEC_PREFIXES: + new_key = "latent_codec." + new_key + converted[new_key] = value + return converted + + +def _is_upstream_stf_state_dict(state_dict: Dict[str, Tensor]) -> bool: + """Heuristic: upstream checkpoints either carry a ``module.`` prefix or + place ``cc_mean_transforms`` at the root instead of under ``latent_codec``. + """ + for key in state_dict: + if key.startswith("module."): + return True + if key.startswith("cc_mean_transforms.") or key.startswith("gaussian_conditional."): + return True + return False + + +@register_model("stf-wacnn") +class WACNN(SliceEntropyCompressionModel): + r"""WACNN model from R. Zou, C. Song, Z. Zhang: `"The Devil Is in the + Details: Window-based Attention for Image Compression" + `_, IEEE/CVF Conf. on Computer Vision + and Pattern Recognition (CVPR), 2022. + + CNN-based variant that inserts window-based attention modules + (:class:`compressai.layers.attn.WinNoShiftAttention` with + ``output_proj=False``) inside the analysis/synthesis transforms, paired + with a Minnen2020-style channel-wise autoregressive entropy model. + + Args: + N (int): Number of channels in the hyperprior backbone. + M (int): Number of channels in the latent representation. + num_slices (int): Number of channel slices for the entropy model. + """ + + def __init__( + self, + N: int = 192, + M: int = 320, + num_slices: int = 10, + max_support_slices: int = 5, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.g_a = nn.Sequential( + conv(3, N, kernel_size=5, stride=2), + GDN(N), + conv(N, N, kernel_size=5, stride=2), + GDN(N), + WinNoShiftAttention(dim=N, num_heads=8, window_size=8, shift_size=4, output_proj=False), + conv(N, N, kernel_size=5, stride=2), + GDN(N), + conv(N, M, kernel_size=5, stride=2), + WinNoShiftAttention(dim=M, num_heads=8, window_size=4, shift_size=2, output_proj=False), + ) + self.g_s = nn.Sequential( + WinNoShiftAttention(dim=M, num_heads=8, window_size=4, shift_size=2, output_proj=False), + deconv(M, N, kernel_size=5, stride=2), + GDN(N, inverse=True), + deconv(N, N, kernel_size=5, stride=2), + GDN(N, inverse=True), + WinNoShiftAttention(dim=N, num_heads=8, window_size=8, shift_size=4, output_proj=False), + deconv(N, N, kernel_size=5, stride=2), + GDN(N, inverse=True), + deconv(N, 3, kernel_size=5, stride=2), + ) + self.h_a = nn.Sequential( + conv3x3(M, M), + nn.GELU(), + conv3x3(M, 288), + nn.GELU(), + conv3x3(288, 256, stride=2), + nn.GELU(), + conv3x3(256, 224), + nn.GELU(), + conv3x3(224, N, stride=2), + ) + self.h_mean_s = nn.Sequential( + conv3x3(N, N), + nn.GELU(), + subpel_conv3x3(N, 224, 2), + nn.GELU(), + conv3x3(224, 256), + nn.GELU(), + subpel_conv3x3(256, 288, 2), + nn.GELU(), + conv3x3(288, M), + ) + self.h_scale_s = nn.Sequential( + conv3x3(N, N), + nn.GELU(), + subpel_conv3x3(N, 224, 2), + nn.GELU(), + conv3x3(224, 256), + nn.GELU(), + subpel_conv3x3(256, 288, 2), + nn.GELU(), + conv3x3(288, M), + ) + self._init_slice_entropy( + M, + N, + num_slices, + max_support_slices, + ) + + def forward(self, x: Tensor) -> Dict[str, Dict[str, Tensor] | Tensor]: + y = self.g_a(x) + latent_output = self._forward_latent_output(y) + return { + "x_hat": self.g_s(latent_output["y_hat"]), + "likelihoods": latent_output["likelihoods"], + } + + def compress(self, x: Tensor) -> Dict[str, object]: + return self._compress_latent(self.g_a(x)) + + def decompress(self, strings: Sequence[Sequence[bytes]], shape: Tuple[int, int]) -> Dict[str, Tensor]: + return {"x_hat": self.g_s(self._decompress_latent(strings, shape)).clamp_(0, 1)} + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "WACNN": + if _is_upstream_stf_state_dict(state_dict): + state_dict = convert_upstream_stf_state_dict(state_dict) + N = state_dict["g_a.0.weight"].size(0) + M = state_dict["g_a.7.weight"].size(0) + num_slices = infer_num_slices(state_dict) or 10 + max_support_slices = infer_max_support_slices(state_dict, M, num_slices) + net = cls( + N=N, + M=M, + num_slices=num_slices, + max_support_slices=max_support_slices, + ) + net.load_state_dict(state_dict) + return net + + +@register_model("stf") +class SymmetricalTransFormer(SliceEntropyCompressionModel): + r"""Symmetrical Transformer model (STF) from R. Zou, C. Song, Z. Zhang: + `"The Devil Is in the Details: Window-based Attention for Image + Compression" `_, IEEE/CVF Conf. on + Computer Vision and Pattern Recognition (CVPR), 2022. + + Transformer-based companion of :class:`WACNN` that builds the + analysis/synthesis transforms with stacked Swin-style basic layers and a + channel-wise autoregressive entropy model. + + Args: + embed_dim (int): Patch-embedding dimension. + num_slices (int): Number of channel slices for the entropy model. + """ + + def __init__( + self, + pretrain_img_size: int = 256, + patch_size: int = 2, + in_chans: int = 3, + embed_dim: int = 48, + depths: Optional[Sequence[int]] = None, + num_heads: Optional[Sequence[int]] = None, + window_size: int = 4, + num_slices: int = 12, + max_support_slices: Optional[int] = None, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.2, + norm_layer: type[nn.Module] = nn.LayerNorm, + patch_norm: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + depths = list(depths or [2, 2, 6, 2]) + num_heads = list(num_heads or [3, 6, 12, 24]) + if len(depths) != len(num_heads): + raise ValueError("depths and num_heads must have the same length") + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.patch_embed = _PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if patch_norm else None, + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [value.item() for value in torch.linspace(0, drop_path_rate, sum(depths))] + self.layers = nn.ModuleList() + for layer_index in range(self.num_layers): + self.layers.append( + _STFBasicLayer( + dim=int(embed_dim * 2**layer_index), + depth=depths[layer_index], + num_heads=num_heads[layer_index], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:layer_index]) : sum(depths[: layer_index + 1])], + norm_layer=norm_layer, + downsample=None if layer_index == self.num_layers - 1 else PatchMerging, + ) + ) + + reversed_depths = list(reversed(depths)) + reversed_heads = list(reversed(num_heads)) + self.syn_layers = nn.ModuleList() + for layer_index in range(self.num_layers): + self.syn_layers.append( + _STFBasicLayer( + dim=int(embed_dim * 2 ** (self.num_layers - 1 - layer_index)), + depth=reversed_depths[layer_index], + num_heads=reversed_heads[layer_index], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[ + sum(reversed_depths[:layer_index]) : sum(reversed_depths[: layer_index + 1]) + ], + norm_layer=norm_layer, + downsample=None if layer_index == self.num_layers - 1 else PatchSplit, + ) + ) + + self.end_conv = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim * patch_size**2, kernel_size=5, stride=1, padding=2), + nn.PixelShuffle(patch_size), + nn.Conv2d(embed_dim, 3, kernel_size=3, stride=1, padding=1), + ) + + latent_channels = int(embed_dim * 2 ** (self.num_layers - 1)) + bottleneck_channels = latent_channels // 2 + self.h_a = nn.Sequential( + conv3x3(latent_channels, latent_channels), + nn.GELU(), + conv3x3(latent_channels, latent_channels - embed_dim), + nn.GELU(), + conv3x3(latent_channels - embed_dim, latent_channels - 2 * embed_dim, stride=2), + nn.GELU(), + conv3x3(latent_channels - 2 * embed_dim, latent_channels - 3 * embed_dim), + nn.GELU(), + conv3x3(latent_channels - 3 * embed_dim, bottleneck_channels, stride=2), + ) + self.h_mean_s = nn.Sequential( + conv3x3(bottleneck_channels, latent_channels - 3 * embed_dim), + nn.GELU(), + subpel_conv3x3(latent_channels - 3 * embed_dim, latent_channels - 2 * embed_dim, 2), + nn.GELU(), + conv3x3(latent_channels - 2 * embed_dim, latent_channels - embed_dim), + nn.GELU(), + subpel_conv3x3(latent_channels - embed_dim, latent_channels, 2), + nn.GELU(), + conv3x3(latent_channels, latent_channels), + ) + self.h_scale_s = nn.Sequential( + conv3x3(bottleneck_channels, latent_channels - 3 * embed_dim), + nn.GELU(), + subpel_conv3x3(latent_channels - 3 * embed_dim, latent_channels - 2 * embed_dim, 2), + nn.GELU(), + conv3x3(latent_channels - 2 * embed_dim, latent_channels - embed_dim), + nn.GELU(), + subpel_conv3x3(latent_channels - embed_dim, latent_channels, 2), + nn.GELU(), + conv3x3(latent_channels, latent_channels), + ) + self._init_slice_entropy( + latent_channels, + bottleneck_channels, + num_slices, + num_slices // 2 if max_support_slices is None else max_support_slices, + ) + + def _analysis_transform(self, x: Tensor) -> Tuple[Tensor, int, int]: + output = self.patch_embed(x) + height, width = output.size(2), output.size(3) + output = self.pos_drop(output.flatten(2).transpose(1, 2)) + for layer in self.layers: + output, height, width = layer(output, height, width) + channels = self.embed_dim * 2 ** (self.num_layers - 1) + output = output.view(-1, height, width, channels).permute(0, 3, 1, 2).contiguous() + return output, height, width + + def _synthesis_transform(self, y_hat: Tensor, height: int, width: int) -> Tensor: + channels = self.embed_dim * 2 ** (self.num_layers - 1) + output = y_hat.permute(0, 2, 3, 1).contiguous().view(-1, height * width, channels) + for layer in self.syn_layers: + output, height, width = layer(output, height, width) + output = output.view(-1, height, width, self.embed_dim).permute(0, 3, 1, 2).contiguous() + return self.end_conv(output) + + def forward(self, x: Tensor) -> Dict[str, Dict[str, Tensor] | Tensor]: + y, height, width = self._analysis_transform(x) + latent_output = self._forward_latent_output(y) + return { + "x_hat": self._synthesis_transform(latent_output["y_hat"], height, width), + "likelihoods": latent_output["likelihoods"], + } + + def compress(self, x: Tensor) -> Dict[str, object]: + y, _, _ = self._analysis_transform(x) + return self._compress_latent(y) + + def decompress(self, strings: Sequence[Sequence[bytes]], shape: Tuple[int, int]) -> Dict[str, Tensor]: + y_hat = self._decompress_latent(strings, shape) + height, width = y_hat.shape[2:] + return {"x_hat": self._synthesis_transform(y_hat, height, width).clamp_(0, 1)} + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "SymmetricalTransFormer": + if _is_upstream_stf_state_dict(state_dict): + state_dict = convert_upstream_stf_state_dict(state_dict) + patch_size = state_dict["patch_embed.proj.weight"].size(2) + embed_dim = state_dict["patch_embed.proj.weight"].size(0) + layer_indices = sorted( + { + int(key.split(".")[1]) + for key in state_dict + if key.startswith("layers.") and ".blocks." in key + } + ) + depths = [ + len( + { + int(key.split(".")[3]) + for key in state_dict + if key.startswith(f"layers.{layer_index}.blocks.") + } + ) + for layer_index in layer_indices + ] + num_heads = [ + state_dict[f"layers.{layer_index}.blocks.0.attn.relative_position_bias_table"].size(1) + for layer_index in layer_indices + ] + table_size = state_dict["layers.0.blocks.0.attn.relative_position_bias_table"].size(0) + window_size = (math.isqrt(table_size) + 1) // 2 + num_slices = infer_num_slices(state_dict) or 12 + latent_channels = embed_dim * 2 ** (len(depths) - 1) + max_support_slices = infer_max_support_slices(state_dict, latent_channels, num_slices) + + net = cls( + patch_size=patch_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + num_slices=num_slices, + max_support_slices=max_support_slices, + ) + net.load_state_dict(state_dict) + return net diff --git a/compressai/zoo/__init__.py b/compressai/zoo/__init__.py index 5c56bee7..acebc705 100644 --- a/compressai/zoo/__init__.py +++ b/compressai/zoo/__init__.py @@ -35,6 +35,8 @@ cheng2020_attn, mbt2018, mbt2018_mean, + stf, + stf_wacnn, ) from .image_vbr import bmshj2018_hyperprior_vbr, mbt2018_mean_vbr, mbt2018_vbr from .pretrained import load_pretrained as load_state_dict @@ -48,6 +50,8 @@ "mbt2018": mbt2018, "cheng2020-anchor": cheng2020_anchor, "cheng2020-attn": cheng2020_attn, + "stf": stf, + "stf-wacnn": stf_wacnn, "bmshj2018-hyperprior-vbr": bmshj2018_hyperprior_vbr, "mbt2018-mean-vbr": mbt2018_mean_vbr, "mbt2018-vbr": mbt2018_vbr, diff --git a/compressai/zoo/image.py b/compressai/zoo/image.py index e0c34492..cd25009d 100644 --- a/compressai/zoo/image.py +++ b/compressai/zoo/image.py @@ -37,6 +37,8 @@ JointAutoregressiveHierarchicalPriors, MeanScaleHyperprior, ScaleHyperprior, + SymmetricalTransFormer, + WACNN, ) from .pretrained import load_pretrained @@ -49,6 +51,8 @@ "mbt2018_mean", "cheng2020_anchor", "cheng2020_attn", + "stf", + "stf_wacnn", ] model_architectures = { @@ -59,6 +63,8 @@ "mbt2018": JointAutoregressiveHierarchicalPriors, "cheng2020-anchor": Cheng2020Anchor, "cheng2020-attn": Cheng2020Attention, + "stf": SymmetricalTransFormer, + "stf-wacnn": WACNN, } root_url = "https://compressai.s3.amazonaws.com/models/v1" @@ -447,3 +453,39 @@ def cheng2020_attn(quality, metric="mse", pretrained=False, progress=True, **kwa return _load_model( "cheng2020-attn", metric, quality, pretrained, progress, **kwargs ) + + +def stf(pretrained: bool = False, progress: bool = True, **kwargs): + r"""Symmetrical TransFormer (STF) model from R. Zou, C. Song, Z. Zhang: + `"The Devil Is in the Details: Window-based Attention for Image + Compression" `_, IEEE/CVF Conf. on + Computer Vision and Pattern Recognition (CVPR), 2022. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained STF weights are not yet hosted on S3.") + return SymmetricalTransFormer(**kwargs) + + +def stf_wacnn(pretrained: bool = False, progress: bool = True, **kwargs): + r"""WACNN model from R. Zou, C. Song, Z. Zhang: + `"The Devil Is in the Details: Window-based Attention for Image + Compression" `_, IEEE/CVF Conf. on + Computer Vision and Pattern Recognition (CVPR), 2022. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained WACNN weights are not yet hosted on S3.") + return WACNN(**kwargs) diff --git a/examples/convert_stf_checkpoint.py b/examples/convert_stf_checkpoint.py new file mode 100644 index 00000000..a8bb4a55 --- /dev/null +++ b/examples/convert_stf_checkpoint.py @@ -0,0 +1,134 @@ +"""Convert an upstream STF / WACNN checkpoint to compressai layout. + +Loads the published candidate weight file (e.g. ``stf_0018_best.pth.tar`` or +``cnn_0018_best.pth.tar`` from the STF repo), translates it to compressai's +module layout, and writes a state dict that +``compressai.models.SymmetricalTransFormer.from_state_dict`` / +``compressai.models.WACNN.from_state_dict`` can load directly. Optionally +reports forward-pass sanity numbers (PSNR / bpp) on a synthetic input. + +Example:: + + python examples/convert_stf_checkpoint.py \\ + --src candidate/STF/stf_0018_best.pth.tar \\ + --arch stf \\ + --dst /tmp/stf_compressai.pth \\ + --smoke + + python examples/convert_stf_checkpoint.py \\ + --src candidate/STF/cnn_0018_best.pth.tar \\ + --arch wacnn \\ + --smoke +""" +from __future__ import annotations + +import argparse +from pathlib import Path + +import torch + +from compressai.models import ( + SymmetricalTransFormer, + WACNN, + convert_upstream_stf_state_dict, +) + + +_ARCHES = {"stf": SymmetricalTransFormer, "wacnn": WACNN} + + +def _detect_arch(state_dict: dict) -> str: + keys = state_dict.keys() + if any("patch_embed" in k for k in keys): + return "stf" + if any(k.endswith("g_a.0.weight") for k in keys): + return "wacnn" + raise SystemExit( + "could not auto-detect arch; pass --arch {stf,wacnn} explicitly" + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--src", + type=Path, + required=True, + help="Path to the upstream checkpoint (e.g. stf_0018_best.pth.tar).", + ) + parser.add_argument( + "--arch", + choices=sorted(_ARCHES), + default=None, + help="Architecture to instantiate. Auto-detected from key names if omitted.", + ) + parser.add_argument( + "--dst", + type=Path, + default=None, + help=( + "Optional output path for the converted state dict. If omitted, " + "the script only verifies that the checkpoint loads cleanly." + ), + ) + parser.add_argument( + "--smoke", + action="store_true", + help="Run a forward smoke test on a synthetic 256x256 image.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if not args.src.exists(): + raise SystemExit(f"checkpoint not found: {args.src}") + + upstream = torch.load(args.src, map_location="cpu", weights_only=False) + upstream = upstream.get("state_dict", upstream) if isinstance(upstream, dict) else upstream + converted = convert_upstream_stf_state_dict(upstream) + print(f"loaded {len(upstream)} upstream keys → {len(converted)} compressai keys") + + arch = args.arch or _detect_arch(upstream) + cls = _ARCHES[arch] + net = cls.from_state_dict(upstream) + net.eval() + print(f"loaded {arch.upper()}: {sum(p.numel() for p in net.parameters()):,} params") + + if args.dst is not None: + args.dst.parent.mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), args.dst) + print(f"wrote converted state dict → {args.dst}") + + if args.smoke: + height = width = 256 + ys, xs = torch.meshgrid( + torch.linspace(0, 1, height), + torch.linspace(0, 1, width), + indexing="ij", + ) + img = torch.stack( + [ + 0.5 + 0.3 * torch.sin(8 * xs), + 0.5 + 0.3 * torch.sin(8 * ys), + 0.5 + 0.3 * torch.cos(8 * (xs + ys)), + ], + dim=0, + ).unsqueeze(0).clamp(0, 1) + + with torch.no_grad(): + out = net(img) + n_pix = height * width + psnr = -10 * torch.log10( + ((out["x_hat"].clamp(0, 1) - img) ** 2).mean() + ).item() + y_bpp = -torch.log2(out["likelihoods"]["y"]).sum().item() / n_pix + z_bpp = -torch.log2(out["likelihoods"]["z"]).sum().item() / n_pix + print( + f"smoke: PSNR={psnr:.2f}dB y_bpp={y_bpp:.4f} z_bpp={z_bpp:.4f} " + f"total_bpp={y_bpp + z_bpp:.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index dc6d2cc9..e99b225c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "pytorch-msssim", "scipy", "setuptools>=68", # For --no-build-isolation. + "timm", "tomli>=2.2.1", "torch-geometric>=2.3.0", "torch>=1.13.1", diff --git a/tests/test_models.py b/tests/test_models.py index c69ae7d5..2e45a111 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -278,6 +278,63 @@ def test_scale_space_flow(self): assert z_likelihoods_shape[3] == x[1].shape[3] / 2**7 +class TestStf: + def test_wacnn_forward_and_state_dict_round_trip(self): + from compressai.models import WACNN + + model = WACNN(N=64, M=128, num_slices=4, max_support_slices=2).eval() + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + + loaded = WACNN.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + + def test_symmetrical_transformer_forward_and_state_dict_round_trip(self): + from compressai.models import SymmetricalTransFormer + + model = SymmetricalTransFormer( + embed_dim=24, + depths=(1, 1, 1, 1), + num_heads=(2, 2, 2, 2), + num_slices=4, + max_support_slices=2, + ).eval() + x = torch.rand(1, 3, 128, 128) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + + loaded = SymmetricalTransFormer.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + + def test_stf_upstream_state_dict_conversion(self): + from compressai.models.stf import ( + convert_upstream_stf_state_dict, + ) + + upstream = { + "module.g_a.0.weight": torch.zeros(2), + "module.cc_mean_transforms.0.0.weight": torch.zeros(2), + "module.gaussian_conditional.scale_table": torch.zeros(2), + "module.h_a.0.weight": torch.zeros(2), + } + converted = convert_upstream_stf_state_dict(upstream) + assert "g_a.0.weight" in converted + assert "latent_codec.cc_mean_transforms.0.0.weight" in converted + assert "latent_codec.gaussian_conditional.scale_table" in converted + assert "h_a.0.weight" in converted + + def test_scale_table_default(): table = get_scale_table() assert SCALES_MIN == 0.11