Source code for aind_dynamic_foraging_models.generative_model.foragers

"""Presets of forager models and utility functions to create group of agents."""

import inspect
import itertools
from typing import Literal, get_origin, get_type_hints

import pandas as pd

from aind_dynamic_foraging_models import generative_model


[docs] class ForagerCollection: """A class to create foragers.""" FORAGER_CLASSES = [ "ForagerQLearning", "ForagerLossCounting", "ForagerCompareThreshold", ] FORAGER_PRESETS = { "Bari2019": dict( description="The vanilla Bari2019 model", agent_class="ForagerQLearning", agent_kwargs=dict( number_of_learning_rate=1, number_of_forget_rate=1, choice_kernel="one_step", action_selection="softmax", ), ), "Hattori2019": dict( description="The vanilla Hattori2019 model", agent_class="ForagerQLearning", agent_kwargs=dict( number_of_learning_rate=2, number_of_forget_rate=1, choice_kernel="none", action_selection="softmax", ), ), "Rescorla-Wagner": dict( description="The vanilla Rescorla-Wagner model disccused in the Sutton & Barto book", agent_class="ForagerQLearning", agent_kwargs=dict( number_of_learning_rate=1, number_of_forget_rate=0, choice_kernel="none", action_selection="epsilon-greedy", ), ), "Win-Stay-Lose-Shift": dict( description="The vanilla Win-stay-lose-shift model", agent_class="ForagerLossCounting", agent_kwargs=dict( win_stay_lose_switch=True, choice_kernel="none", ), ), "CompareToThreshold": dict( description="Compare-to-threshold foraging model", agent_class="ForagerCompareThreshold", agent_kwargs=dict( choice_kernel="none", ), ), } def __init__(self): """Init""" self.presets = list(self.FORAGER_PRESETS.keys())
[docs] def get_agent_class(self, agent_class_name): """Get an agent class by agent_class_name""" agent_class = getattr(generative_model, agent_class_name, None) if agent_class is None: raise ValueError( f"{agent_class_name} is not found in the generative_model. " f"Available agent classes are: {self.FORAGER_CLASSES}" ) return agent_class
[docs] def get_forager(self, agent_class_name, agent_kwargs={}, **kwargs): """Get a forager by agent_class_name and agent_kwargs Parameters ---------- agent_class_name : str The class name of the forager. agent_kwargs : dict The keyword arguments to pass to the forager. **kwargs : dict Other keyword arguments to pass to the forager (like the rng seed). """ agent_class = self.get_agent_class(agent_class_name) return agent_class(**agent_kwargs, **kwargs)
[docs] def get_preset_forager(self, alias, **kwargs): """Get a preset forager but its alias. Parameters ---------- alias : str The alias of the forager. **kwargs : dict Other keyword arguments to pass to the forager (like the rng seed). """ assert alias in self.FORAGER_PRESETS.keys(), ( f"{alias} is not found in the preset foragers." f" Available presets are: {self.presets}" ) agent = self.FORAGER_PRESETS[alias] return self.get_forager(agent["agent_class"], agent["agent_kwargs"], **kwargs)
[docs] def is_preset(self, agent_class, agent_kwargs): """Check if an given agent is a preset forager. Parameters ---------- agent_class : str The class name of the forager to query agent_kwargs : dict The keyword arguments of the forager to query Returns ------- str or None The alias of the preset forager if it exists, otherwise None """ for preset_name, preset_specs in self.FORAGER_PRESETS.items(): if ( preset_specs["agent_class"] == agent_class and preset_specs["agent_kwargs"] == agent_kwargs ): return preset_name else: return None
[docs] def get_all_foragers(self, **kwargs) -> pd.DataFrame: """Return all available foragers in a dataframe. Parameters ---------- **kwargs : dict Other keyword arguments to pass to the forager (like the rng seed). """ all_foragers = [] # Loop over all agent classes for agent_class_name in self.FORAGER_CLASSES: agent_class = self.get_agent_class(agent_class_name) agent_kwargs_options = { key: value["options"] for key, value in self._get_agent_kwargs_options(agent_class).items() } agent_kwargs_combinations = itertools.product(*agent_kwargs_options.values()) # Loop over all agent_kwargs_combinations for specs in agent_kwargs_combinations: agent_kwargs = dict(zip(agent_kwargs_options.keys(), specs)) forager = self.get_forager(agent_class_name, agent_kwargs, **kwargs) preset_name = self.is_preset(agent_class_name, agent_kwargs) all_foragers.append( dict( agent_class_name=agent_class_name, agent_kwargs=agent_kwargs, agent_alias=forager.get_agent_alias(), **agent_kwargs, # Also unpack agent_kwargs preset_name=preset_name, n_free_params=len(forager.params_list_free), params=forager.get_params_str(if_latex=True, if_value=False), forager=forager, ) ) return pd.DataFrame(all_foragers)
@staticmethod def _get_agent_kwargs_options(agent_class) -> dict: """Given an agent class, return the agent's agent_kwargs that are of type Literal. This requires the agent class to have type hints as Literal in the __init__ method. Example: >>> _get_agent_kwargs_options("ForagerQLearning") {'number_of_learning_rate': {'options': (1, 2), 'default': 2}, 'number_of_forget_rate': {'options': (0, 1), 'default': 1}, 'choice_kernel': {'options': ('none', 'one_step', 'full'), 'default': 'none'}, 'action_selection': {'options': ('softmax', 'epsilon-greedy'), 'default': 'softmax'}} """ type_hints = get_type_hints(agent_class.__init__) signature = inspect.signature(agent_class.__init__) agent_args_options = {} for arg, type_hint in type_hints.items(): if get_origin(type_hint) is Literal: # Check if the type hint is a Literal # Get options literal_values = type_hint.__args__ default_value = signature.parameters[arg].default agent_args_options[arg] = {"options": literal_values, "default": default_value} return agent_args_options
if __name__ == "__main__": forager_collection = ForagerCollection() forager = forager_collection.get_preset_forager("Bari2019") print(forager_collection.presets) print(forager.params) df = forager_collection.get_all_foragers() print(df.agent_alias) forager = df.iloc[0].forager forager.params