From cc2657616605b2666d6fab5a4b2b437eb917081f Mon Sep 17 00:00:00 2001 From: RobbinBouwmeester Date: Fri, 24 Apr 2026 08:20:11 +0200 Subject: [PATCH] Fix multitask_model.pt loading: add BatchedHeads/MultitaskDeepLCModel to _architecture and register legacy module shim The bundled multitask_model.pt was serialised when MultitaskDeepLCModel and BatchedHeads lived in a top-level module called `multitask_model`. That module no longer exists, so torch.load raised ModuleNotFoundError for any user who tried to load the default model. Fix: - Add BatchedHeads and MultitaskDeepLCModel to deeplc/_architecture.py, where the rest of the model architecture already lives. - Add _patch_legacy_multitask_module() to deeplc/_model_ops.py, which registers a sys.modules shim mapping the old import path to the new classes before torch.load is called. The shim is a no-op if the module is already registered. - Add test_load_multitask_model_without_prior_shim to tests/test_model_ops.py to prevent regression. Co-Authored-By: Claude Sonnet 4.6 --- deeplc/_architecture.py | 66 +++++++++++++++++++++++++++++++++++++++++ deeplc/_model_ops.py | 21 ++++++++++++- tests/test_model_ops.py | 25 ++++++++++++++++ 3 files changed, 111 insertions(+), 1 deletion(-) diff --git a/deeplc/_architecture.py b/deeplc/_architecture.py index a21a1cc..f2495a6 100644 --- a/deeplc/_architecture.py +++ b/deeplc/_architecture.py @@ -423,3 +423,69 @@ def forward( output = self.final_network(concatenated) return output + + +class BatchedHeads(nn.Module): + """Parallel output heads sharing a hidden projection. + + Each head maps the shared trunk output to a scalar via a two-step + computation: a batched linear projection followed by a per-head dot + product with a learned weight vector. + + Parameters + ---------- + input_size + Size of the input feature vector (output of shared trunk). + n_heads + Number of parallel output heads. + hidden + Hidden dimension per head (default: 32). + """ + + def __init__(self, input_size: int, n_heads: int, hidden: int = 32): + super().__init__() + self.layer1 = nn.Linear(input_size, n_heads * hidden) + self.w2 = nn.Parameter(torch.zeros(n_heads, hidden)) + self.b2 = nn.Parameter(torch.zeros(n_heads)) + nn.init.normal_(self.w2, std=0.05) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.layer1(x) # (batch, n_heads * hidden) + n_heads = self.b2.shape[0] + h = torch.relu(h.view(h.shape[0], n_heads, h.shape[1] // n_heads)) + return (h * self.w2.unsqueeze(0)).sum(dim=-1) + self.b2 # (batch, n_heads) + + +class MultitaskDeepLCModel(nn.Module): + """Multi-task DeepLC backbone predicting RT across multiple LC systems. + + Shares the same four input branches as :class:`DeepLCModel` but replaces + the single-output final network with a shared trunk feeding into + :class:`BatchedHeads`, producing one RT value per LC system. + + This class is primarily used for loading pre-trained checkpoints via + ``torch.load``. The child modules (``branch_a``, ``branch_b``, + ``branch_c``, ``branch_d``, ``shared_trunk``, ``heads``) are restored + from the checkpoint state dict and do not need to be constructed here. + """ + + def forward( + self, + x_atom: torch.Tensor, + x_atom_sum: torch.Tensor, + x_global: torch.Tensor, + x_one_hot: torch.Tensor, + ) -> torch.Tensor: + x_atom = x_atom.transpose(1, 2) + x_atom_sum = x_atom_sum.transpose(1, 2) + x_one_hot = x_one_hot.transpose(1, 2) + concatenated = torch.cat( + [ + self.branch_a(x_atom), # type: ignore[attr-defined] + self.branch_b(x_atom_sum), # type: ignore[attr-defined] + self.branch_c(x_global), # type: ignore[attr-defined] + self.branch_d(x_one_hot), # type: ignore[attr-defined] + ], + dim=1, + ) + return self.heads(self.shared_trunk(concatenated)) # type: ignore[attr-defined] diff --git a/deeplc/_model_ops.py b/deeplc/_model_ops.py index ad959dc..209a4c8 100644 --- a/deeplc/_model_ops.py +++ b/deeplc/_model_ops.py @@ -2,6 +2,8 @@ import copy import logging +import sys +import types from collections.abc import Callable from os import PathLike from pathlib import Path @@ -17,12 +19,28 @@ ) from torch.utils.data import DataLoader, Dataset, Subset -from deeplc._architecture import DeepLCModel +from deeplc._architecture import BatchedHeads, DeepLCModel, MultitaskDeepLCModel from deeplc.data import DeepLCDataset logger = logging.getLogger(__name__) +def _patch_legacy_multitask_module() -> None: + """Register a backwards-compatibility shim for multitask_model.pt. + + The bundled multitask checkpoint was saved when MultitaskDeepLCModel and + BatchedHeads lived in a top-level module called ``multitask_model``. That + module no longer exists; the classes now live in ``deeplc._architecture``. + Registering a shim in ``sys.modules`` before ``torch.load`` lets pickle + resolve the old import paths without re-saving the checkpoint. + """ + if "multitask_model" not in sys.modules: + shim = types.ModuleType("multitask_model") + shim.MultitaskDeepLCModel = MultitaskDeepLCModel # type: ignore[attr-defined] + shim.BatchedHeads = BatchedHeads # type: ignore[attr-defined] + sys.modules["multitask_model"] = shim + + def load_model( model: torch.nn.Module | PathLike | str | None = None, device: str | None = None, @@ -33,6 +51,7 @@ def load_model( # Load model from file if a path is provided if isinstance(model, (str, PathLike, Path)): + _patch_legacy_multitask_module() loaded_model = torch.load(model, weights_only=False, map_location=selected_device) elif isinstance(model, torch.nn.Module): loaded_model = model diff --git a/tests/test_model_ops.py b/tests/test_model_ops.py index 38306a0..4959b51 100644 --- a/tests/test_model_ops.py +++ b/tests/test_model_ops.py @@ -1,10 +1,14 @@ from __future__ import annotations +import sys + import pytest import torch from torch.utils.data import Dataset from deeplc import _model_ops +from deeplc._architecture import BatchedHeads, MultitaskDeepLCModel +from deeplc.core import DEFAULT_MULTITASK_MODEL_PACKAGED from deeplc.data import split_datasets @@ -58,3 +62,24 @@ def test_train_rejects_empty_validation_loader(): batch_size=2, show_progress=False, ) + + +def test_load_multitask_model_without_prior_shim(): + """multitask_model.pt must load even when the legacy module is not pre-registered.""" + # Remove any previously registered shim so the test is self-contained. + sys.modules.pop("multitask_model", None) + + model = _model_ops.load_model(DEFAULT_MULTITASK_MODEL_PACKAGED, device="cpu") + + assert isinstance(model, MultitaskDeepLCModel) + + x_atom = torch.zeros(2, 60, 6) + x_sum = torch.zeros(2, 30, 6) + x_global = torch.zeros(2, 55) + x_hc = torch.zeros(2, 60, 20) + with torch.no_grad(): + out = model(x_atom, x_sum, x_global, x_hc) + + assert out.ndim == 2 + assert out.shape[0] == 2 + assert out.shape[1] > 1 # multiple heads