|
3 | 3 | import importlib |
4 | 4 | import os |
5 | 5 | import pathlib |
| 6 | +import warnings |
6 | 7 | from collections.abc import Callable |
7 | 8 |
|
8 | 9 | import numpy as np |
9 | 10 |
|
10 | 11 | import ratapi |
11 | 12 | import ratapi.wrappers |
12 | 13 | from ratapi.rat_core import Checks, Control, NameStore, ProblemDefinition |
13 | | -from ratapi.utils.enums import Calculations, Languages, LayerModels, TypeOptions |
| 14 | +from ratapi.utils.enums import Calculations, Languages, LayerModels, Procedures, TypeOptions |
14 | 15 |
|
15 | 16 | parameter_field = { |
16 | 17 | "parameters": "params", |
@@ -137,13 +138,13 @@ def make_input(project: ratapi.Project, controls: ratapi.Controls) -> tuple[Prob |
137 | 138 | The controls object used in the compiled RAT code. |
138 | 139 |
|
139 | 140 | """ |
140 | | - problem = make_problem(project) |
| 141 | + problem = make_problem(project, controls.procedure != Procedures.Calculate) |
141 | 142 | cpp_controls = make_controls(controls) |
142 | 143 |
|
143 | 144 | return problem, cpp_controls |
144 | 145 |
|
145 | 146 |
|
146 | | -def make_problem(project: ratapi.Project) -> ProblemDefinition: |
| 147 | +def make_problem(project: ratapi.Project, validate_range: bool = False) -> ProblemDefinition: |
147 | 148 | """Construct the problem input required for the compiled RAT code. |
148 | 149 |
|
149 | 150 | Parameters |
@@ -351,40 +352,42 @@ def make_problem(project: ratapi.Project) -> ProblemDefinition: |
351 | 352 | problem.domainContrastLayers = [ |
352 | 353 | domain_contrast_model if domain_contrast_model else [] for domain_contrast_model in domain_contrast_models |
353 | 354 | ] |
354 | | - problem.fitParams = [ |
355 | | - param.value |
356 | | - for class_list in ratapi.project.parameter_class_lists |
357 | | - for param in getattr(project, class_list) |
358 | | - if param.fit |
359 | | - ] |
360 | | - problem.fitLimits = [ |
361 | | - [param.min, param.max] |
362 | | - for class_list in ratapi.project.parameter_class_lists |
363 | | - for param in getattr(project, class_list) |
364 | | - if param.fit |
365 | | - ] |
366 | | - problem.priorNames = [ |
367 | | - param.name for class_list in ratapi.project.parameter_class_lists for param in getattr(project, class_list) |
368 | | - ] |
369 | | - problem.priorValues = [ |
370 | | - [prior_id[param.prior_type], param.mu, param.sigma] |
371 | | - for class_list in ratapi.project.parameter_class_lists |
372 | | - for param in getattr(project, class_list) |
373 | | - ] |
| 355 | + |
| 356 | + fit_params = [] |
| 357 | + fit_limits = [] |
| 358 | + prior_names = [] |
| 359 | + prior_values = [] |
| 360 | + problem.checks = Checks() |
| 361 | + for class_list in ratapi.project.parameter_class_lists: |
| 362 | + field = parameter_field[class_list] |
| 363 | + check_list = [] |
| 364 | + for param in getattr(project, class_list): |
| 365 | + prior_names.append(param.name) |
| 366 | + prior_values.append([prior_id[param.prior_type], param.mu, param.sigma]) |
| 367 | + check_list.append(int(param.fit)) |
| 368 | + if param.fit: |
| 369 | + if validate_range and (param.max - param.min) < 1e-10: |
| 370 | + warnings.warn( |
| 371 | + f'{class_list.replace("_", " ").title()} "{param.name}" was removed from the ' |
| 372 | + "fit because its range is too small (< 1e-10).", |
| 373 | + stacklevel=2, |
| 374 | + ) |
| 375 | + check_list[-1] = 0 |
| 376 | + else: |
| 377 | + fit_params.append(param.value) |
| 378 | + fit_limits.append([param.min, param.max]) |
| 379 | + setattr(problem.checks, field, check_list) |
| 380 | + problem.fitParams = fit_params |
| 381 | + problem.fitLimits = fit_limits |
| 382 | + problem.priorNames = prior_names |
| 383 | + problem.priorValues = prior_values |
374 | 384 |
|
375 | 385 | # Names |
376 | 386 | problem.names = NameStore() |
377 | 387 | for class_list in ratapi.project.parameter_class_lists: |
378 | 388 | setattr(problem.names, parameter_field[class_list], [param.name for param in getattr(project, class_list)]) |
379 | 389 | problem.names.contrasts = [contrast.name for contrast in project.contrasts] |
380 | 390 |
|
381 | | - # Checks |
382 | | - problem.checks = Checks() |
383 | | - for class_list in ratapi.project.parameter_class_lists: |
384 | | - setattr( |
385 | | - problem.checks, parameter_field[class_list], [int(element.fit) for element in getattr(project, class_list)] |
386 | | - ) |
387 | | - |
388 | 391 | check_indices(problem) |
389 | 392 |
|
390 | 393 | return problem |
|
0 commit comments