Source code for aind_dynamic_foraging_models.generative_model.params.util

from typing import Any, Dict, List, Tuple

import numpy as np
from pydantic import ConfigDict, Field, create_model, model_validator

from aind_dynamic_foraging_models.generative_model.params import ParamsSymbols


[docs] def create_pydantic_models_dynamic( params_fields: Dict[str, Any], fitting_bounds: Dict[str, Tuple[float, float]], ): """Create Pydantic models dynamically based on the input fields and fitting bounds.""" # -- params_model -- params_model = create_model( "ParamsModel", **params_fields, __config__=ConfigDict( extra="forbid", validate_assignment=True, ), ) # -- fitting_bounds_model -- fitting_bounds_fields = {} for name, (lower, upper) in fitting_bounds.items(): fitting_bounds_fields[name] = ( List[float], Field( default=[lower, upper], min_length=2, max_length=2, description=f"Fitting bounds for {name}", ), ) # Add a validator to check the fitting bounds def validate_bounds(cls, values: Dict[str, Any]) -> Dict[str, Any]: for name, bounds in values.model_dump().items(): lower_bound, upper_bound = bounds if lower_bound > upper_bound: raise ValueError(f"Lower bound for {name} must be <= upper bound") return values fitting_bounds_model = create_model( "FittingBoundsModel", **fitting_bounds_fields, __validators__={"validate_bounds": model_validator(mode="after")(validate_bounds)}, __config__=ConfigDict( extra="forbid", validate_assignment=True, ), ) return params_model, fitting_bounds_model
[docs] def get_params_options( params_model, default_range=[-np.inf, np.inf], para_range_override={}, ) -> dict: """Get options for the params fields. Useful for the Streamlit app. Parameters ---------- params_model : Pydantic model The Pydantic model for the parameters. default_range : list, optional The default range for the parameters, by default [-np.inf, np.inf] If the range is not specified in the Pydantic model, this default range will be used. para_range_override : dict, optional The range override for user-specified parameters, by default {} Example >>> ParamsModel, FittingBoundsModel = generate_pydantic_q_learning_params( number_of_learning_rate=1, number_of_forget_rate=1, choice_kernel="one_step", action_selection="softmax", ) >>> params_options = get_params_options(ParamsModel) {'learn_rate': {'para_range': [0.0, 1.0], 'para_default': 0.5, 'para_symbol': <ParamsSymbols.learn_rate: '$\\alpha$'>, 'para_desc': 'Learning rate'}, ... } """ # Get the schema params_schema = params_model.model_json_schema()["properties"] # Extract ge and le constraints param_options = {} for para_name, para_field in params_schema.items(): default = para_field.get("default", None) para_desc = para_field.get("description", "") if para_name in para_range_override: para_range = para_range_override[para_name] else: # Get from pydantic schema para_range = default_range.copy() # Default range # Override the range if specified if "minimum" in para_field: para_range[0] = para_field["minimum"] if "maximum" in para_field: para_range[1] = para_field["maximum"] para_range = [type(default)(x) for x in para_range] param_options[para_name] = dict( para_range=para_range, para_default=default, para_symbol=ParamsSymbols[para_name], para_desc=para_desc, ) return param_options