Skip to content

Commit 8571768

Browse files
committed
Check that fitted parameters have reasonable range
1 parent f405135 commit 8571768

3 files changed

Lines changed: 49 additions & 31 deletions

File tree

ratapi/inputs.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import importlib
44
import os
55
import pathlib
6+
import warnings
67
from collections.abc import Callable
78

89
import numpy as np
910

1011
import ratapi
1112
import ratapi.wrappers
1213
from ratapi.rat_core import Checks, Control, NameStore, ProblemDefinition
13-
from ratapi.utils.enums import Calculations, Languages, LayerModels, TypeOptions
14+
from ratapi.utils.enums import Calculations, Languages, LayerModels, Procedures, TypeOptions
1415

1516
parameter_field = {
1617
"parameters": "params",
@@ -137,13 +138,13 @@ def make_input(project: ratapi.Project, controls: ratapi.Controls) -> tuple[Prob
137138
The controls object used in the compiled RAT code.
138139
139140
"""
140-
problem = make_problem(project)
141+
problem = make_problem(project, controls.procedure != Procedures.Calculate)
141142
cpp_controls = make_controls(controls)
142143

143144
return problem, cpp_controls
144145

145146

146-
def make_problem(project: ratapi.Project) -> ProblemDefinition:
147+
def make_problem(project: ratapi.Project, validate_range: bool = False) -> ProblemDefinition:
147148
"""Construct the problem input required for the compiled RAT code.
148149
149150
Parameters
@@ -351,40 +352,42 @@ def make_problem(project: ratapi.Project) -> ProblemDefinition:
351352
problem.domainContrastLayers = [
352353
domain_contrast_model if domain_contrast_model else [] for domain_contrast_model in domain_contrast_models
353354
]
354-
problem.fitParams = [
355-
param.value
356-
for class_list in ratapi.project.parameter_class_lists
357-
for param in getattr(project, class_list)
358-
if param.fit
359-
]
360-
problem.fitLimits = [
361-
[param.min, param.max]
362-
for class_list in ratapi.project.parameter_class_lists
363-
for param in getattr(project, class_list)
364-
if param.fit
365-
]
366-
problem.priorNames = [
367-
param.name for class_list in ratapi.project.parameter_class_lists for param in getattr(project, class_list)
368-
]
369-
problem.priorValues = [
370-
[prior_id[param.prior_type], param.mu, param.sigma]
371-
for class_list in ratapi.project.parameter_class_lists
372-
for param in getattr(project, class_list)
373-
]
355+
356+
fit_params = []
357+
fit_limits = []
358+
prior_names = []
359+
prior_values = []
360+
problem.checks = Checks()
361+
for class_list in ratapi.project.parameter_class_lists:
362+
field = parameter_field[class_list]
363+
check_list = []
364+
for param in getattr(project, class_list):
365+
prior_names.append(param.name)
366+
prior_values.append([prior_id[param.prior_type], param.mu, param.sigma])
367+
check_list.append(int(param.fit))
368+
if param.fit:
369+
if validate_range and (param.max - param.min) < 1e-10:
370+
warnings.warn(
371+
f'{class_list.replace("_", " ").title()} "{param.name}" was removed from the '
372+
"fit because its range is too small (< 1e-10).",
373+
stacklevel=2,
374+
)
375+
check_list[-1] = 0
376+
else:
377+
fit_params.append(param.value)
378+
fit_limits.append([param.min, param.max])
379+
setattr(problem.checks, field, check_list)
380+
problem.fitParams = fit_params
381+
problem.fitLimits = fit_limits
382+
problem.priorNames = prior_names
383+
problem.priorValues = prior_values
374384

375385
# Names
376386
problem.names = NameStore()
377387
for class_list in ratapi.project.parameter_class_lists:
378388
setattr(problem.names, parameter_field[class_list], [param.name for param in getattr(project, class_list)])
379389
problem.names.contrasts = [contrast.name for contrast in project.contrasts]
380390

381-
# Checks
382-
problem.checks = Checks()
383-
for class_list in ratapi.project.parameter_class_lists:
384-
setattr(
385-
problem.checks, parameter_field[class_list], [int(element.fit) for element in getattr(project, class_list)]
386-
)
387-
388391
check_indices(problem)
389392

390393
return problem

ratapi/run.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def run(project, controls):
130130
# Update parameter values in project
131131
for class_list in ratapi.project.parameter_class_lists:
132132
for index, value in enumerate(getattr(problem_definition, parameter_field[class_list])):
133-
getattr(project, class_list)[index].value = value
133+
param = getattr(project, class_list)[index]
134+
param.fit = bool(getattr(problem_definition.checks, parameter_field[class_list])[index])
135+
param.value = value
134136

135137
controls.delete_IPC()
136138

tests/test_inputs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,19 @@ def test_make_problem(test_project, test_problem, request) -> None:
483483
check_problem_equal(problem, test_problem)
484484

485485

486+
def test_make_problem_validate_range(request) -> None:
487+
"""The problem should not contain fitted parameters with small range."""
488+
test_project = request.getfixturevalue("standard_layers_project")
489+
490+
test_project.scalefactors.set_fields(0, min=10, value=10, max=10, fit=True)
491+
problem = make_problem(test_project)
492+
assert problem.checks.scalefactors[0] == 1
493+
494+
with pytest.warns(UserWarning, match="was removed from the fit because its range is too small \(< 1e-10\)"):
495+
problem = make_problem(test_project, True)
496+
assert problem.checks.scalefactors[0] == 0
497+
498+
486499
@pytest.mark.parametrize("test_problem", ["standard_layers_problem", "custom_xy_problem", "domains_problem"])
487500
class TestCheckIndices:
488501
"""Tests for check_indices over a set of three test problems."""

0 commit comments

Comments
 (0)