Skip to content

ANTsX/ANTsNormalizingFlows

Repository files navigation

ANTsNormalizingFlows

An updated PyTorch package (from normflows) for discrete normalizing flows.

Quick start

import antsnormflows as nf

# Base distribution (2D diagonal Gaussian)
base = nf.distributions.base.DiagGaussian(2)

# Real NVP with simple MLP conditioner
flows = []
num_layers = 8
for _ in range(num_layers):
    param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)
    flows.append(nf.flows.AffineCouplingBlock(param_map))
    flows.append(nf.flows.Permute(2, mode="swap"))

model = nf.NormalizingFlow(base, flows)
loss = model.forward_kld(x)  # x: (batch, 2)
loss.backward()

Documentation

Citation

If you use antsnormflows, please cite the corresponding papers:

  • Stimper et al., (2023). normflows: A PyTorch Package for Normalizing Flows. Journal of Open Source Software, 8(86), 5361, JOSS.

  • Tustison et al., (2026). Deep Computational Anatomy via Latent-Aligned Normalizing Flows. biorxiv.

About

PyTorch implementation of normalizing flow models

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages