Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions deeplc/_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,69 @@ def forward(
output = self.final_network(concatenated)

return output


class BatchedHeads(nn.Module):
"""Parallel output heads sharing a hidden projection.

Each head maps the shared trunk output to a scalar via a two-step
computation: a batched linear projection followed by a per-head dot
product with a learned weight vector.

Parameters
----------
input_size
Size of the input feature vector (output of shared trunk).
n_heads
Number of parallel output heads.
hidden
Hidden dimension per head (default: 32).
"""

def __init__(self, input_size: int, n_heads: int, hidden: int = 32):
super().__init__()
self.layer1 = nn.Linear(input_size, n_heads * hidden)
self.w2 = nn.Parameter(torch.zeros(n_heads, hidden))
self.b2 = nn.Parameter(torch.zeros(n_heads))
nn.init.normal_(self.w2, std=0.05)

def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.layer1(x) # (batch, n_heads * hidden)
n_heads = self.b2.shape[0]
h = torch.relu(h.view(h.shape[0], n_heads, h.shape[1] // n_heads))
return (h * self.w2.unsqueeze(0)).sum(dim=-1) + self.b2 # (batch, n_heads)


class MultitaskDeepLCModel(nn.Module):
"""Multi-task DeepLC backbone predicting RT across multiple LC systems.

Shares the same four input branches as :class:`DeepLCModel` but replaces
the single-output final network with a shared trunk feeding into
:class:`BatchedHeads`, producing one RT value per LC system.

This class is primarily used for loading pre-trained checkpoints via
``torch.load``. The child modules (``branch_a``, ``branch_b``,
``branch_c``, ``branch_d``, ``shared_trunk``, ``heads``) are restored
from the checkpoint state dict and do not need to be constructed here.
"""

def forward(
self,
x_atom: torch.Tensor,
x_atom_sum: torch.Tensor,
x_global: torch.Tensor,
x_one_hot: torch.Tensor,
) -> torch.Tensor:
x_atom = x_atom.transpose(1, 2)
x_atom_sum = x_atom_sum.transpose(1, 2)
x_one_hot = x_one_hot.transpose(1, 2)
concatenated = torch.cat(
[
self.branch_a(x_atom), # type: ignore[attr-defined]
self.branch_b(x_atom_sum), # type: ignore[attr-defined]
self.branch_c(x_global), # type: ignore[attr-defined]
self.branch_d(x_one_hot), # type: ignore[attr-defined]
],
dim=1,
)
return self.heads(self.shared_trunk(concatenated)) # type: ignore[attr-defined]
21 changes: 20 additions & 1 deletion deeplc/_model_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import copy
import logging
import sys
import types
from collections.abc import Callable
from os import PathLike
from pathlib import Path
Expand All @@ -17,12 +19,28 @@
)
from torch.utils.data import DataLoader, Dataset, Subset

from deeplc._architecture import DeepLCModel
from deeplc._architecture import BatchedHeads, DeepLCModel, MultitaskDeepLCModel
from deeplc.data import DeepLCDataset

logger = logging.getLogger(__name__)


def _patch_legacy_multitask_module() -> None:
"""Register a backwards-compatibility shim for multitask_model.pt.

The bundled multitask checkpoint was saved when MultitaskDeepLCModel and
BatchedHeads lived in a top-level module called ``multitask_model``. That
module no longer exists; the classes now live in ``deeplc._architecture``.
Registering a shim in ``sys.modules`` before ``torch.load`` lets pickle
resolve the old import paths without re-saving the checkpoint.
"""
if "multitask_model" not in sys.modules:
shim = types.ModuleType("multitask_model")
shim.MultitaskDeepLCModel = MultitaskDeepLCModel # type: ignore[attr-defined]
shim.BatchedHeads = BatchedHeads # type: ignore[attr-defined]
sys.modules["multitask_model"] = shim


def load_model(
model: torch.nn.Module | PathLike | str | None = None,
device: str | None = None,
Expand All @@ -33,6 +51,7 @@ def load_model(

# Load model from file if a path is provided
if isinstance(model, (str, PathLike, Path)):
_patch_legacy_multitask_module()
loaded_model = torch.load(model, weights_only=False, map_location=selected_device)
elif isinstance(model, torch.nn.Module):
loaded_model = model
Expand Down
25 changes: 25 additions & 0 deletions tests/test_model_ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import sys

import pytest
import torch
from torch.utils.data import Dataset

from deeplc import _model_ops
from deeplc._architecture import BatchedHeads, MultitaskDeepLCModel
from deeplc.core import DEFAULT_MULTITASK_MODEL_PACKAGED
from deeplc.data import split_datasets


Expand Down Expand Up @@ -58,3 +62,24 @@ def test_train_rejects_empty_validation_loader():
batch_size=2,
show_progress=False,
)


def test_load_multitask_model_without_prior_shim():
"""multitask_model.pt must load even when the legacy module is not pre-registered."""
# Remove any previously registered shim so the test is self-contained.
sys.modules.pop("multitask_model", None)

model = _model_ops.load_model(DEFAULT_MULTITASK_MODEL_PACKAGED, device="cpu")

assert isinstance(model, MultitaskDeepLCModel)

x_atom = torch.zeros(2, 60, 6)
x_sum = torch.zeros(2, 30, 6)
x_global = torch.zeros(2, 55)
x_hc = torch.zeros(2, 60, 20)
with torch.no_grad():
out = model(x_atom, x_sum, x_global, x_hc)

assert out.ndim == 2
assert out.shape[0] == 2
assert out.shape[1] > 1 # multiple heads
Loading