diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 20bf0b3..bffb162 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: - id: mypy additional_dependencies: [numpy] args: [--strict] - exclude: '(\.pyi$|tests/|examples/)' + exclude: '(\.pyi$|tests/|examples/|benches/)' - repo: local hooks: - id: cargo-fmt diff --git a/Cargo.lock b/Cargo.lock index c994f4f..732320d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9,6 +9,7 @@ dependencies = [ "criterion", "nalgebra", "numpy", + "ordered-float", "pyo3", "rustc-hash", "thiserror", @@ -415,6 +416,15 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "ordered-float" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" +dependencies = [ + "num-traits", +] + [[package]] name = "paste" version = "1.0.15" diff --git a/Cargo.toml b/Cargo.toml index 4a26d5a..43cfaf5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,11 +17,12 @@ crate-type = ["cdylib", "rlib"] criterion = "0.5" [dependencies] -pyo3 = { version = "0.23", features = ["extension-module", "abi3-py310"] } +pyo3 = { version = "0.23", features = ["abi3-py310"] } numpy = "0.23" nalgebra = "0.33" thiserror = "2" rustc-hash = "2" +ordered-float = "4" [profile.release] lto = true diff --git a/benches/bench_learner1d.py b/benches/bench_learner1d.py new file mode 100644 index 0000000..90b244d --- /dev/null +++ b/benches/bench_learner1d.py @@ -0,0 +1,89 @@ +"""Benchmarks for the Rust-powered Learner1D.""" + +from __future__ import annotations + +import math +import time + +from adaptive_triangulation import Learner1D + + +def bench(name: str, fn, *, n_iter: int = 10): + """Run a benchmark and print results.""" + times = [] + for _ in range(n_iter): + t0 = time.perf_counter() + fn() + times.append(time.perf_counter() - t0) + times.sort() + median = times[len(times) // 2] + print(f" {name}: {median * 1000:.2f} ms (median of {n_iter})") + return median + + +def bench_tell_single(): + print("=== tell_single ===") + for n in [1_000, 10_000, 100_000]: + step = 2.0 / n + + def run(n=n, step=step): + l = Learner1D(bounds=(-1.0, 1.0)) + for i in range(n): + x = -1.0 + step * i + l.tell(x, math.sin(x * 10)) + + bench(f"tell {n:,} points", run, n_iter=5 if n >= 100_000 else 10) + + +def bench_tell_many_batch(): + print("\n=== tell_many (force rebuild) ===") + for n in [1_000, 10_000]: + step = 2.0 / n + xs = [-1.0 + step * i for i in range(n)] + ys = [math.sin(x * 10) for x in xs] + + def run(xs=xs, ys=ys): + l = Learner1D(bounds=(-1.0, 1.0)) + l.tell_many(xs, ys, force=True) + + bench(f"tell_many {n:,} points", run) + + +def bench_ask(): + print("\n=== ask 100 points ===") + for n_existing in [100, 1_000, 10_000]: + step = 2.0 / n_existing + xs = [-1.0 + step * i for i in range(n_existing)] + ys = [math.sin(x * 10) for x in xs] + + def run(xs=xs, ys=ys): + l = Learner1D(bounds=(-1.0, 1.0)) + l.tell_many(xs, ys, force=True) + l.ask(100, tell_pending=False) + + bench(f"ask 100 (from {n_existing:,} pts)", run) + + +def bench_full_loop(): + print("\n=== full loop 10K points ===") + f = lambda x: math.sin(x * 10) + + def run_serial(): + l = Learner1D(bounds=(-1.0, 1.0)) + l.run(f, n_points=10_000, batch_size=1) + + def run_batched(): + l = Learner1D(bounds=(-1.0, 1.0)) + l.run(f, n_points=10_000, batch_size=100) + + bench("serial (batch=1)", run_serial, n_iter=3) + bench("batched (batch=100)", run_batched, n_iter=3) + + +if __name__ == "__main__": + print("Learner1D Benchmarks") + print("=" * 50) + bench_tell_single() + bench_tell_many_batch() + bench_ask() + bench_full_loop() diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..f36ad7e --- /dev/null +++ b/build.rs @@ -0,0 +1,28 @@ +use std::process::Command; + +fn main() { + let python = std::env::var("PYO3_PYTHON") + .or_else(|_| std::env::var("PYTHON_SYS_EXECUTABLE")) + .unwrap_or_else(|_| "python3".to_owned()); + + let Ok(output) = Command::new(&python) + .args([ + "-c", + "import sysconfig; print(sysconfig.get_config_var('LIBDIR') or '')", + ]) + .output() + else { + return; + }; + if !output.status.success() { + return; + } + + let Ok(libdir) = String::from_utf8(output.stdout) else { + return; + }; + let libdir = libdir.trim(); + if !libdir.is_empty() { + println!("cargo:rustc-link-arg=-Wl,-rpath,{libdir}"); + } +} diff --git a/pyproject.toml b/pyproject.toml index 012f4ee..b87090e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ select = ["ALL"] ignore = ["D", "COM812", "ISC001"] [tool.ruff.lint.per-file-ignores] +"benches/*" = ["INP001", "T201", "ANN", "E741", "PLR2004", "E731"] "tests/*" = ["S101", "PLR2004", "ANN"] "examples/*" = ["INP001", "T201", "S101", "N813", "RUF001", "RUF002", "SIM105", "PERF203"] diff --git a/python/adaptive_triangulation/__init__.py b/python/adaptive_triangulation/__init__.py index 8b4d4c1..5fef83f 100644 --- a/python/adaptive_triangulation/__init__.py +++ b/python/adaptive_triangulation/__init__.py @@ -23,6 +23,7 @@ from __future__ import annotations from ._rust import ( + Learner1D, SimplicesProxy, Triangulation, VertexToSimplicesProxy, @@ -41,6 +42,7 @@ ) __all__: list[str] = [ + "Learner1D", "SimplicesProxy", "Triangulation", "VertexToSimplicesProxy", diff --git a/python/adaptive_triangulation/_rust.pyi b/python/adaptive_triangulation/_rust.pyi index 09aa0b6..3ed1faf 100644 --- a/python/adaptive_triangulation/_rust.pyi +++ b/python/adaptive_triangulation/_rust.pyi @@ -14,6 +14,47 @@ TransformLike: TypeAlias = Sequence[Sequence[float]] | npt.ArrayLike __version__: str +class Learner1D: + def __init__( + self, + bounds: tuple[float, float], + loss_per_interval: object | None = None, + ) -> None: ... + def tell(self, x: float, y: float | Sequence[float]) -> None: ... + def tell_many( + self, + xs: list[float], + ys: list[float | Sequence[float]], + force: bool = False, + ) -> None: ... + def tell_pending(self, x: float) -> None: ... + def ask(self, n: int, tell_pending: bool = True) -> tuple[list[float], list[float]]: ... + def run( + self, + f: object, + *, + goal: float | None = None, + n_points: int | None = None, + batch_size: int = 1, + ) -> int: ... + @property + def loss(self) -> float: ... + @property + def npoints(self) -> int: ... + @property + def vdim(self) -> int | None: ... + def to_numpy( + self, + ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: ... + def remove_unfinished(self) -> None: ... + @property + def pending_points(self) -> list[float]: ... + @property + def data(self) -> dict[float, float | list[float]]: ... + def intervals(self) -> list[tuple[float, float, float]]: ... + @property + def bounds(self) -> tuple[float, float]: ... + class SimplicesProxy: def __contains__(self, simplex: Simplex) -> bool: ... def __iter__(self) -> Iterator[Simplex]: ... diff --git a/src/learner1d/loss.rs b/src/learner1d/loss.rs new file mode 100644 index 0000000..e12680a --- /dev/null +++ b/src/learner1d/loss.rs @@ -0,0 +1,427 @@ +use std::collections::{BTreeSet, HashMap}; + +use pyo3::prelude::*; + +use super::{YValue, OF64}; + +// ---- Loss function enum ---- + +/// Built-in and custom loss functions for 1D adaptive sampling. +pub enum LossFunction { + /// `sqrt(dx² + dy²)` with scaling. + Default, + /// `dx` — uniform sampling. + Uniform, + /// Default loss clamped by interval width. + Resolution { min_length: f64, max_length: f64 }, + /// Weighted sum of triangle area + euclidean distance + horizontal distance. + Curvature { + area_factor: f64, + euclid_factor: f64, + horizontal_factor: f64, + }, + /// Average area of triangles formed by 4 neighbouring points. + Triangle, + /// Default loss on `log(|y|)`. + AbsMinLog, + /// Python callable `(xs, ys) -> float`. + PythonCallback { + callback: PyObject, + nth_neighbors: usize, + }, +} + +impl LossFunction { + pub fn nth_neighbors(&self) -> usize { + match self { + Self::Default | Self::Uniform | Self::Resolution { .. } | Self::AbsMinLog => 0, + Self::Curvature { .. } | Self::Triangle => 1, + Self::PythonCallback { nth_neighbors, .. } => *nth_neighbors, + } + } +} + +impl std::fmt::Debug for LossFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Default => write!(f, "Default"), + Self::Uniform => write!(f, "Uniform"), + Self::Resolution { + min_length, + max_length, + } => f + .debug_struct("Resolution") + .field("min_length", min_length) + .field("max_length", max_length) + .finish(), + Self::Curvature { + area_factor, + euclid_factor, + horizontal_factor, + } => f + .debug_struct("Curvature") + .field("area_factor", area_factor) + .field("euclid_factor", euclid_factor) + .field("horizontal_factor", horizontal_factor) + .finish(), + Self::Triangle => write!(f, "Triangle"), + Self::AbsMinLog => write!(f, "AbsMinLog"), + Self::PythonCallback { nth_neighbors, .. } => f + .debug_struct("PythonCallback") + .field("nth_neighbors", nth_neighbors) + .finish(), + } + } +} + +impl Clone for LossFunction { + fn clone(&self) -> Self { + match self { + Self::Default => Self::Default, + Self::Uniform => Self::Uniform, + Self::Resolution { + min_length, + max_length, + } => Self::Resolution { + min_length: *min_length, + max_length: *max_length, + }, + Self::Curvature { + area_factor, + euclid_factor, + horizontal_factor, + } => Self::Curvature { + area_factor: *area_factor, + euclid_factor: *euclid_factor, + horizontal_factor: *horizontal_factor, + }, + Self::Triangle => Self::Triangle, + Self::AbsMinLog => Self::AbsMinLog, + Self::PythonCallback { + callback, + nth_neighbors, + } => Python::with_gil(|py| Self::PythonCallback { + callback: callback.clone_ref(py), + nth_neighbors: *nth_neighbors, + }), + } + } +} + +// ---- Loss computation dispatch ---- + +pub fn compute_loss(loss_fn: &LossFunction, xs: &[Option], ys: &[Option<&YValue>]) -> f64 { + match loss_fn { + LossFunction::Default => default_loss(xs, ys), + LossFunction::Uniform => uniform_loss(xs), + LossFunction::Resolution { + min_length, + max_length, + } => resolution_loss(xs, ys, *min_length, *max_length), + LossFunction::Curvature { + area_factor, + euclid_factor, + horizontal_factor, + } => curvature_loss(xs, ys, *area_factor, *euclid_factor, *horizontal_factor), + LossFunction::Triangle => triangle_loss_fn(xs, ys), + LossFunction::AbsMinLog => abs_min_log_loss(xs, ys), + LossFunction::PythonCallback { callback, .. } => python_callback_loss(callback, xs, ys), + } +} + +// ---- Individual loss functions ---- + +fn default_loss(xs: &[Option], ys: &[Option<&YValue>]) -> f64 { + let (x0, x1) = (xs[0].unwrap(), xs[1].unwrap()); + let (y0, y1) = (ys[0].unwrap(), ys[1].unwrap()); + let dx = x1 - x0; + match (y0, y1) { + (YValue::Scalar(a), YValue::Scalar(b)) => { + let dy = b - a; + (dx * dx + dy * dy).sqrt() + } + (YValue::Vector(a), YValue::Vector(b)) => a + .iter() + .zip(b.iter()) + .map(|(ai, bi)| { + let dy = (bi - ai).abs(); + (dx * dx + dy * dy).sqrt() + }) + .fold(f64::NEG_INFINITY, f64::max), + _ => 0.0, + } +} + +fn uniform_loss(xs: &[Option]) -> f64 { + xs[1].unwrap() - xs[0].unwrap() +} + +fn resolution_loss( + xs: &[Option], + ys: &[Option<&YValue>], + min_length: f64, + max_length: f64, +) -> f64 { + let dx = uniform_loss(xs); + if dx < min_length { + return 0.0; + } + if dx > max_length { + return f64::INFINITY; + } + default_loss(xs, ys) +} + +fn triangle_loss_fn(xs: &[Option], ys: &[Option<&YValue>]) -> f64 { + let points: Vec<(f64, &YValue)> = xs + .iter() + .zip(ys.iter()) + .filter_map(|(x, y)| match (x, y) { + (Some(xv), Some(yv)) => Some((*xv, *yv)), + _ => None, + }) + .collect(); + + if points.len() <= 1 { + return 0.0; + } + if points.len() == 2 { + return points[1].0 - points[0].0; + } + + let n_tri = points.len() - 2; + let is_vec = matches!(points[0].1, YValue::Vector(_)); + + let mut total = 0.0; + for i in 0..n_tri { + if is_vec { + let mk = |p: &(f64, &YValue)| -> Vec { + let mut v = vec![p.0]; + if let YValue::Vector(arr) = p.1 { + v.extend(arr); + } + v + }; + total += simplex_vol_tri(&mk(&points[i]), &mk(&points[i + 1]), &mk(&points[i + 2])); + } else { + let gy = |p: &(f64, &YValue)| -> f64 { + if let YValue::Scalar(v) = p.1 { + *v + } else { + 0.0 + } + }; + total += tri_area_2d( + points[i].0, + gy(&points[i]), + points[i + 1].0, + gy(&points[i + 1]), + points[i + 2].0, + gy(&points[i + 2]), + ); + } + } + total / n_tri as f64 +} + +fn curvature_loss( + xs: &[Option], + ys: &[Option<&YValue>], + area_factor: f64, + euclid_factor: f64, + horizontal_factor: f64, +) -> f64 { + let tri_l = triangle_loss_fn(xs, ys); + let def_l = default_loss(&xs[1..3], &ys[1..3]); + let dx = uniform_loss(&xs[1..3]); + area_factor * tri_l.sqrt() + euclid_factor * def_l + horizontal_factor * dx +} + +fn abs_min_log_loss(xs: &[Option], ys: &[Option<&YValue>]) -> f64 { + let ys_log: Vec> = ys + .iter() + .map(|y| { + y.map(|yv| { + let min_abs = match yv { + YValue::Scalar(v) => v.abs(), + YValue::Vector(v) => v.iter().map(|x| x.abs()).fold(f64::INFINITY, f64::min), + }; + YValue::Scalar(min_abs.ln()) + }) + }) + .collect(); + let refs: Vec> = ys_log.iter().map(|y| y.as_ref()).collect(); + default_loss(xs, &refs) +} + +fn python_callback_loss(callback: &PyObject, xs: &[Option], ys: &[Option<&YValue>]) -> f64 { + Python::with_gil(|py| { + let xs_py: Vec = xs + .iter() + .map(|x| match x { + Some(v) => v.into_pyobject(py).unwrap().into_any().unbind(), + None => py.None(), + }) + .collect(); + let ys_py: Vec = ys + .iter() + .map(|y| match y { + Some(YValue::Scalar(v)) => v.into_pyobject(py).unwrap().into_any().unbind(), + Some(YValue::Vector(v)) => { + pyo3::types::PyList::new(py, v).unwrap().into_any().unbind() + } + None => py.None(), + }) + .collect(); + let xs_tuple = pyo3::types::PyTuple::new(py, &xs_py).unwrap(); + let ys_tuple = pyo3::types::PyTuple::new(py, &ys_py).unwrap(); + match callback.call1(py, (xs_tuple, ys_tuple)) { + Ok(result) => result.extract::(py).unwrap_or_else(|e| { + e.print(py); + f64::INFINITY + }), + Err(err) => { + err.print(py); + f64::INFINITY + } + } + }) +} + +// ---- Geometry helpers ---- + +fn tri_area_2d(x0: f64, y0: f64, x1: f64, y1: f64, x2: f64, y2: f64) -> f64 { + ((x1 - x0) * (y2 - y0) - (x2 - x0) * (y1 - y0)).abs() / 2.0 +} + +fn simplex_vol_tri(v0: &[f64], v1: &[f64], v2: &[f64]) -> f64 { + let dim = v0.len(); + let e1: Vec = (0..dim).map(|i| v1[i] - v0[i]).collect(); + let e2: Vec = (0..dim).map(|i| v2[i] - v0[i]).collect(); + let d11: f64 = e1.iter().map(|x| x * x).sum(); + let d12: f64 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum(); + let d22: f64 = e2.iter().map(|x| x * x).sum(); + (d11 * d22 - d12 * d12).abs().sqrt() / 2.0 +} + +// ---- Finite-loss rounding ---- + +pub fn round_loss(loss: f64) -> f64 { + let fac = 1e12_f64; + (loss * fac + 0.5).floor() / fac +} + +pub fn finite_loss_value(loss: f64, left: f64, right: f64, x_scale: f64) -> f64 { + finite_loss_with_n(loss, left, right, 1, x_scale) +} + +pub fn finite_loss_with_n(loss: f64, left: f64, right: f64, n: usize, x_scale: f64) -> f64 { + let loss = if !loss.is_finite() { + (right - left) / x_scale / n as f64 + } else { + loss + }; + round_loss(loss) +} + +// ---- Priority-queue entry (sorted ascending → first() = max loss) ---- + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct LossEntry { + pub neg_finite_loss: OF64, + pub left: OF64, + pub right: OF64, +} + +// ---- Interval key ---- + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct Interval { + pub left: OF64, + pub right: OF64, +} + +// ---- LossManager ---- + +/// Maps intervals to losses, with a priority queue sorted by finite-loss. +#[derive(Clone)] +pub struct LossManager { + interval_to_loss: HashMap, + queue: BTreeSet, + pub x_scale: f64, +} + +impl LossManager { + pub fn new(x_scale: f64) -> Self { + Self { + interval_to_loss: HashMap::new(), + queue: BTreeSet::new(), + x_scale, + } + } + + fn make_entry(&self, left: OF64, right: OF64, loss: f64) -> LossEntry { + let fl = finite_loss_value(loss, left.into_inner(), right.into_inner(), self.x_scale); + LossEntry { + neg_finite_loss: OF64::from(-fl), + left, + right, + } + } + + fn loss_for_entry(&self, entry: &LossEntry) -> f64 { + self.interval_to_loss[&Interval { + left: entry.left, + right: entry.right, + }] + } + + pub fn insert(&mut self, left: OF64, right: OF64, loss: f64) { + let ival = Interval { left, right }; + if let Some(old_loss) = self.interval_to_loss.remove(&ival) { + self.queue.remove(&self.make_entry(left, right, old_loss)); + } + self.interval_to_loss.insert(ival, loss); + self.queue.insert(self.make_entry(left, right, loss)); + } + + pub fn remove(&mut self, left: OF64, right: OF64) -> Option { + let ival = Interval { left, right }; + if let Some(loss) = self.interval_to_loss.remove(&ival) { + self.queue.remove(&self.make_entry(left, right, loss)); + Some(loss) + } else { + None + } + } + + pub fn get(&self, left: OF64, right: OF64) -> Option { + self.interval_to_loss + .get(&Interval { left, right }) + .copied() + } + + /// Raw loss of the highest-priority interval. + pub fn peek_max_loss(&self) -> Option { + self.queue + .iter() + .next() + .map(|entry| self.loss_for_entry(entry)) + } + + /// Iterate entries in priority order (highest loss first). + /// Yields `(left, right, raw_loss)`. + pub fn iter_by_priority(&self) -> impl Iterator + '_ { + self.queue + .iter() + .map(move |entry| (entry.left, entry.right, self.loss_for_entry(entry))) + } + + /// Collect all intervals (no particular order). + pub fn all_intervals(&self) -> Vec<(OF64, OF64)> { + self.interval_to_loss + .keys() + .map(|iv| (iv.left, iv.right)) + .collect() + } +} diff --git a/src/learner1d/mod.rs b/src/learner1d/mod.rs new file mode 100644 index 0000000..14a83ec --- /dev/null +++ b/src/learner1d/mod.rs @@ -0,0 +1,773 @@ +pub mod loss; +pub mod python; + +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::ops::Bound; + +use ordered_float::OrderedFloat; + +use self::loss::{compute_loss, finite_loss_value, finite_loss_with_n, LossFunction, LossManager}; + +pub type OF64 = OrderedFloat; + +/// Represents a function output value (scalar or vector). +#[derive(Clone, Debug)] +pub enum YValue { + Scalar(f64), + Vector(Vec), +} + +/// Adaptive 1D learner backed by `BTreeMap` internals. +pub struct Learner1D { + // Core data + pub(crate) data: BTreeMap, + /// Data points outside bounds — stored separately so they don't affect neighbor queries. + pub(crate) out_of_bounds_data: HashMap, + pub(crate) pending: BTreeSet, + /// Union of `data` keys and `pending`. + combined_points: BTreeSet, + + // Loss tracking + pub(crate) losses: LossManager, + pub(crate) losses_combined: LossManager, + + // Scaling + pub(crate) bounds: (f64, f64), + x_scale: f64, + y_scale: f64, + old_y_scale: f64, + y_min: Vec, + y_max: Vec, + + // Config + pub(crate) loss_fn: LossFunction, + nth_neighbors: usize, + pub(crate) vdim: Option, + dx_eps: f64, +} + +// ---- BTreeMap / BTreeSet neighbour helpers ---- + +fn predecessor_map(map: &BTreeMap, x: OF64) -> Option { + map.range(..x).next_back().map(|(&k, _)| k) +} + +fn successor_map(map: &BTreeMap, x: OF64) -> Option { + map.range((Bound::Excluded(x), Bound::Unbounded)) + .next() + .map(|(&k, _)| k) +} + +fn predecessor_set(set: &BTreeSet, x: OF64) -> Option { + set.range(..x).next_back().copied() +} + +fn successor_set(set: &BTreeSet, x: OF64) -> Option { + set.range((Bound::Excluded(x), Bound::Unbounded)) + .next() + .copied() +} + +/// Return n-1 interior points evenly spaced inside `(left, right)`. +fn linspace_interior(left: f64, right: f64, n: usize) -> Vec { + if n <= 1 { + return vec![]; + } + let step = (right - left) / n as f64; + (1..n).map(|i| left + step * i as f64).collect() +} + +/// Update y-min/y-max bounds from a single data point. +fn update_y_bounds(y_min: &mut Vec, y_max: &mut Vec, y: &YValue) { + match y { + YValue::Scalar(v) => { + if y_min.is_empty() { + *y_min = vec![*v]; + *y_max = vec![*v]; + } else { + y_min[0] = y_min[0].min(*v); + y_max[0] = y_max[0].max(*v); + } + } + YValue::Vector(v) => { + if y_min.is_empty() { + *y_min = v.clone(); + *y_max = v.clone(); + } else { + for (i, val) in v.iter().enumerate() { + if i < y_min.len() { + y_min[i] = y_min[i].min(*val); + y_max[i] = y_max[i].max(*val); + } + } + } + } + } +} + +impl Learner1D { + // ---------------------------------------------------------------- + // Construction + // ---------------------------------------------------------------- + + pub fn new(bounds: (f64, f64), loss_fn: LossFunction) -> Self { + let x_scale = bounds.1 - bounds.0; + let nth_neighbors = loss_fn.nth_neighbors(); + let dx_eps = 2.0 * f64::max(bounds.0.abs(), bounds.1.abs()) * f64::EPSILON; + Self { + data: BTreeMap::new(), + out_of_bounds_data: HashMap::new(), + pending: BTreeSet::new(), + combined_points: BTreeSet::new(), + losses: LossManager::new(x_scale), + losses_combined: LossManager::new(x_scale), + bounds, + x_scale, + y_scale: 0.0, + old_y_scale: 0.0, + y_min: Vec::new(), + y_max: Vec::new(), + loss_fn, + nth_neighbors, + vdim: None, + dx_eps, + } + } + + // ---------------------------------------------------------------- + // Public API + // ---------------------------------------------------------------- + + pub fn tell(&mut self, x: f64, y: YValue) { + let xo = OF64::from(x); + if self.data.contains_key(&xo) || self.out_of_bounds_data.contains_key(&xo) { + return; + } + + self.set_vdim_if_unknown(&y); + + self.pending.remove(&xo); + + if !(self.bounds.0 <= x && x <= self.bounds.1) { + self.out_of_bounds_data.insert(xo, y); + return; + } + + self.data.insert(xo, y.clone()); + self.combined_points.insert(xo); + + self.update_scale(&y); + self.update_losses(xo, true); + + let should_recompute = (self.old_y_scale == 0.0 && self.y_scale > 0.0) + || (self.old_y_scale > 0.0 && self.y_scale > 2.0 * self.old_y_scale); + if should_recompute { + let ivals: Vec<(OF64, OF64)> = self.losses.all_intervals(); + for (l, r) in ivals { + self.update_interpolated_loss_in_interval(l, r); + } + self.old_y_scale = self.y_scale; + } + } + + pub fn tell_pending(&mut self, x: f64) { + let xo = OF64::from(x); + if self.data.contains_key(&xo) { + return; + } + self.pending.insert(xo); + self.combined_points.insert(xo); + self.update_losses(xo, false); + } + + pub fn tell_many(&mut self, xs: &[f64], ys: &[YValue], force: bool) { + if xs.is_empty() { + return; + } + let should_rebuild = + force || (xs.len() > 2 && xs.len() as f64 > 0.5 * self.data.len() as f64); + if !should_rebuild { + // Incremental path + for (x, y) in xs.iter().zip(ys.iter()) { + self.tell(*x, y.clone()); + } + return; + } + // Fast rebuild path + for (x, y) in xs.iter().zip(ys.iter()) { + let xo = OF64::from(*x); + self.pending.remove(&xo); + self.set_vdim_if_unknown(y); + if !(self.bounds.0 <= *x && *x <= self.bounds.1) { + self.out_of_bounds_data.insert(xo, y.clone()); + } else { + self.data.insert(xo, y.clone()); + } + } + + // Rebuild combined_points + self.combined_points = self + .data + .keys() + .copied() + .chain(self.pending.iter().copied()) + .collect(); + + // Rebuild scale from all data + self.rebuild_scale(); + + // Rebuild losses + let real_pts: Vec = self.data.keys().copied().collect(); + let comb_pts: Vec = self.combined_points.iter().copied().collect(); + let intervals: Vec<(OF64, OF64)> = real_pts.windows(2).map(|w| (w[0], w[1])).collect(); + let intervals_combined: Vec<(OF64, OF64)> = + comb_pts.windows(2).map(|w| (w[0], w[1])).collect(); + + self.losses = LossManager::new(self.x_scale); + for &(l, r) in &intervals { + let loss = self.get_loss_in_interval(l, r); + self.losses.insert(l, r, loss); + } + + self.losses_combined = LossManager::new(self.x_scale); + let mut to_interpolate: Vec<(OF64, OF64)> = Vec::new(); + for &(l, r) in &intervals_combined { + match self.losses.get(l, r) { + Some(loss) => self.losses_combined.insert(l, r, loss), + None => { + self.losses_combined.insert(l, r, f64::INFINITY); + if matches!(to_interpolate.last(), Some((_, last_right)) if *last_right == l) { + to_interpolate.last_mut().unwrap().1 = r; + } else { + to_interpolate.push((l, r)); + } + } + } + } + for (l, r) in to_interpolate { + if self.losses.get(l, r).is_some() { + self.update_interpolated_loss_in_interval(l, r); + } + } + } + + /// Return `n` points that are expected to maximally reduce the loss, + /// plus the corresponding loss improvements. + pub fn ask(&mut self, n: usize, do_tell_pending: bool) -> (Vec, Vec) { + let (points, improvements) = self.ask_points_without_adding(n); + if do_tell_pending { + for &p in &points { + self.tell_pending(p); + } + } + (points, improvements) + } + + pub fn loss(&self, real: bool) -> f64 { + if self.is_missing_bound(self.bounds.0) || self.is_missing_bound(self.bounds.1) { + return f64::INFINITY; + } + let mgr = if real { + &self.losses + } else { + &self.losses_combined + }; + mgr.peek_max_loss().unwrap_or(f64::INFINITY) + } + + pub fn npoints(&self) -> usize { + self.data.len() + self.out_of_bounds_data.len() + } + + pub fn remove_unfinished(&mut self) { + self.pending.clear(); + self.losses_combined = self.losses.clone(); + self.combined_points = self.data.keys().copied().collect(); + } + + /// Sorted `(x, y)` data. + pub fn to_sorted_data(&self) -> Vec<(f64, YValue)> { + self.data + .iter() + .map(|(&x, y)| (x.into_inner(), y.clone())) + .collect() + } + + /// `(x_left, x_right, loss)` for all real intervals. + pub fn intervals_with_loss(&self) -> Vec<(f64, f64, f64)> { + let pts: Vec = self.data.keys().copied().collect(); + pts.windows(2) + .map(|w| { + let l = w[0]; + let r = w[1]; + let loss = self.losses.get(l, r).unwrap_or(0.0); + (l.into_inner(), r.into_inner(), loss) + }) + .collect() + } + + // ---------------------------------------------------------------- + // Internal: scale management + // ---------------------------------------------------------------- + + fn set_vdim_if_unknown(&mut self, y: &YValue) { + if self.vdim.is_none() { + self.vdim = Some(match y { + YValue::Scalar(_) => 1, + YValue::Vector(values) => values.len(), + }); + } + } + + fn refresh_y_scale(&mut self) { + self.y_scale = self + .y_min + .iter() + .zip(self.y_max.iter()) + .map(|(lo, hi)| hi - lo) + .fold(0.0_f64, f64::max); + } + + fn update_scale(&mut self, y: &YValue) { + update_y_bounds(&mut self.y_min, &mut self.y_max, y); + self.refresh_y_scale(); + } + + fn rebuild_scale(&mut self) { + self.y_min.clear(); + self.y_max.clear(); + for y in self.data.values() { + update_y_bounds(&mut self.y_min, &mut self.y_max, y); + } + self.refresh_y_scale(); + // x_scale is always the domain width (matches Python where x_scale = bounds[1] - bounds[0]) + self.x_scale = self.bounds.1 - self.bounds.0; + self.old_y_scale = self.y_scale; + } + + fn scale_y(&self, y: &YValue) -> YValue { + let ys = if self.y_scale != 0.0 { + self.y_scale + } else { + 1.0 + }; + match y { + YValue::Scalar(v) => YValue::Scalar(v / ys), + YValue::Vector(v) => YValue::Vector(v.iter().map(|x| x / ys).collect()), + } + } + + // ---------------------------------------------------------------- + // Internal: neighbour queries + // ---------------------------------------------------------------- + + fn find_real_neighbors(&self, x: OF64) -> (Option, Option) { + (predecessor_map(&self.data, x), successor_map(&self.data, x)) + } + + fn find_combined_neighbors(&self, x: OF64) -> (Option, Option) { + ( + predecessor_set(&self.combined_points, x), + successor_set(&self.combined_points, x), + ) + } + + /// Intervals affected by adding `x` to the real set. + /// With `nth_neighbors = 0`: `[(x_left, x), (x, x_right)]`. + /// With `nth_neighbors = 1`: extends one further on each side. + fn get_affected_intervals(&self, x: OF64) -> Vec<(OF64, OF64)> { + let nn = self.nth_neighbors; + let mut pts: Vec = Vec::new(); + + // Collect up to nn+1 predecessors (including x_left, x_{left-1}, …) + let mut cur = x; + let mut before = Vec::new(); + for _ in 0..(nn + 1) { + if let Some(prev) = predecessor_map(&self.data, cur) { + before.push(prev); + cur = prev; + } else { + break; + } + } + before.reverse(); + pts.extend(before); + + pts.push(x); + + // Collect up to nn+1 successors + cur = x; + for _ in 0..(nn + 1) { + if let Some(next) = successor_map(&self.data, cur) { + pts.push(next); + cur = next; + } else { + break; + } + } + + pts.windows(2).map(|w| (w[0], w[1])).collect() + } + + /// Gather the neighbourhood points (with optional None padding) + /// for loss computation on interval `[x_left, x_right]`. + fn get_neighborhood( + &self, + x_left: OF64, + x_right: OF64, + ) -> (Vec>, Vec>) { + let nn = self.nth_neighbors; + let mut xs: Vec> = Vec::with_capacity(2 + 2 * nn); + let mut ys: Vec> = Vec::with_capacity(2 + 2 * nn); + + // Predecessors of x_left + let mut before_x: Vec> = Vec::new(); + let mut cur = x_left; + for _ in 0..nn { + let prev = predecessor_map(&self.data, cur); + before_x.push(prev); + if let Some(p) = prev { + cur = p; + } else { + break; + } + } + before_x.resize(nn, None); + before_x.reverse(); + + for p in &before_x { + xs.push(p.map(|v| v.into_inner())); + ys.push(p.and_then(|v| self.data.get(&v))); + } + + // The interval endpoints + xs.push(Some(x_left.into_inner())); + ys.push(self.data.get(&x_left)); + xs.push(Some(x_right.into_inner())); + ys.push(self.data.get(&x_right)); + + // Successors of x_right + cur = x_right; + for _ in 0..nn { + let next = successor_map(&self.data, cur); + xs.push(next.map(|v| v.into_inner())); + ys.push(next.and_then(|v| self.data.get(&v))); + if let Some(n) = next { + cur = n; + } else { + break; + } + } + xs.resize(2 + 2 * nn, None); + ys.resize(2 + 2 * nn, None); + + (xs, ys) + } + + // ---------------------------------------------------------------- + // Internal: loss computation + // ---------------------------------------------------------------- + + fn get_loss_in_interval(&self, x_left: OF64, x_right: OF64) -> f64 { + let dx = x_right.into_inner() - x_left.into_inner(); + if dx < self.dx_eps { + return 0.0; + } + let (xs_raw, ys_raw) = self.get_neighborhood(x_left, x_right); + let xs_sc: Vec> = xs_raw.iter().map(|x| x.map(|v| v / self.x_scale)).collect(); + let ys_scaled: Vec> = + ys_raw.iter().map(|y| y.map(|v| self.scale_y(v))).collect(); + let ys_refs: Vec> = ys_scaled.iter().map(|y| y.as_ref()).collect(); + compute_loss(&self.loss_fn, &xs_sc, &ys_refs) + } + + fn update_interpolated_loss_in_interval(&mut self, x_left: OF64, x_right: OF64) { + let loss = self.get_loss_in_interval(x_left, x_right); + self.losses.insert(x_left, x_right, loss); + + // Walk combined points between x_left and x_right, setting + // interpolated losses proportional to sub-interval width. + let dx = (x_right - x_left).into_inner(); + if dx == 0.0 { + return; + } + let mut a = x_left; + while a != x_right { + if let Some(b) = successor_set(&self.combined_points, a) { + if b > x_right { + break; + } + let sub_loss = (b - a).into_inner() * loss / dx; + self.losses_combined.insert(a, b, sub_loss); + a = b; + } else { + break; + } + } + } + + fn update_losses(&mut self, x: OF64, real: bool) { + let (x_left, x_right) = self.find_real_neighbors(x); + let (a, b) = self.find_combined_neighbors(x); + + // Remove the old combined interval that x now splits. + if let (Some(a_val), Some(b_val)) = (a, b) { + self.losses_combined.remove(a_val, b_val); + } + + if real { + // Recompute losses for all affected intervals + let affected = self.get_affected_intervals(x); + for (il, ir) in affected { + self.update_interpolated_loss_in_interval(il, ir); + } + // Remove the old real interval + if let (Some(xl), Some(xr)) = (x_left, x_right) { + self.losses.remove(xl, xr); + self.losses_combined.remove(xl, xr); + } + } else if let (Some(xl), Some(xr)) = (x_left, x_right) { + // Interpolate from the real interval + if let Some(loss) = self.losses.get(xl, xr) { + let dx = (xr - xl).into_inner(); + if let Some(a_val) = a { + let sub = (x - a_val).into_inner() * loss / dx; + self.losses_combined.insert(a_val, x, sub); + } + if let Some(b_val) = b { + let sub = (b_val - x).into_inner() * loss / dx; + self.losses_combined.insert(x, b_val, sub); + } + } + } + + // Handle unknown-loss edges (pending point with no real neighbour on a side). + let left_unknown = x_left.is_none() || (!real && x_right.is_none()); + let right_unknown = x_right.is_none() || (!real && x_left.is_none()); + + if let Some(a_val) = a { + if left_unknown { + self.losses_combined.insert(a_val, x, f64::INFINITY); + } + } + if let Some(b_val) = b { + if right_unknown { + self.losses_combined.insert(x, b_val, f64::INFINITY); + } + } + } + + // ---------------------------------------------------------------- + // Internal: ask algorithm + // ---------------------------------------------------------------- + + fn is_missing_bound(&self, bound: f64) -> bool { + let bound = OF64::from(bound); + !self.data.contains_key(&bound) && !self.pending.contains(&bound) + } + + fn missing_bounds(&self) -> Vec { + [self.bounds.0, self.bounds.1] + .into_iter() + .filter(|&bound| self.is_missing_bound(bound)) + .collect() + } + + fn ask_points_without_adding(&self, n: usize) -> (Vec, Vec) { + if n == 0 { + return (vec![], vec![]); + } + + let missing = self.missing_bounds(); + if missing.len() >= n { + return (missing[..n].to_vec(), vec![f64::INFINITY; n]); + } + + if self.data.is_empty() && self.pending.is_empty() { + let (a, b) = self.bounds; + let pts: Vec = if n == 1 { + vec![a] + } else { + (0..n) + .map(|i| a + (b - a) * i as f64 / (n - 1) as f64) + .collect() + }; + return (pts, vec![f64::INFINITY; n]); + } + + // --- Build quals via merge of losses_combined + quals priority queue --- + + // Quals: sorted by (neg_finite_loss, left, right) → first() = max loss + // Each entry stores the interval, n (number of splits), and raw loss. + #[derive(Clone, Copy)] + struct Qual { + left: f64, + right: f64, + n: usize, + loss: f64, + } + + // We store quals in a BTreeMap keyed by (neg_fl, left, right, n) + // to ensure uniqueness and efficient max access. + #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + struct QKey { + neg_fl: OF64, + left: OF64, + right: OF64, + n: u64, + } + + let x_scale = self.x_scale; + let mk_key = |q: &Qual| -> QKey { + let fl = finite_loss_with_n(q.loss, q.left, q.right, q.n, x_scale); + QKey { + neg_fl: OF64::from(-fl), + left: OF64::from(q.left), + right: OF64::from(q.right), + n: q.n as u64, + } + }; + + let mut quals: BTreeMap = BTreeMap::new(); + let insert_qual = |quals: &mut BTreeMap, q: Qual| { + let key = mk_key(&q); + quals.insert(key, q); + }; + + // Add missing-bound intervals to quals + if !missing.is_empty() { + let all_pts: Vec = self + .data + .keys() + .chain(self.pending.iter()) + .map(|x| x.into_inner()) + .collect(); + let min_pt = all_pts.iter().copied().fold(f64::INFINITY, f64::min); + let max_pt = all_pts.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let bound_intervals = [(self.bounds.0, min_pt), (max_pt, self.bounds.1)]; + for (ival, &bound) in bound_intervals.iter().zip(&[self.bounds.0, self.bounds.1]) { + if self.is_missing_bound(bound) { + insert_qual( + &mut quals, + Qual { + left: ival.0, + right: ival.1, + n: 1, + loss: f64::INFINITY, + }, + ); + } + } + } + + let points_to_go = n - missing.len(); + + // Collect combined losses in priority order + let combined_sorted: Vec<(f64, f64, f64)> = self + .losses_combined + .iter_by_priority() + .map(|(l, r, loss)| (l.into_inner(), r.into_inner(), loss)) + .collect(); + let i_max = combined_sorted.len(); + let mut i = 0; + + for _ in 0..points_to_go { + // Peek at best qual + let qual_top = quals.first_key_value().map(|(k, q)| (*k, *q)); + // Peek at next from combined + let ival_entry = if i < i_max { + Some(combined_sorted[i]) + } else { + None + }; + + let prefer_combined = match (&ival_entry, &qual_top) { + (Some((il, ir, loss_c)), Some((qk, _qq))) => { + let fl_c = finite_loss_value(*loss_c, *il, *ir, x_scale); + let fl_q = -qk.neg_fl.into_inner(); + // Python compares (loss, interval) tuples lexicographically + if fl_c != fl_q { + fl_c >= fl_q + } else { + (*il, *ir) >= (qk.left.into_inner(), qk.right.into_inner()) + } + } + (Some(_), None) => true, + (None, Some(_)) => false, + (None, None) => false, + }; + + if prefer_combined { + let (il, ir, loss_c) = ival_entry.unwrap(); + i += 1; + insert_qual( + &mut quals, + Qual { + left: il, + right: ir, + n: 2, + loss: loss_c / 2.0, + }, + ); + } else { + // Pop from quals and re-insert with n+1 + let (qk, qq) = qual_top.unwrap(); + quals.remove(&qk); + let new_n = qq.n + 1; + let new_loss = qq.loss * qq.n as f64 / new_n as f64; + insert_qual( + &mut quals, + Qual { + left: qq.left, + right: qq.right, + n: new_n, + loss: new_loss, + }, + ); + } + } + + // Generate points and loss_improvements from quals + // Quals are iterated in sorted order (highest loss first), + // which determines the order of points. + let mut points: Vec = Vec::with_capacity(n); + let mut improvements: Vec = Vec::with_capacity(n); + + points.extend_from_slice(&missing); + improvements.extend(std::iter::repeat_n(f64::INFINITY, missing.len())); + + for q in quals.values() { + let interior = linspace_interior(q.left, q.right, q.n); + improvements.extend(std::iter::repeat_n(q.loss, interior.len())); + points.extend(interior); + } + + (points, improvements) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_tell_and_ask() { + let mut learner = Learner1D::new((0.0, 1.0), LossFunction::Default); + learner.tell(0.0, YValue::Scalar(0.0)); + learner.tell(1.0, YValue::Scalar(1.0)); + let (pts, _) = learner.ask(1, false); + assert_eq!(pts.len(), 1); + assert!((pts[0] - 0.5).abs() < 1e-10); + } + + #[test] + fn test_linspace_interior() { + assert_eq!(linspace_interior(0.0, 1.0, 1), Vec::::new()); + assert_eq!(linspace_interior(0.0, 1.0, 2), vec![0.5]); + let pts = linspace_interior(0.0, 1.0, 4); + assert_eq!(pts.len(), 3); + assert!((pts[0] - 0.25).abs() < 1e-10); + assert!((pts[1] - 0.5).abs() < 1e-10); + assert!((pts[2] - 0.75).abs() < 1e-10); + } +} diff --git a/src/learner1d/python.rs b/src/learner1d/python.rs new file mode 100644 index 0000000..4cc538b --- /dev/null +++ b/src/learner1d/python.rs @@ -0,0 +1,233 @@ +use numpy::{PyArray1, PyArray2}; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; + +use super::loss::LossFunction; +use super::{Learner1D, YValue}; + +/// Extract a `YValue` from a Python object. +fn extract_yvalue(obj: &Bound<'_, pyo3::types::PyAny>) -> PyResult { + if let Ok(v) = obj.extract::() { + return Ok(YValue::Scalar(v)); + } + if let Ok(v) = obj.extract::>() { + return Ok(YValue::Vector(v)); + } + Err(PyValueError::new_err( + "y must be a float or a sequence of floats", + )) +} + +/// Detect `nth_neighbors` attribute on a Python loss callable. +fn detect_nth_neighbors(py: Python<'_>, obj: &PyObject) -> usize { + obj.bind(py) + .getattr("nth_neighbors") + .and_then(|a| a.extract::()) + .unwrap_or(0) +} + +#[pyclass(name = "Learner1D")] +pub struct PyLearner1D { + pub(crate) inner: Learner1D, +} + +#[pymethods] +impl PyLearner1D { + #[new] + #[pyo3(signature = (bounds, loss_per_interval=None))] + fn new( + py: Python<'_>, + bounds: (f64, f64), + loss_per_interval: Option, + ) -> PyResult { + if bounds.0 >= bounds.1 { + return Err(PyValueError::new_err( + "bounds[0] must be strictly less than bounds[1]", + )); + } + let loss_fn = match loss_per_interval { + Some(obj) => { + let nn = detect_nth_neighbors(py, &obj); + LossFunction::PythonCallback { + callback: obj, + nth_neighbors: nn, + } + } + None => LossFunction::Default, + }; + Ok(Self { + inner: Learner1D::new(bounds, loss_fn), + }) + } + + fn tell(&mut self, x: f64, y: &Bound<'_, pyo3::types::PyAny>) -> PyResult<()> { + let yv = extract_yvalue(y)?; + self.inner.tell(x, yv); + Ok(()) + } + + #[pyo3(signature = (xs, ys, force=false))] + fn tell_many( + &mut self, + xs: Vec, + ys: Vec>, + force: bool, + ) -> PyResult<()> { + let yvalues: Vec = ys.iter().map(extract_yvalue).collect::>()?; + self.inner.tell_many(&xs, &yvalues, force); + Ok(()) + } + + fn tell_pending(&mut self, x: f64) { + self.inner.tell_pending(x); + } + + #[pyo3(signature = (n, tell_pending=true))] + fn ask(&mut self, n: usize, tell_pending: bool) -> (Vec, Vec) { + self.inner.ask(n, tell_pending) + } + + /// Run the full adaptive loop — only user function evals cross the PyO3 boundary. + #[pyo3(signature = (f, *, goal=None, n_points=None, batch_size=1))] + fn run( + &mut self, + py: Python<'_>, + f: PyObject, + goal: Option, + n_points: Option, + batch_size: usize, + ) -> PyResult { + let mut n_evaluated: usize = 0; + let bs = batch_size.max(1); + loop { + // Check stopping conditions + if let Some(g) = goal { + if self.inner.npoints() > 0 && self.inner.loss(true) <= g { + break; + } + } + if let Some(np) = n_points { + if n_evaluated >= np { + break; + } + } + if goal.is_none() && n_points.is_none() { + break; + } + + let ask_n = if let Some(np) = n_points { + bs.min(np - n_evaluated) + } else { + bs + }; + + let (xs, _) = self.inner.ask(ask_n, true); + if xs.is_empty() { + break; + } + + let yvalues: Vec = xs + .iter() + .map(|&x| { + let result = f.call1(py, (x,))?; + extract_yvalue(result.bind(py)) + }) + .collect::>()?; + + self.inner.tell_many(&xs, &yvalues, false); + n_evaluated += xs.len(); + } + Ok(n_evaluated) + } + + #[pyo3(signature = (real=true))] + fn loss(&self, real: bool) -> f64 { + self.inner.loss(real) + } + + #[getter] + fn npoints(&self) -> usize { + self.inner.npoints() + } + + #[getter] + fn vdim(&self) -> Option { + self.inner.vdim + } + + fn to_numpy<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyArray1>, PyObject)> { + let data = self.inner.to_sorted_data(); + if data.is_empty() { + let xs = PyArray1::from_vec(py, vec![]); + let ys: PyObject = PyArray1::from_vec(py, Vec::::new()) + .into_any() + .unbind(); + return Ok((xs, ys)); + } + + let xs: Vec = data.iter().map(|(x, _)| *x).collect(); + let xs_arr = PyArray1::from_vec(py, xs); + + let vdim = self.inner.vdim.unwrap_or(1); + if vdim == 1 { + let ys: Vec = data + .iter() + .map(|(_, y)| match y { + YValue::Scalar(v) => *v, + _ => 0.0, + }) + .collect(); + let ys_arr: PyObject = PyArray1::from_vec(py, ys).into_any().unbind(); + Ok((xs_arr, ys_arr)) + } else { + let rows: Vec> = data + .iter() + .map(|(_, y)| match y { + YValue::Scalar(v) => vec![*v], + YValue::Vector(v) => v.clone(), + }) + .collect(); + let ys_arr: PyObject = PyArray2::from_vec2(py, &rows)?.into_any().unbind(); + Ok((xs_arr, ys_arr)) + } + } + + fn remove_unfinished(&mut self) { + self.inner.remove_unfinished(); + } + + #[getter] + fn pending_points(&self) -> Vec { + self.inner.pending.iter().map(|x| x.into_inner()).collect() + } + + #[getter] + fn data<'py>(&self, py: Python<'py>) -> PyResult { + let dict = pyo3::types::PyDict::new(py); + // Include both in-bounds and out-of-bounds data (matches Python behavior) + for (&x, y) in self + .inner + .data + .iter() + .chain(self.inner.out_of_bounds_data.iter()) + { + let key = x.into_inner(); + let val: PyObject = match y { + YValue::Scalar(v) => v.into_pyobject(py)?.into_any().unbind(), + YValue::Vector(v) => pyo3::types::PyList::new(py, v)?.into_any().unbind(), + }; + dict.set_item(key, val)?; + } + Ok(dict.into_any().unbind()) + } + + /// Return `[(x_left, x_right, loss), …]` for all real intervals. + fn intervals(&self) -> Vec<(f64, f64, f64)> { + self.inner.intervals_with_loss() + } + + #[getter] + fn bounds(&self) -> (f64, f64) { + self.inner.bounds + } +} diff --git a/src/lib.rs b/src/lib.rs index 7223c06..5498d7d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod geometry; +pub mod learner1d; pub mod triangulation; use pyo3::exceptions::{PyValueError, PyZeroDivisionError}; @@ -39,6 +40,7 @@ fn _rust(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; m.add_function(wrap_pyfunction!(py_circumsphere, m)?)?; diff --git a/tests/test_learner1d.py b/tests/test_learner1d.py new file mode 100644 index 0000000..a95aa81 --- /dev/null +++ b/tests/test_learner1d.py @@ -0,0 +1,402 @@ +"""Comprehensive tests for the Rust-powered Learner1D.""" + +from __future__ import annotations + +import math + +import numpy as np +import pytest +from adaptive_triangulation import Learner1D + + +def sin10(x: float) -> float: + return math.sin(10 * x) + + +def parabola(x: float) -> float: + return x**2 + + +def vector_fn(x: float) -> list[float]: + return [math.sin(x), math.cos(x)] + + +class IntentionalLossError(ValueError): + pass + + +def make_learner(bounds=(0.0, 1.0), points=()): + learner = Learner1D(bounds=bounds) + for x, y in points: + learner.tell(x, y) + return learner + + +def unit_interval_learner() -> Learner1D: + return make_learner(points=((0.0, 0.0), (1.0, 1.0))) + + +def assert_points_close(points, expected) -> None: + assert len(points) == len(expected) + for point, expected_point in zip(sorted(points), expected, strict=True): + assert point == pytest.approx(expected_point) + + +def bad_loss(_xs, _ys): + raise IntentionalLossError + + +def string_loss(_xs, _ys): + return "not a float" + + +@pytest.mark.parametrize("bounds", [(0.0, 1.0), (-1.0, 2.0)]) +def test_create(bounds) -> None: + learner = make_learner(bounds=bounds) + assert learner.npoints == 0 + assert learner.bounds == bounds + + +def test_invalid_bounds() -> None: + with pytest.raises(ValueError, match=r"bounds\[0\] must be strictly less than bounds\[1\]"): + Learner1D(bounds=(1.0, 0.0)) + + +def test_tell_counts_unique_points() -> None: + learner = make_learner() + learner.tell(0.0, 0.0) + assert learner.npoints == 1 + learner.tell(1.0, 1.0) + assert learner.npoints == 2 + learner.tell(1.0, 99.0) + assert learner.npoints == 2 + assert learner.data[1.0] == 1.0 + + +def test_loss_starts_infinite() -> None: + assert math.isinf(make_learner().loss()) + + +@pytest.mark.parametrize( + ("points", "expected"), + [ + (((0.0, 0.0), (1.0, 1.0)), math.sqrt(2)), + (((0.0, 0.0), (1.0, 0.0)), 1.0), + ], +) +def test_loss_with_two_points(points, expected: float) -> None: + assert make_learner(points=points).loss() == pytest.approx(expected) + + +@pytest.mark.parametrize( + ("n", "expected"), + [ + (1, [0.5]), + (3, [0.25, 0.5, 0.75]), + ], +) +def test_ask_subdivides_interval(n: int, expected: list[float]) -> None: + points, improvements = unit_interval_learner().ask(n, tell_pending=False) + assert len(improvements) == n + assert_points_close(points, expected) + + +@pytest.mark.parametrize( + ("bounds", "n", "expected"), + [ + ((0.0, 1.0), 3, [0.0, 0.5, 1.0]), + ((0.0, 1.0), 5, [0.0, 0.25, 0.5, 0.75, 1.0]), + ((-1.0, 1.0), 2, [-1.0, 1.0]), + ], +) +def test_empty_ask_returns_linspace(bounds, n: int, expected: list[float]) -> None: + points, improvements = make_learner(bounds=bounds).ask(n, tell_pending=False) + assert all(math.isinf(improvement) for improvement in improvements) + assert_points_close(points, expected) + + +def test_tell_many_incremental_path() -> None: + learner = make_learner() + xs = [0.0, 0.5, 1.0] + ys = [0.0, 0.25, 1.0] + learner.tell_many(xs, ys) + assert learner.npoints == 3 + + +@pytest.mark.parametrize("force", [False, True]) +def test_tell_many_force_rebuild(force) -> None: + learner = make_learner() + xs = [i / 10 for i in range(11)] + ys = [x**2 for x in xs] + learner.tell_many(xs, ys, force=force) + assert learner.npoints == 11 + assert not math.isinf(learner.loss()) + + +def test_tell_many_large_batch_triggers_rebuild() -> None: + learner = unit_interval_learner() + xs = [0.2, 0.4, 0.6, 0.8, 0.5] + ys = [x**2 for x in xs] + learner.tell_many(xs, ys) + assert learner.npoints == 7 + + +def test_scalar_output_sets_vdim() -> None: + assert unit_interval_learner().vdim == 1 + + +def test_vector_output_supports_ask_and_to_numpy() -> None: + learner = make_learner(points=((0.0, [1.0, 2.0]), (0.5, [2.0, 3.0]), (1.0, [3.0, 4.0]))) + assert learner.vdim == 2 + points, _ = learner.ask(1, tell_pending=False) + assert len(points) == 1 + xs, ys = learner.to_numpy() + assert xs.shape == (3,) + assert ys.shape == (3, 2) + + +def test_tell_pending_and_tell_clears_pending() -> None: + learner = unit_interval_learner() + learner.tell_pending(0.5) + assert 0.5 in learner.pending_points + learner.tell(0.5, 0.25) + assert 0.5 not in learner.pending_points + + +@pytest.mark.parametrize(("tell_pending", "expected_pending"), [(True, 2), (False, 0)]) +def test_ask_pending_mode(tell_pending, expected_pending: int) -> None: + learner = unit_interval_learner() + points, _ = learner.ask(2, tell_pending=tell_pending) + assert len(points) == 2 + assert len(learner.pending_points) == expected_pending + + +def test_pending_splits_intervals() -> None: + learner = unit_interval_learner() + points_without_pending, _ = learner.ask(1, tell_pending=False) + learner.tell_pending(0.5) + points_with_pending, _ = learner.ask(1, tell_pending=False) + assert points_without_pending != points_with_pending + + +def test_run_with_n_points() -> None: + learner = make_learner(bounds=(-1.0, 1.0)) + evaluated = learner.run(sin10, n_points=20, batch_size=5) + assert evaluated == 20 + assert learner.npoints == 20 + + +@pytest.mark.parametrize("batch_size", [1, 10]) +def test_run_with_goal(batch_size: int) -> None: + learner = make_learner(bounds=(-1.0, 1.0)) + evaluated = learner.run(parabola, goal=0.01, batch_size=batch_size) + assert learner.loss() <= 0.01 + assert evaluated > 0 + + +def test_run_batch_size_1() -> None: + assert make_learner().run(parabola, n_points=10, batch_size=1) == 10 + + +def test_run_no_goal_or_npoints_returns_immediately() -> None: + assert make_learner().run(parabola) == 0 + + +def test_to_numpy_empty() -> None: + xs, ys = make_learner().to_numpy() + assert len(xs) == 0 + assert len(ys) == 0 + + +@pytest.mark.parametrize( + "points", + [ + ((0.0, 0.0), (0.5, 0.25), (1.0, 1.0)), + ((1.0, 1.0), (0.0, 0.0), (0.5, 0.25)), + ], +) +def test_to_numpy_scalar(points) -> None: + xs, ys = make_learner(points=points).to_numpy() + np.testing.assert_array_equal(xs, [0.0, 0.5, 1.0]) + np.testing.assert_array_equal(ys, [0.0, 0.25, 1.0]) + + +def test_to_numpy_sorted_order() -> None: + learner = make_learner() + learner.tell(1.0, 1.0) + learner.tell(0.0, 0.0) + learner.tell(0.5, 0.25) + xs, _ = learner.to_numpy() + assert list(xs) == [0.0, 0.5, 1.0] + + +def test_remove_unfinished_clears_pending_without_changing_loss() -> None: + learner = unit_interval_learner() + loss_before = learner.loss() + learner.ask(5, tell_pending=True) + assert len(learner.pending_points) == 5 + learner.remove_unfinished() + assert len(learner.pending_points) == 0 + assert learner.loss() == pytest.approx(loss_before) + + +def test_single_point_loss_stays_infinite() -> None: + learner = make_learner(points=((0.0, 0.0),)) + assert learner.npoints == 1 + assert math.isinf(learner.loss()) + + +def test_tell_at_bounds_yields_finite_loss() -> None: + learner = unit_interval_learner() + assert learner.npoints == 2 + assert not math.isinf(learner.loss()) + + +def test_ask_zero() -> None: + points, improvements = make_learner().ask(0) + assert points == [] + assert improvements == [] + + +@pytest.mark.parametrize("ask_n", [1, 10, 25]) +def test_many_points(ask_n: int) -> None: + learner = make_learner() + xs = [i / 100 for i in range(101)] + ys = [math.sin(x * 10) for x in xs] + learner.tell_many(xs, ys, force=True) + assert learner.npoints == 101 + points, _ = learner.ask(ask_n, tell_pending=False) + assert len(points) == ask_n + + +def test_scale_recompute_triggers_on_large_y_change() -> None: + learner = unit_interval_learner() + loss_before = learner.loss() + learner.tell(0.5, 100.0) + assert learner.loss() != loss_before + + +@pytest.mark.parametrize("y_right", [1.0, 100.0]) +def test_uniform_loss_callback(y_right: float) -> None: + def my_uniform(xs, _ys): + return xs[1] - xs[0] + + learner = Learner1D(bounds=(0.0, 1.0), loss_per_interval=my_uniform) + learner.tell(0.0, 0.0) + learner.tell(1.0, y_right) + points, _ = learner.ask(1, tell_pending=False) + assert points[0] == pytest.approx(0.5) + + +def test_custom_loss_run() -> None: + def my_loss(xs, _ys): + return abs(xs[1] - xs[0]) + + learner = Learner1D(bounds=(0.0, 1.0), loss_per_interval=my_loss) + assert learner.run(parabola, n_points=20, batch_size=5) == 20 + + +def test_data_and_intervals_properties() -> None: + learner = make_learner(points=((0.0, 0.0), (0.5, 0.25), (1.0, 1.0))) + assert learner.data == {0.0: 0.0, 0.5: 0.25, 1.0: 1.0} + + intervals = learner.intervals() + assert intervals == pytest.approx([(0.0, 0.5, intervals[0][2]), (0.5, 1.0, intervals[1][2])]) + assert all(loss >= 0 for _, _, loss in intervals) + + +def test_finite_loss_rounding() -> None: + learner = make_learner() + xs = [i * 0.1 for i in range(11)] + ys = [math.sin(x) for x in xs] + learner.tell_many(xs, ys, force=True) + assert learner.loss() >= 0 + assert math.isfinite(learner.loss()) + + +def test_tiny_interval() -> None: + learner = make_learner() + learner.tell(0.5, 0.0) + learner.tell(0.5 + 1e-15, 1.0) + assert learner.loss() >= 0 + + +def test_convergence() -> None: + learner = make_learner(bounds=(-1.0, 1.0), points=((-1.0, sin10(-1.0)), (1.0, sin10(1.0)))) + losses = [learner.loss()] + for _ in range(50): + points, _ = learner.ask(1, tell_pending=True) + for point in points: + learner.tell(point, sin10(point)) + losses.append(learner.loss()) + assert losses[-1] < losses[0] + + +def test_batch_run_matches_sequential() -> None: + learner_seq = make_learner() + learner_batch = make_learner() + + assert learner_seq.run(parabola, n_points=50, batch_size=1) == 50 + assert learner_batch.run(parabola, n_points=50, batch_size=10) == 50 + assert learner_seq.loss() < 0.1 + assert learner_batch.loss() < 0.1 + + +def test_full_adaptive_loop() -> None: + learner = make_learner(bounds=(-1.0, 1.0)) + learner.run(sin10, goal=0.05, batch_size=5) + assert learner.npoints > 0 + xs, ys = learner.to_numpy() + assert all(xs[i] <= xs[i + 1] for i in range(len(xs) - 1)) + for x, y in zip(xs, ys, strict=True): + assert y == pytest.approx(sin10(x)) + + +def test_out_of_bounds_points_are_stored_counted_and_ignored_in_loss() -> None: + learner = make_learner(bounds=(-1.0, 1.0), points=((-1.0, 0.0), (1.0, 0.0))) + loss_before = learner.loss() + learner.tell(-10.0, 999.0) + learner.tell(5.0, 42.0) + assert learner.npoints == 4 + assert learner.data[-10.0] == 999.0 + assert learner.data[5.0] == 42.0 + assert learner.loss() == pytest.approx(loss_before) + + +def test_oob_does_not_affect_neighbors() -> None: + learner = make_learner(bounds=(-1.0, 1.0)) + learner.tell(-10.0, 0.0) + learner.tell(-1.0, 0.0) + learner.tell(1.0, 0.0) + intervals = learner.intervals() + assert len(intervals) == 1 + assert intervals[0][:2] == (-1.0, 1.0) + + +def test_oob_duplicate_ignored() -> None: + learner = make_learner(bounds=(-1.0, 1.0)) + learner.tell(5.0, 1.0) + learner.tell(5.0, 99.0) + assert learner.npoints == 1 + assert learner.data[5.0] == 1.0 + + +@pytest.mark.parametrize("callback", [bad_loss, string_loss]) +def test_callback_failure_returns_infinity(callback) -> None: + learner = Learner1D(bounds=(0.0, 1.0), loss_per_interval=callback) + learner.tell(0.0, 0.0) + learner.tell(1.0, 1.0) + assert math.isinf(learner.loss()) + + +def test_loss_defaults_to_real() -> None: + learner = unit_interval_learner() + assert learner.loss() == learner.loss(real=True) + + +@pytest.mark.parametrize("pending_x", [0.25, 0.5]) +def test_loss_real_false_accounts_for_pending(pending_x: float) -> None: + learner = unit_interval_learner() + real_loss = learner.loss(real=True) + learner.tell_pending(pending_x) + assert learner.loss(real=False) < real_loss