aind_dynamic_foraging_models.generative_model package

Subpackages

Submodules

aind_dynamic_foraging_models.generative_model.act_functions module

Functions for action selection in generative models

aind_dynamic_foraging_models.generative_model.act_functions.act_epsilon_greedy(q_value_t: array, epsilon: float, bias_terms: array, choice_kernel=None, choice_kernel_relative_weight=None, rng=None)[source]

Action selection by epsilon-greedy method.

Steps: 1. Compute adjusted Q values by adding bias terms and choice kernel

Q’ = Q + bias + choice_kernel_relative_weight * choice_kernel

  1. The espilon-greedy method is quivalent to choice probabilities:
    If Q’_L != Q’_R (for simplicity, we assume only two choices)

    choice_prob [(argmax(Q’)] = 1 - epsilon / 2 choice_prob [(argmin(Q’))] = epsilon / 2

    else

    choice_prob [:] = 0.5

Parameters:
  • q_value_t (np.array) – Current Q-values

  • epsilon (float) – Probability of exploration

  • bias_terms (np.array) – Bias terms

  • choice_kernel (None or np.array, optional) – If not None, it will be added to Q-values, by default None

  • choice_kernel_relative_weight (_type_, optional) – If not None, it controls the relative weight of choice kernel, by default None

  • rng (_type_, optional) – _description_, by default None

aind_dynamic_foraging_models.generative_model.act_functions.act_loss_counting(previous_choice: int | None, loss_count: int, loss_count_threshold_mean: float, loss_count_threshold_std: float, bias_terms: array, choice_kernel=None, choice_kernel_relative_weight=None, rng=None)[source]

Action selection by loss counting method.

Parameters:
  • previous_choice (int) – Last choice

  • loss_count (int) – Current loss count

  • loss_count_threshold_mean (float) – Mean of the loss count threshold

  • loss_count_threshold_std (float) – Standard deviation of the loss count threshold

  • bias_terms (np.array) – Bias terms loss count

  • choice_kernel (None or np.array, optional) – If not None, it will be added to Q-values, by default None

  • choice_kernel_relative_weight (_type_, optional) – If not None, it controls the relative weight of choice kernel, by default None

  • rng (_type_, optional)

aind_dynamic_foraging_models.generative_model.act_functions.act_softmax(q_value_t: array, softmax_inverse_temperature: float, bias_terms: array, choice_kernel_relative_weight=None, choice_kernel=None, rng=None)[source]

Given q values and softmax_inverse_temperature, return the choice and choice probability.

If chocie_kernel is not None, it will sum it into the softmax function like this

  1. Compute adjusted Q values by adding bias terms and choice kernel

    \(Q' = \beta * (Q + w_{ck} * choice\_kernel) + bias\)

    \(\beta\) ~ softmax_inverse_temperature

    \(w_{ck}\) ~ choice_kernel_relative_weight

  2. Compute choice probabilities by softmax function

    \(choice\_prob = exp(Q'_i) / \sum_i(exp(Q'_i))\)

Parameters:
  • q_value_t (list or np.array) – array of q values, by default 0

  • softmax_inverse_temperature (int, optional) – inverse temperature of softmax function, by default 0

  • bias_terms (np.array, optional) – _description_, by default 0

  • choice_kernel_relative_weight (_type_, optional) – relative strength of choice kernel relative to Q in decision, by default None. If not none, choice kernel will have an inverse temperature of softmax_inverse_temperature * choice_kernel_relative_weight

  • choice_kernel (_type_, optional) – _description_, by default None

  • rng (_type_, optional) – random number generator, by default None

Returns:

_description_

Return type:

_type_

aind_dynamic_foraging_models.generative_model.act_functions.choose_ps(ps, rng=None)[source]

“Poisson”-choice process

aind_dynamic_foraging_models.generative_model.act_functions.softmax(x, rng=None)[source]
Parameters:
  • x (_type_) – _description_

  • rng (_type_, optional) – _description_, by default None

Returns:

_description_

Return type:

_type_

aind_dynamic_foraging_models.generative_model.base module

Base class for DynamicForagingAgent with MLE fitting

class aind_dynamic_foraging_models.generative_model.base.DynamicForagingAgentMLEBase(agent_kwargs: dict = {}, params: dict = {}, **kwargs)[source]

Bases: DynamicForagingAgentBase

Base class of “DynamicForagingAgentBase” + “MLE fitting”

act(observation)[source]

Chooses an action based on the current observation. I just copy and paste this from DynamicForagingAgentBase here for clarity.

Parameters:

observation – The current observation from the environment.

Returns:

The action chosen by the agent.

Return type:

action

fit(fit_choice_history, fit_reward_history, fit_bounds_override: dict | None = {}, clamp_params: dict | None = None, k_fold_cross_validation: int | None = None, DE_kwargs: dict | None = {'workers': 1})[source]

Fit the model to the data using differential evolution.

It handles fit_bounds_override and clamp_params as follows: 1. It will first clamp the parameters specified in clamp_params 2. For other parameters, if it is specified in fit_bounds_override, the specified

bound will be used; otherwise, the bound in the model’s ParamFitBounds will be used.

For example, if params_to_fit and clamp_params are all empty, all parameters will be fitted with default bounds in the model’s ParamFitBounds.

Supports both single-session and multi-session fitting.

Multi-session format is a list/tuple of per-session 1D arrays. In multi-session fitting, latent states (Q-values, choice kernels, etc.) are reset at the start of each session.

Parameters:
  • fit_choice_history (np.ndarray or List[np.ndarray]) – Choice history for fitting. Can be either: - A single 1D array for single-session fitting (backward compatible) - A list of 1D arrays for multi-session fitting, one array per session

  • fit_reward_history (np.ndarray or List[np.ndarray]) – Reward history for fitting. Must match the format of fit_choice_history.

  • fit_bounds_override (dict, optional) – Override the default bounds for fitting parameters in ParamFitBounds, by default {}

  • clamp_params (dict, optional) – Specify parameters to fix to certain values, by default {}

  • k_fold_cross_validation (Optional[int], optional) –

    Whether to do cross-validation, by default None (no cross-validation). If k_fold_cross_validation > 1, it will do k-fold cross-validation and return the prediction accuracy of the test set for model comparison. Cross-validation behavior: - Single-session input: trial-level k-fold CV (backward compatible) - Multi-session input: session-level k-fold CV (entire sessions held out)

    Requires n_sessions >= k_fold_cross_validation.

  • DE_kwargs (dict, optional) –

    kwargs for differential_evolution, by default {‘workers’: 1} For example:

    workersint

    Number of workers for differential evolution, by default 1. In CO, fitting a typical session of 1000 trials takes:

    1 worker: ~100 s 4 workers: ~35 s 8 workers: ~22 s 16 workers: ~20 s

    (see https://github.com/AllenNeuralDynamics/aind-dynamic-foraging-models/blob/22075b85360c0a5db475a90bcb025deaa4318f05/notebook/demo_rl_mle_fitting_new_test_time.ipynb) # noqa E501 That is to say, the parallel speedup in DE is sublinear. Therefore, given a constant number of total CPUs, it is more efficient to parallelize on the level of session, instead of on DE’s workers.

Returns:

  • fitting_result (OptimizeResult) – The fitting result object containing optimized parameters and statistics.

  • fitting_result_cross_validation (dict or None) – Cross-validation results if k_fold_cross_validation is specified, else None.

get_agent_alias()[source]

Get the agent alias for the model

Should be overridden by the subclass.

get_choice_history()[source]

Return the history of actions in format that is compatible with other library such as aind_dynamic_foraging_basic_analysis

get_fitting_result_dict()[source]

Return the fitting result in a json-compatible dict for uploading to docDB etc.

get_latent_variables()[source]

Return the latent variables of the agent

This is agent-specific and should be implemented by the subclass.

get_p_reward()[source]

Return the reward probabilities for each arm in each trial which is compatible with other library such as aind_dynamic_foraging_basic_analysis

get_params()[source]

Get the model parameters in a dictionary format

get_params_str(if_latex=True, if_value=True, decimal=3)[source]

Get string of the model parameters

Parameters:
  • if_latex (bool, optional) – if True, return the latex format of the parameters, by default True

  • if_value (bool, optional) – if True, return the value of the parameters, by default True

  • decimal (int, optional)

get_reward_history()[source]

Return the history of reward in format that is compatible with other library such as aind_dynamic_foraging_basic_analysis

learn(observation, action, reward, next_observation, done)[source]

Updates the agent’s knowledge or policy based on the last action and its outcome. I just copy and paste this from DynamicForagingAgentBase here for clarity.

This is the core method that should be implemented by all non-trivial agents. It could be Q-learning, policy gradients, neural networks, etc.

Parameters:
  • observation – The observation before the action was taken.

  • action – The action taken by the agent.

  • reward – The reward received after taking the action.

  • next_observation – The next observation after the action.

  • done – Whether the episode has ended.

perform(task: DynamicForagingTaskBase)[source]

Generative simulation of a task, or “open-loop” simulation

Override the base class method to include choice_prob caching etc.

In each trial loop (note when time ticks):

agent.act() task.step() agent.learn()

latent variable [t] –> choice [t] –> reward [t] —-> update latent variable [t+1]

perform_closed_loop(fit_choice_history, fit_reward_history)[source]

Simulates the agent over a fixed choice and reward history using its params. Also called “teacher forcing” or “closed-loop” simulation.

Unlike .perform() (“generative” simulation), this is called “predictive” simulation, which does not need a task and is used for model fitting.

perform_closed_loop_multi_session(fit_choice_history_sessions, fit_reward_history_sessions)[source]

Simulates the agent over multiple sessions, resetting latent states at each session start.

Unlike .perform() (“generative” simulation), this is called “predictive” simulation, which does not need a task and is used for model fitting across multiple sessions.

Parameters:
  • fit_choice_history_sessions (list of np.ndarray) – List of choice history arrays, one per session

  • fit_reward_history_sessions (list of np.ndarray) – List of reward history arrays, one per session

Returns:

choice_prob_sessions – List of choice probability arrays, one per session

Return type:

list of np.ndarray

plot_fitted_session(if_plot_latent=True)[source]

Plot session after .fit()

  1. choice and reward history will be the history used for fitting

  2. laten variables q_estimate and choice_prob will be plotted

  3. p_reward will be missing (since it is not used for fitting)

Parameters:

if_plot_latent (bool, optional) – Whether to plot latent variables, by default True

plot_latent_variables(ax, if_fitted=False)[source]

Add agent-specific latent variables to the plot

if_fitted: whether the latent variables are from the fitted model (styling purpose)

plot_session(if_plot_latent=True)[source]

Plot session after .perform(task)

Parameters:

if_plot_latent (bool, optional) – Whether to plot latent variables, by default True

set_params(**params)[source]

Update the model parameters and validate

aind_dynamic_foraging_models.generative_model.base.negLL(choice_prob, fit_choice_history, fit_reward_history, fit_trial_set=None)[source]

Compute total negLL of the trials in fit_trial_set given the data.

aind_dynamic_foraging_models.generative_model.forager_loss_counting module

Maximum likelihood fitting of foraging models

class aind_dynamic_foraging_models.generative_model.forager_loss_counting.ForagerLossCounting(win_stay_lose_switch: Literal[False, True] = False, choice_kernel: Literal['none', 'one_step', 'full'] = 'none', params: dict = {}, **kwargs)[source]

Bases: DynamicForagingAgentMLEBase

The familiy of loss counting models.

act(_)[source]

Action selection

get_agent_alias()[source]

Get the agent alias

get_latent_variables()[source]

Return the latent variables of the agent

This is agent-specific and should be implemented by the subclass.

learn(_, choice, reward, __, done)[source]

Update loss counter

Note that self.trial already increased by 1 before learn() in the base class

plot_latent_variables(ax, if_fitted=False)[source]

Plot Q values

aind_dynamic_foraging_models.generative_model.forager_q_learning module

Maximum likelihood fitting of foraging models

class aind_dynamic_foraging_models.generative_model.forager_q_learning.ForagerQLearning(number_of_learning_rate: Literal[1, 2] = 2, number_of_forget_rate: Literal[0, 1] = 1, choice_kernel: Literal['none', 'one_step', 'full'] = 'none', action_selection: Literal['softmax', 'epsilon-greedy'] = 'softmax', params: dict = {}, **kwargs)[source]

Bases: DynamicForagingAgentMLEBase

The familiy of simple Q-learning models.

act(_)[source]

Action selection

get_agent_alias()[source]

Get the agent alias

get_latent_variables()[source]

Return the latent variables of the agent

This is agent-specific and should be implemented by the subclass.

learn(_observation, choice, reward, _next_observation, done)[source]

Update Q values

Note that self.trial already increased by 1 before learn() in the base class

plot_latent_variables(ax, if_fitted=False)[source]

Plot Q values

aind_dynamic_foraging_models.generative_model.foragers module

Presets of forager models and utility functions to create group of agents.

class aind_dynamic_foraging_models.generative_model.foragers.ForagerCollection[source]

Bases: object

A class to create foragers.

FORAGER_CLASSES = ['ForagerQLearning', 'ForagerLossCounting', 'ForagerCompareThreshold']
FORAGER_PRESETS = {'Bari2019': {'agent_class': 'ForagerQLearning', 'agent_kwargs': {'action_selection': 'softmax', 'choice_kernel': 'one_step', 'number_of_forget_rate': 1, 'number_of_learning_rate': 1}, 'description': 'The vanilla Bari2019 model'}, 'CompareToThreshold': {'agent_class': 'ForagerCompareThreshold', 'agent_kwargs': {'choice_kernel': 'none'}, 'description': 'Compare-to-threshold foraging model'}, 'Hattori2019': {'agent_class': 'ForagerQLearning', 'agent_kwargs': {'action_selection': 'softmax', 'choice_kernel': 'none', 'number_of_forget_rate': 1, 'number_of_learning_rate': 2}, 'description': 'The vanilla Hattori2019 model'}, 'Rescorla-Wagner': {'agent_class': 'ForagerQLearning', 'agent_kwargs': {'action_selection': 'epsilon-greedy', 'choice_kernel': 'none', 'number_of_forget_rate': 0, 'number_of_learning_rate': 1}, 'description': 'The vanilla Rescorla-Wagner model disccused in the Sutton & Barto book'}, 'Win-Stay-Lose-Shift': {'agent_class': 'ForagerLossCounting', 'agent_kwargs': {'choice_kernel': 'none', 'win_stay_lose_switch': True}, 'description': 'The vanilla Win-stay-lose-shift model'}}
get_agent_class(agent_class_name)[source]

Get an agent class by agent_class_name

get_all_foragers(**kwargs) DataFrame[source]

Return all available foragers in a dataframe.

Parameters:

**kwargs (dict) – Other keyword arguments to pass to the forager (like the rng seed).

get_forager(agent_class_name, agent_kwargs={}, **kwargs)[source]

Get a forager by agent_class_name and agent_kwargs

Parameters:
  • agent_class_name (str) – The class name of the forager.

  • agent_kwargs (dict) – The keyword arguments to pass to the forager.

  • **kwargs (dict) – Other keyword arguments to pass to the forager (like the rng seed).

get_preset_forager(alias, **kwargs)[source]

Get a preset forager but its alias.

Parameters:
  • alias (str) – The alias of the forager.

  • **kwargs (dict) – Other keyword arguments to pass to the forager (like the rng seed).

is_preset(agent_class, agent_kwargs)[source]

Check if an given agent is a preset forager.

Parameters:
  • agent_class (str) – The class name of the forager to query

  • agent_kwargs (dict) – The keyword arguments of the forager to query

Returns:

The alias of the preset forager if it exists, otherwise None

Return type:

str or None

aind_dynamic_foraging_models.generative_model.learn_functions module

Functions for update latent variables in generative models.

aind_dynamic_foraging_models.generative_model.learn_functions.learn_RWlike(choice, reward, q_value_tminus1, forget_rates, learn_rates)[source]

Learning function for Rescorla-Wagner-like model.

Parameters:
  • choice (int) – this choice

  • reward (float) – this reward

  • q_value_tminus1 (np.ndarray) – array of old q values

  • forget_rates (list) – forget rates for [unchosen, chosen] sides

  • learn_rates (_type_) – learning rates for [rewarded, unrewarded] sides

Returns:

array of new q values

Return type:

np.ndarray

aind_dynamic_foraging_models.generative_model.learn_functions.learn_choice_kernel(choice, choice_kernel_tminus1, choice_kernel_step_size)[source]

Learning function for choice kernel.

Parameters:
  • choice (int) – this choice

  • choice_kernel_tminus1 (np.ndarray) – array of old choice kernel values

  • choice_kernel_step_size (float) – step size for choice kernel

Returns:

array of new choice kernel values

Return type:

np.ndarray

aind_dynamic_foraging_models.generative_model.learn_functions.learn_loss_counting(choice, reward, just_switched, loss_count_tminus1) int[source]

Update loss counting

Returns the new loss count

Module contents

Package for generative models of dynamic foraging behavior