Skip to content

Add :safetensors_reader option to load_model/2#456

Merged
jonatanklosko merged 1 commit into
elixir-nx:mainfrom
ausimian:feat/safetensors-reader-option
May 18, 2026
Merged

Add :safetensors_reader option to load_model/2#456
jonatanklosko merged 1 commit into
elixir-nx:mainfrom
ausimian:feat/safetensors-reader-option

Conversation

@ausimian
Copy link
Copy Markdown
Contributor

Summary

Adds a :safetensors_reader option to Bumblebee.load_model/2. When supplied, the override is used instead of &Safetensors.read!(&1, lazy: true) to read .safetensors parameter files. The override receives a file path and must return a map from tensor name to an Nx.Tensor or any term implementing Nx.LazyContainer — the same shape Safetensors.read!/2 already returns, so the rest of the loading pipeline (Bumblebee.Conversion.PyTorchParams.load_params!/4) is unchanged.

Default behaviour is identical when the option is not supplied.

Motivation

This is a small seam that enables custom safetensors readers without forking Bumblebee. The concrete use case is a memory-mapped reader backed by a NIF resource binary (enif_make_resource_binary), which keeps peak BEAM memory bounded to a single tensor when loading very large checkpoints — pages are demand-faulted by the OS and freed after the backend transfer, instead of being pread-ed and copied into BEAM heap per tensor.

The current Safetensors.read!/2 lazy path already streams reasonably well, but the per-tensor File.open + pread + binary copy is a real cost on multi-GB checkpoints, and there's no way to swap it today because the loader is wired directly to the module name at bumblebee.ex:771.

The option is scoped to safetensors only. PyTorch pickle isn't a candidate for the same treatment because it requires structural parsing rather than byte-range reads.

Test plan

  • mix test test/bumblebee_test.exs --only describe:"load_model/2" passes locally (5/5).
  • New test asserts the override is invoked for .safetensors files and that the resulting params have the same keys as the default reader.
  • Existing tests pass unchanged (default path is unaffected).

Allows callers to override the function used to read `.safetensors`
parameter files. The override receives a file path and must return a
map from tensor name to an `Nx.Tensor` or `Nx.LazyContainer`. Defaults
to the existing `&Safetensors.read!(&1, lazy: true)`, so behaviour is
unchanged when the option is not supplied.

This is a small seam for custom readers — for example, a memory-mapped
loader backed by a resource binary, which can keep peak memory bounded
to a single tensor when loading very large checkpoints.
@jonatanklosko jonatanklosko merged commit d0774e8 into elixir-nx:main May 18, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants