An updated PyTorch package (from normflows) for discrete normalizing flows.
import antsnormflows as nf
# Base distribution (2D diagonal Gaussian)
base = nf.distributions.base.DiagGaussian(2)
# Real NVP with simple MLP conditioner
flows = []
num_layers = 8
for _ in range(num_layers):
param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)
flows.append(nf.flows.AffineCouplingBlock(param_map))
flows.append(nf.flows.Permute(2, mode="swap"))
model = nf.NormalizingFlow(base, flows)
loss = model.forward_kld(x) # x: (batch, 2)
loss.backward()If you use antsnormflows, please cite the corresponding papers: