Source code for aind_dynamic_foraging_models.generative_model.learn_functions

"""Functions for update latent variables in generative models."""

import numpy as np


[docs] def learn_RWlike(choice, reward, q_value_tminus1, forget_rates, learn_rates): """Learning function for Rescorla-Wagner-like model. Parameters ---------- choice : int this choice reward : float this reward q_value_tminus1 : np.ndarray array of old q values forget_rates : list forget rates for [unchosen, chosen] sides learn_rates : _type_ learning rates for [rewarded, unrewarded] sides Returns ------- np.ndarray array of new q values """ # Reward-dependent step size ('Hattori2019') learn_rate_rew, learn_rate_unrew = learn_rates[0], learn_rates[1] if reward: learn_rate = learn_rate_rew else: learn_rate = learn_rate_unrew # Choice-dependent forgetting rate ('Hattori2019') # Chosen: Q(n+1) = (1- forget_rate_chosen) * Q(n) + step_size * (Reward - Q(n)) q_value_t = np.zeros_like(q_value_tminus1) K = q_value_tminus1.shape[0] q_value_t[choice] = (1 - forget_rates[1]) * q_value_tminus1[choice] + learn_rate * ( reward - q_value_tminus1[choice] ) # Unchosen: Q(n+1) = (1-forget_rate_unchosen) * Q(n) unchosen_idx = [cc for cc in range(K) if cc != choice] q_value_t[unchosen_idx] = (1 - forget_rates[0]) * q_value_tminus1[unchosen_idx] return q_value_t
[docs] def learn_choice_kernel(choice, choice_kernel_tminus1, choice_kernel_step_size): """Learning function for choice kernel. Parameters ---------- choice : int this choice choice_kernel_tminus1 : np.ndarray array of old choice kernel values choice_kernel_step_size : float step size for choice kernel Returns ------- np.ndarray array of new choice kernel values """ # Choice vector choice_vector = np.array([0, 0]) choice_vector[choice] = 1 # Update choice kernel (see Model 5 of Wilson and Collins, 2019) # Note that if chocie_step_size = 1, degenerates to Bari 2019 # (choice kernel = the last choice only) return choice_kernel_tminus1 + choice_kernel_step_size * (choice_vector - choice_kernel_tminus1)
[docs] def learn_loss_counting(choice, reward, just_switched, loss_count_tminus1) -> int: """Update loss counting Returns the new loss count """ if reward: return 0 # If not reward if just_switched: return 1 else: return loss_count_tminus1 + 1