Open NorbertZheng opened 2 years ago
Mattar et al. 2018提出的prioritized-sweep可以较好地解释小鼠在rest时replay内容顺序的结构,但是人类在offline的replay却不能很好地被其预测,如Antonov et al. 2022和Eldar et al. 2020中数据所示。这表明人类不是单纯地由expected-return所驱动,而online的sample是严重依赖policy的,导致TEM学出的hpc表征也是依赖policy的。但人类中除了online的expr,也可以在rest时自发地重组织形成simu的expr(如Liu et al. 2019),其内部具备丰富的sequence结构,使用其代替TEM中对Hopfield的reactivation是否有助于env-model的学习呢?
这也是在expr中只学习p(s'|s)而不学习p(s'|s,a)的弊端,如果还有action会导致存储指数爆炸,主要也没有来自motor cortex的信息。如何平衡model修正与value-update之间的关系?毕竟replay应不仅仅和value相关,replay的是state-trans。如果dopamine驱动,reward对应的back replay,算是加上了一种state-trans偏好? 另外一点,我们就可以用cognitive-map这一对env的model进行RL任务,解决control问题,优化policy了。
另外replay的信息压缩是否和information bottleneck的rl-sampling(Zhu et al. 2020)相关?
在这里,我们解析pytorch版本的TEM代码,主要解释model中的运算逻辑和架构。十分感谢Jcobb Bakermans提供的pytorch版本代码,比James Whittington的tensorflow版本TEM代码好看太多了。
def forward(self, walk, prev_iter = None, prev_M = None):
# The previous iteration may contain walks without action.
# These are new walks, for which some parameters need to be reset.
steps = self.init_walks(prev_iter)
# Forward pass: perform a TEM iteration for each set of [place, observation, action],
# and produce inferred and generated variables for each step.
for g, x, a in walk:
# If there is no previous iteration at all: all walks
# are new, initialise a whole new iteration object
if steps is None:
# Use an Iteration object to set initial values before any real iterations,
# initialising M, x_inf as zero. Set actions to None blank
# to indicate there was no previous action
steps = [self.init_iteration(g, x, [None for _ in range(len(a))], prev_M)]
# Perform TEM iteration using transition from previous iteration
L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf = self.iteration(
x, g, steps[-1].a, steps[-1].M, steps[-1].x_inf, steps[-1].g_inf
)
# Store this iteration in iteration object in steps list
steps.append(Iteration(g, x, a, L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf))
# The first step is either a step from a previous walk or initialisiation rubbish, so remove it
steps = steps[1:]
# Return steps, which is a list of Iteration objects
return steps
在这里,我们可以看到每次调用tem的时候,我们需要向其提供walk
和prev_iter
(也就是steps
)输入。其中,输入的数据结构如下:
20
表示该walk的长度,batch_size表明相同结构的不同训练环境(也就是具备相同的struct,但具备不同的sensory dist)。值得注意的是这里的struct是确定形式的(比如square,可以变换大小,主要目标是让tem学到square下的状态转换基本模式),并不能是随意的graph,也就是需要graph中的sub-graph要尽可能的相同,才能让tem在限定的struct-set下抽取出通用的rule,目前还做不到任意graph抽取转移rule,这也是Tim Beherens在BNU IDG/McGovern十周年纪念上的解释。steps
。但一进来就直接掐断,只留下一个Iteration,要么是初始化得到的,要么是上次迭代的最后一个Iteration。之后便开始,对walk
进行迭代,每次得到的Iteration都被放进steps
中,作为下一次调用forward
时的prev_iter
。
def iteration(self, x, locations, a_prev, M_prev, x_prev, g_prev):
# First, do the transition step, as it will be necessary for both
# the inference and generative part of the model
gt_gen, gt_inf = self.gen_g(a_prev, g_prev, locations)
# Run inference model: infer grounded location p_inf (hippocampus),
# abstract location g_inf (entorhinal). Also keep filtered sensory observation (x_inf),
# and retrieved grounded location p_inf_x
x_inf, g_inf, p_inf_x, p_inf = self.inference(x, locations, M_prev, x_prev, gt_inf)
# Run generative model: since generative model is only used
# for training purposes, it will generate from *inferred* variables
# instead of *generated* variables (as it would when used for generation)
x_gen, x_logits, p_gen = self.generative(M_prev, p_inf, g_inf, gt_gen)
# Update generative memory with generated and inferred grounded location.
M = [self.hebbian(M_prev[0], torch.cat(p_inf,dim=1), torch.cat(p_gen,dim=1))]
# If using memory for grounded location inference: append inference memory
if self.hyper['use_p_inf']:
# Inference memory is identical to generative memory
# if using common memory, and updated separatedly if not
M.append(M[0] if self.hyper['common_memory'] else\
self.hebbian(
M_prev[1], torch.cat(p_inf,dim=1), torch.cat(p_inf_x,dim=1),
do_hierarchical_connections=False
)
)
# Calculate loss of this step
L = self.loss(gt_gen, p_gen, x_logits, x, g_inf, p_inf, p_inf_x, M_prev)
# Return all iteration values
return L, M, gt_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf
iteration
是forward
函数中唯一调用的函数。其主要步骤如下:
gen_g
:这也是gen_g
函数唯一一次调用的地方。该函数首先将action
转换为对应的grid cell
投射矩阵D_a
,然后将之前的抽象位置转移到当前的抽象位置。其中的变量依赖如下:
gt_gen:
(a_prev,g_prev)->gt_gen
gt_inf:
(a_prev,g_prev)->gt_inf
inference
:tem中的inference部分。基于感觉输入x
推断抽象位置g_inf
和复合位置p_inf
,并同时输出中间变量,如感觉输入还原的复合位置p_inf_x
以及filtered的感觉观测x_inf
。其中的变量依赖如下:
x_inf:
x->x_c
(x_prev,x_c)->x_f(x_inf)
p_inf_x:
x_f->x_
(x_,M_prev[inf])->p_x(p_inf_x)
g_inf:
(p_x,gt_inf,x)->g(g_inf)
p_inf:
g->g_
(x_,g_)->p(p_inf)
generative
:tem中的generative部分。只是为了training而运行,它从infered variables
生成,而不是从generated variables
,也就是将inference
部分产生的抽象位置g_inf
和复合位置p_inf
以及之前gen_g
部分生成的gt_gen
为基础,生成出x_gen
,进而和x
对比产生error用于训练。其中的变量依赖如下:
p_gen:
(g_inf,M_prev[gen])->p_g_inf(p_gen)
x_gen:
p_inf[0]->x_p
p_g_inf[0]->x_g
(gt_gen,M_prev[gen])->p_g_gen[0]->x_gt
x_logits:
p_inf[0]->x_p_logits
p_g_inf[0]->x_g_logits
(gt_gen,M_prev[gen])->p_g_gen[0]->x_gt_logits
hebbian
:使用当前iteration前三步得到的结果更新M
。公式如下:
p_t
对应的是p_inf
,其是由g_inf
作为MLP
输入得到的,与之对应的是p_t_
,其是由g_inf
作为Memory
输入得到的p_gen
,因而这里更新的是M[gen]
。另外一个更新公式如下(如果use_p_inf
的话):
p_t
对应的是p_inf
,其是由g_inf
作为MLP
输入得到的,与之对应的是p_xt
,这是在推断g_inf
之前的变量p_inf_x
,在p_inf
之前,这里更新的是M[inf]
。其中的变量依赖如下:
M[gen]:
(p_inf,p_gen)->M[gen]
M[inf]:
(p_inf,p_inf_x)->M[inf]
loss
:最后是计算loss
的函数,用于得到各类error(L_p_g
, L_p_x
, L_x_gen
, L_x_g
, L_x_p
, L_g
, L_reg_g
, L_reg_p
),用于后期training。其中的变量依赖如下:
L_p:
(p_inf,p_gen)->L_p_g
(p_inf,p_inf_x)->L_p_x
L_g:
(g_inf,gt_gen)->L_g
L_x:
(x,x_gt_logits)->L_x_gen
(x,x_g_logits)->L_x_g
(x,x_p_logits)->L_x_p
L_reg:
g_inf->L_reg_g
p_inf->L_reg_p
generative model可以被分解为如下形式,分别对应着generative process的三部分gen_g
、gen_p
和gen_x
:
def generative(self, M_prev, p_inf, g_inf, g_gen):
# M_prev - (1/2, batch_size(16), sum(n_p)(400), sum(n_p)(400))
# p_inf - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
# g_inf - (n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))
# g_gen - (n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))
# Generate observation from inferred grounded location, using only
# the highest frequency. Also keep non-softmaxed logits which are used in the loss later
# x_p - (batch_size(16), n_x(45))
# x_p_logits - (batch_size(16), n_x(45))
x_p, x_p_logits = self.gen_x(p_inf[0])
# Retrieve grounded location from memory by pattern completion on inferred abstract location
# p_g_inf - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
p_g_inf = self.gen_p(g_inf, M_prev[0]) # was p_mem_gen
# And generate observation from the grounded location retrieved from inferred abstract location
# x_g - (batch_size(16), n_x(45))
# x_g_logits - (batch_size(16), n_x(45))
x_g, x_g_logits = self.gen_x(p_g_inf[0])
# Retreive grounded location from memory by pattern completion on abstract location by transitioning
# p_g_gen - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
p_g_gen = self.gen_p(g_gen, M_prev[0])
# Generate observation from sampled grounded location
# x_gt - (batch_size(16), n_x(45))
# x_gt_logits - (batch_size(16), n_x(45))
x_gt, x_gt_logits = self.gen_x(p_g_gen[0])
# Return all generated observations and their corresponding logits
return (x_p, x_g, x_gt), (x_p_logits, x_g_logits, x_gt_logits), p_g_inf
generative process的示意图如下:
generative process的更新公式如下:
generative
在iteration
函数中被调用。其主要模块如下:
gen_x
:这里的做法十分粗暴,直接把p_inf
的第一个频率模块拿出来作为其输入(batch_size(16),n_p[0](100))
,然后直接传给f_x
函数(如果do_sample
的话会另行修正,但源码中没有修改,暂时作为一个placeholder,应当使用reparameterisation trick)。由于n_p[0](100)
是通过n_g_subsampled[0](10)*n_x_c(10)
计算得到的,这里会将p_inf[0]
对每个sensory节点累加g_subsampled
,得到(batch_size(16),n_x_c(10))
的中间值,然后加权偏置得到(batch_size(16),n_x_c(10))
的x
,然后使用可训练的f_c_star
函数(MLP)解压得到(batch_size(16),n_x(45))
的logits
,最后通过softmax
将其归一化为概率分布probability
。其对应示意图中右半部分,公式对应Observation Sample(但这里只是得到了prob dist,没有进行最后的sample得到one-hot向量),细节如下:
gen_p
:该函数接收两个输入,一个是(n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))
的g
(正常的表征形式,没有被压缩),一个是(batch_size(16), sum(n_p)(400), sum(n_p)(400))
的M_prev[gen]
。该函数中的步骤如下:
g2g_
:将g
输入,首先通过f_g
函数暴力压缩为(n_f(5), batch_size(16), n_g_subsampled(10,10,8,6,6))
,然后再暴力拉伸为(n_f(5), batch_size(16), n_p(100,100,80,60,60))
的g_
,这便是g
产生的p-index
,即将在下面使用Hopfield进行迭代,得到收敛的p
。attractor
:将之前得到的g_
(p_query
)输入,先reshape为(batch_size(16), sum(n_p)(400))
,然后通过f_p
函数clamp
到(-1,1)并使用leaky_relu
进行激活。然后依据如下公式进行update:
最后通过reshape得到(n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
的p
作为mu_p
。f_sigma_p
:将得到的mu_p
作为输入,使用可训练的MLP转化为同样shape的sigma_p
。最终,p
从N(μ=mu_p
,σ=sigma_p
)中sample得到(如果do_sample
为False,则直接使用mu_p
)。其对应示意图中左下部分,公式对应Retrieve Memory。
其中,(x_p
,x_g
,x_gt
)分别是不同时期由不同p
通过gen_x
得到的,具体依赖关系如下:
这也是属于tem的generative
部分。其对应示意图中左上部分,公式对应Transition Sample:
我们也很自然的将gen_g
拆解为两部分f_mu_g_path
和f_sigma_g_path
:
f_mu_g_path
:接收两个输入a_prev
和g_prev
。在这里,(batch_size(16),)
的a_prev
先被转化为(batch_size(16), n_actions)
的one-hot向量a
。然后被可训练的MLP映射为(n_f(5), batch_size(16), [sum(n_g_from), n_g_to]([120,30],[90,30],[60,24],[36,18],[18,18]))
的D_a
,其中受shiny影响,所有no_direc
指定的batch都会被可训练的默认权重D_no_a
所替代。之后,将D_a
与对应的g_in
相乘便可得到delta
:
但这里并没有结束,为保证stability,这并不是最终的g_step
,真正的g_step
是由g_prev
和delta
相加,并通过f_g_clamp
函数将(-1,1)之外的值clamp后得到。f_sigma_g_path
:直接将g_prev
输入到可训练的MLP中得到(n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))
的from_g
。对于walk刚开始的env,其对应的sigma_g
由可训练的参数logsig_g_init
通过指数得到。最后,g_inf
便从N(μ=mu_g
,σ=sigma_g
)中sample得到(如果do_sample
为False,则直接使用mu_g
)。g_gen
相比g_inf
多一步,如果存在shiny,其中f_mu_g_path
的D_a
需要被D_no_a
代替,重新得到的mu_g
直接作为g_gen
。
inference model可以被分解为如下形式,分别对应着inference process的两部分infer_p
和inf_g
:
其中前项可以进一步分解,使得inference process中的infer_g
可以复用generative process中的gen_g
(也就是用来生成gt_inf
):
def inference(self, x, locations, M_prev, x_prev, g_gen):
# Compress sensory observation from one-hot
# to two-hot (or alternatively, whatever an MLP makes of it)
# x - (batch_size(16), n_x(45))
# x_c - (batch_size(16), n_x_c(10))
x_c = self.f_c(x)
# Temporally filter sensory observation by mixing it with previous experience
# x_prev - (n_f(5), batch_size(16), n_x_c(10))
# x_f - (n_f(5), batch_size(16), n_x_c(10))
x_f = self.x_prev2x(x_prev, x_c)
# Prepare sensory experience for input to memory by normalisation and weighting
# x_ - (n_f(5), batch_size(16), n_p(100,100,80,60,60))
x_ = self.x2x_(x_f)
# Retrieve grounded location from memory by
# doing pattern completion on current sensory experience
# p_x - (n_f(5), batch_size(16), n_p(100,100,80,60,60))
p_x = self.attractor(x_, M_prev[1], retrieve_it_mask=self.hyper['p_retrieve_mask_inf'])\
if self.hyper['use_p_inf'] else None
# Infer abstract location by combining previous abstract location and
# grounded location retrieved from memory by current sensory experience
# g_gen - (n_f(5), batch_size(16), self.g_init[f](30, 30, 24, 18, 18))
# x - (batch_size(16), n_x(45)), one-hot code, each element corresponds to a sensory
# locations - (batch_size(16),), each element is dict, e.g. {'id': 24, 'shiny': None}
# g - (n_f(5), batch_size(16), self.g_init[f](30, 30, 24, 18, 18)),
# use x & g_gen(for inv_var_weight) to infer g
g = self.inf_g(p_x, g_gen, x, locations)
# Prepare abstract location for input to memory by downsampling and weighting
# g_ - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
g_ = self.g2g_(g)
# Infer grounded location from sensory experience and inferred abstract location
# p - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
p = self.inf_p(x_, g_)
# Return variables in order that they were created
return x_f, g, p_x, p
inference process的示意图如下:
inference process的更新公式如下:
inference
在iteration
函数中被调用。其主要模块如下:
inf_g
:这一部分对应示意图中的上半部分。由上面的公式可知,其输入有(p_inf_x
,gt_inf
,x
,locations
)。我们先考虑p_inf_x
的生成过程:
(batch_size(16), n_x(45))
的x
通过f_x
函数查表压缩得到(batch_size(16), n_x_c(10))
的x_c
,这是更新公式中的Compress sensory observation。x_prev2x
函数进行时序过滤得到(n_f(5), batch_size(16), n_x_c(10))
的x_f
,这是更新公式中的Temporally filter sensorium,其中(n_f(5),)
的α
是可训练参数。x2x_
函数将(n_f(5), batch_size(16), n_x_c(10))
的x_f
进行normalize,然后暴力重复并加权得到(n_f(5), batch_size(16), n_x_c*n_g_subsampled(100,100,80,60,60))
的x_
,这是更新公式中的Sensory input to hippocampus,其中(n_f(5),)
的w_p
是可训练参数。attractor
函数,使用M_prev[inf]
来进行迭代,得到渐近收敛的(n_f(5), batch_size(16), n_p(100,100,80,60,60))
的p_inf_x
,这是更新公式中的Retrieve memory。如果不使用use_p_inf
,则在计算g
时只会使用gt_inf
(mu_g
=gt_inf[0]
,是shiny操作前的gt_gen
;sigma_g
=gt_inf[1]
,这和shiny操作前的gt_gen
的σ相同),会和生成gt_gen
的时候一样对N(μ=mu_g
,σ=sigma_g
)进行sample(如果do_sample
为False,则直接使用mu_g
)。但如果使用use_p_inf
,这时我们会先对p_inf_x
暴力sum降采样得到(n_f(5), batch_size(16), n_g_subsampled(10,10,8,6,6))
的g_downsampled
,然后通过可训练MLP的f_mu_g_mem
函数并clamp(-1,1)得到(n_f(5), batch_size(16), n_g(30,30,24,18,18))
的mu_g_mem
。由于这里的p_inf_x
实际上对应了本节第2个公式中的x{<=t}和M{t-1}混合后的结果,我们可以使用其作为gen_x
函数的输入反推得到x_hat
,进而计算与x
的MSE,并整合g**2
一块作为可训练MLP的f_sigma_g_mem
函数输入sigma_g_mem
,并加上超参数p2g_scale_offset
和p2g_sig_val
的乘积以减少推断p对早期训练的影响。之后通过inv_var_weight
函数将(mu_g_path
,mu_g_mem
)和(sigma_g_path
,sigma_g_mem
)整合,考虑shiny对其的影响之后,直接在N(μ=mu_g
,σ=sigma_g
)中sample即可得到g_inf
,这是更新公式中的Infer entorhinal。
inf_p
:这一部分对应示意图中的下半部分。在此之前,将g_inf
通过暴力truncated降采样,并repeat得到(n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
的g_
,这是更新公式中的Entorhinal input to hippocampus。然后将g_
与x_
进行element-wise mul,clamp(-1,1)后进行leaky-relu得到mu_p
,注意这里没有使用attractor
,但sigma_p
由于不明晰paper中所具体指的函数被设置为0,p
便通过从N(μ=mu_p
,σ=sigma_p
)中sample得到,这是更新公式中的Infer hippocampus。这是Memory
部分的学习算法,更新公式如下:
def hebbian(self, M_prev, p_inferred, p_generated, do_hierarchical_connections=True):
# Create new ground memory for attractor network by setting weights to outer product of learned vectors
# p_inferred corresponds to p in the paper, and p_generated corresponds to p^.
# The order of p + p^ and p - p^ is reversed since these are row vectors,
# instead of column vectors in the paper.
# M_new - (batch_size(16), sum(n_p)(400), sum(n_p)(400)),
# calculated from (16,400,1) matmul (16,1,400), not element-wise mul
M_new = torch.squeeze(torch.matmul(
torch.unsqueeze(p_inferred + p_generated, 2),torch.unsqueeze(p_inferred - p_generated,1)
))
# Multiply by connection vector, e.g. only keeping weights
# from low to high frequencies for hierarchical retrieval.
if do_hierarchical_connections:
M_new = M_new * self.hyper['p_update_mask']
# Store grounded location in attractor network memory with weights M by Hebbian learning of pattern
M = torch.clamp(self.hyper['lambda'] * M_prev + self.hyper['eta'] * M_new, min=-1, max=1)
return M
在使用更新公式更新之后,使用clamp截断(-1,1)之外的部分。
这是Cortex
部分的学习算法,loss公式如下:
def loss(self, g_gen, p_gen, x_logits, x, g_inf, p_inf, p_inf_x, M_prev):
# Calculate loss function, separately for each component
# because you might want to reweight contributions later.
# L_p_gen is squared error loss between inferred grounded location
# and grounded location retrieved from inferred abstract location
L_p_g = torch.sum(torch.stack(utils.squared_error(p_inf, p_gen), dim=0), dim=0)
# L_p_inf is squared error loss between inferred grounded location
# and grounded location retrieved from sensory experience
L_p_x = torch.sum(torch.stack(utils.squared_error(p_inf, p_inf_x), dim=0), dim=0)\
if self.hyper['use_p_inf'] else torch.zeros_like(L_p_g)
# L_g is squared error loss between generated abstract location and inferred abstract location
L_g = torch.sum(torch.stack(utils.squared_error(g_inf, g_gen), dim=0), dim=0)
# L_x is a cross-entropy loss between sensory experience and different model predictions.
# First get true labels from sensory experience
labels = torch.argmax(x, 1)
# L_x_gen: losses generated by generative model from g_prev -> g -> p -> x
L_x_gen = utils.cross_entropy(x_logits[2], labels)
# L_x_g: Losses generated by generative model from g_inf -> p -> x
L_x_g = utils.cross_entropy(x_logits[1], labels)
# L_x_p: Losses generated by generative model from p_inf -> x
L_x_p = utils.cross_entropy(x_logits[0], labels)
# L_reg are regularisation losses, L_reg_g on L2 norm of g
L_reg_g = torch.sum(torch.stack([torch.sum(g ** 2, dim=1) for g in g_inf], dim=0), dim=0)
# And L_reg_p regularisation on L1 norm of p
L_reg_p = torch.sum(torch.stack([torch.sum(torch.abs(p), dim=1) for p in p_inf], dim=0), dim=0)
# Return total loss as list of losses, so you can possibly reweight them
L = [L_p_g, L_p_x, L_x_gen, L_x_g, L_x_p, L_g, L_reg_g, L_reg_p]
return L
通过该loss,使用BPTT+ADAM进行更新。
TEM的学习过程是完全offline的,其并不存在与env进行交互动态更新policy的可能性,那么对于env的sample-policy就会影响其形成的表征,不论是grid-cell还是place-cell,而这有可能引发有关#11 SR的讨论。在TEM的simulation中,其使用喜欢在边界附近花费时间并接近物体的policy来模拟non-diffusive transitions。关于此,James Whittington在其phd-thesis中如下写道:
这其实涉及到p(s'|s)
的表征问题,很自然的,在non-diffusive transition下,p(s'|s)
并不是等概率的,这导致agent学到的状态转移矩阵T
就可以反应policy,具备形成SR的基础特征。而SR即在解释policy诱发的p(s'|s)
偏好,由于TEM是在学习预测下一个state以及对应的sensory,而且这里只有对边界和物体的喜欢,因而只产生了object-vector cell,这是SR的一种形式。如果我们进一步增加non-diffusive transition的种类,我们可以观测到更多种类的SR。
当然,环境自身的transition结构,TEM也是学到了,毕竟在generative
中的gen_g
部分,有D_a*g_{t-1}
表示(s,a)-pair,而其是用于生成g_{t}
的,这里其实就表示了p(s'|s,a)
。由于这里假定了MDP的全可观测性,并不存在POMDP的问题,那么其env动力学方程可由p(s',r|s,a)
表示,同时由于没有提供reward信息,这里的env动力学方程可简化为p(s'|s,a)
,所以我们认为其已经学会了环境自身的transition结构,这一部分表示在μ=D_a*g_{t-1}
。而σ=f_sigma_g(g_{t-1})
表示对于μ
的uncertainty,但由于环境是determined,所以这里的distribution既不表示stochasticity,也不表示uncertainty,理应归为0,关于stochasticity和uncertainty的讨论详见#13 。主要是x
的sensory distribution变化,无法影响到理想的graph,g_{t}
的值理应是确定的,不存在stochasticity和uncertainty导致的distribution。
另外,我认为一个模型能够表示state和在state-graph上的inference,并不会把其representation限制在diffusive transition上,因为这两者相互之间具备很大的自由度,与TEM能够形成SR并不矛盾。这一想法主要来自TEM论文中提到对于state的约束:
在介绍Algorithm的时候,我们提到graph中的sub-graph要尽可能的相同,才能让tem在限定的struct-set下抽取出通用的rule,目前还做不到任意graph抽取转移rule。这里的主要原因在于:
这也是HPC和ERC都具备不同freq的原因,我们可以通过不同的scale来表示world,这种hierarchy的组织形式是极其高效的(二进制独立编码),有助于我们以我们认为正常的方式抽取world中的rule。同时,Saxe et. al. 2019表示学习过程中会展现出一种non-linear的形式,先填满第一个奇异值(也就是最大的那个)对应的向量,然后再去修正后面的奇异值对应向量。考虑到hierarchy对应的其实也是特征值的问题,这暗示了人们理解事物的时候的一个基本原则,难度要循序渐进,先解决统一问题,再解决细节问题,我们才能摸到背后统一的rule。关于同时学经历具备不同subgraph的graph问题,我们不清楚学习到的grid-cell发放模式是否还具备hierarchy的结构,不知道TEM能否预测在这种实验范式下的grid-cell发放模式,注意这里是同时学习,而不是在一个环境中学完到另外一个新环境中产生remapping,学不会就另说了,TEM现在是真不太好说是否具备这样的功能。
另外,一个猜测,TEM在学习的过程中应该也是线形成large-scale的grid-cell(也就是特征值大的那个对应的特征向量),然后再逐渐补全剩余的部分,也就是表明了统一问题优先的原则。这就让我们考虑到了generalization的问题,不知道这里是否是和计算机那边DeepRL在PCG-env中考虑的generalization问题一致,Jiang et. al. 2021针对这一类问题从sample-efficiency的角度尝试解决不知道对human在offline时期replay来平衡generalization和reward-maximization有没有帮助,不然不好formalize人类执行任务中的generalization问题的形式。另外一点,就是人类日常生活中所解决的问题往往是POMDP,这也就导致了belief-state representation的出现,比如Gershman et. al. 2019提出mPFC等作为belief-state representation的潜在计算位置。Give the belief to TEM~
但inf_g
的时候,依据p_inf_x
生成了sigma_g_input
,这对于TEM刚进入新环境而言,其实算是一种不能确定true parameters,是一种epistemic uncertainty,这就变成了一个epistemic POMDP问题。而TEM很好地解决了这个问题,能够在经历一次新环境所有节点之后立刻推断出所有边,做到了epistemic POMDP问题的generalization。
# Not in paper, but this greatly improves zero-shot inference: provide
# the uncertainty function of the inferred abstract location with measures of memory quality
with torch.no_grad():
# For the first measure, use the grounded location inferred from memory to generate an observation
x_hat, x_hat_logits = self.gen_x(p_x[0])
# Then calculate the error between the generated observation and the actual observation:
# if the memory is working well, this error should be small
err = utils.squared_error(x, x_hat)
# The second measure is the vector norm of the inferred abstract location; good memories should have
# similar vector norms. Concatenate the two measures as input for the abstract location uncertainty function
# sigma_g_input - (n_f(5), batch_size(16), n_measure(2))
sigma_g_input = [torch.cat(
(torch.sum(g ** 2, dim=1, keepdim=True), torch.unsqueeze(err, dim=1)), dim=1
) for g in mu_g_mem]
...
# And get standard deviation/uncertainty of inferred abstract location by
# providing uncertainty function with memory quality measures
# sigma_g_mem - (n_f(5), batch_size(16), n_g(30,30,24,18,18))
sigma_g_mem = self.f_sigma_g_mem(sigma_g_input)
σ=f_sigma_g(g_{t-1})
用起来,或者可以借鉴其它做POMDP的强化学习文章,PCG-env的generalization问题是否算作POMDP抽取env动力学方程的特例(也就是指实际生活中比这中simulation的环境更加复杂)?g
as MEC, p
as HPC, x
as LEC, combine x
and g
to get p
.p
and observe the simulation result.logsig_offset
, logsig_ratio
, which are used to modify the logsig
generated by MLP as follows:
logsigmas = [self.p2g_logsig[i](x) for i, x in enumerate(logsig_input)]
logsigma = tf.concat(logsigmas, axis=1) * self.par.logsig_ratio + self.par.logsig_offset
Is that used to ensure the regularity of the latent space #21 generated by MLP? In the _loss
part of the original model, there are some regularization items:
# Calculate L_reg* losses.
L_reg_g = tf.reduce_sum(tf.stack([tf.reduce_sum(g ** 2, axis=1)\
for g in g_inf], axis=0), axis=0)
L_reg_p = tf.reduce_sum(tf.stack([tf.reduce_sum(tf.abs(p), axis=1)\
for p in p_inf], axis=0), axis=0)
It seems we don't regularize the sigma
term of distributions, so we need logsig_offset
, logsig_ratio
, and p2g_sig_val
to ensure the regularity of the latent space?
_inf_g
part of the original model, there are also some regularization items:
# Then calculate the error between the generated observation and the actual observation:
# if the memory is working well, this error should be small. `x_hat` as target.
err = self._loss_mse(x, x_hat)
# The second measure is the vector norm of the inferred abstract location;
# good memories should have similar vector norms. Concatenate the two measures
# as input for the abstract location uncertainty function.
# logsig_g_in - (n_f[list], batch_size, 2)
logsig_g_in = [tf.concat((tf.reduce_sum(g ** 2, axis=1, keepdims=True),
tf.expand_dims(err, axis=1)), axis=1) for g in mu_g_mem]
We use the regularization of mu_g
, e.g. tf.reduce_sum(g ** 2, axis=1, keepdims=True)
, and err
as the inputs to the MLP_logsig_g_mem
. What is the meaning of the input here?
_hebbian
and _final_hebbian
? Want a batch-update of M
?(-1,1)
? Why use leaky_relu
to modify the activation of p
, is that more bio-plausible?p
part of the model doesn't have inner dynamics itself. For example, we use x_
to retrieve memory, we will get one attractor. We ignore the previous state of memory, and the attractor can only reflect the dynamic of x
instead of p
. Specifically, We use x_prev
to retrieve p_inf_x_prev
, and update x_prev
:
# x - (n_f[list], batch_size, n_x_c)
x = [(1 - alpha[f]) * x_prev[f] + alpha[f] * x_c\
for f in range(self.model_params.neuron.n_f)]
And then we use x_
to retrieve p_inf_x
, so the transition from p_inf_x_prev
to p_inf_x
only reflect the recurrent dynamics of x
, so does g
-indexed retrieval. So the attractor space of p
is determined by both x
and g
, or both LEC
and MEC
. Of course, this is perfectly suitable for the tasks we want to explain, but most time, we don't know which brain areas a new task is related to, so we should set up multiple brain areas to be responsible for different graphic structures, and use the matching degree of the task as the weight of the index. After all, the hopfield model can be regarded as a transformer (see Ramsauer et. al. 2020 for more detials), and that enables the hippocampus to have the ability to dynamically organize modules as Lewis 2021 mentioned. Can we cast the problem of exploration as inferencing a graph structure? And this step is before the so-called planning, which could only be executed after the map has been formalized in the hippocampus. During exploration, we keep using the sensory input x
to modify the synaptic connection within hippocampus, and formalize a map in hippocampus just like the map in the sensory cortex, e.g. LEC
. Hippocampus then uses this activation as the index to call other cortex, e.g. MEC
. If the dynamics of hippocampus matches the dynamics of one cortex, then the activation of this cortex will then reinforce the dynamics in hippocampus, or we can say this cortex is attended, for the activation of this cortex has a high similarity with the activation of hippocampus, at the same time, the synaptic connection within hippocampus will also be modified. It seems that two (maybe more) maps in different cortex are trying to reach a consensus. In this case, the Levy flight can be seen as trying to find the most consistent map (maybe the composition of these maps), which is stored in other cortex, as soon as possible. The local movement corresponds to the local map (in one cortex), and the long jump corresponds the transition from one module to another. If the surrounding modules correspond to the same map (in the cortex), then this will reinforce these modules to be clustered in the activation of hippocampus. And now hippocampus formalize a large-scale map, whose components corresponds to clusters.
When encounter a new graph which is completely different from the learned graph, then hippocampus formalize a new map by itself, which means this map can only match the dynamics of sensory cortex, cannot match other cortex. Now the hippocampus will formalize a fine-gained map, even context will also be reflected in the hippocampus, then transfer it to the cortex, maybe by replay in the random walk mode?
p
doesn't have its own dynamics, without the position code from g
(which has not been matched yet), it will take long time for p
(also g
) to learn the underlying sequence. This is not the case where g
is learned and provides the position coding. Should we consider the inner dynamics of hippocampus itself? Maybe this will help to formalize the attractor space where the representations of objects that are temporally adjacent are represented adjacently in the hippocampus (see Schapiro et. al. 2016 for more details), and help to generate replay from hippocampus, not from cortex.Not sure what some variables refer to | variable | my understanding |
---|---|---|
r | (purely?) reward tuning cell in hippocampus | |
d | (purely?) direction tuning cell in hippocampus |
inf_l
part of the ovc model, l
is inferred from the activation of both g
and ovc
, without the activation of x
, is this to match the bio-anatomical evidence?g
and ovc
are separated, does this correspond to the result of Obenhaus et. al. 2022? And this paper states that the majority of cell types were intermingled, but grid and object-vector cells exhibited little overlap. So, g
mainly account for the generalization, and ovc
mainly account for the maximization of reward?The following is a report about Tolman-Eichenbaum Machine (TEM), the corresponding ppt can be downloaded from here.
Here, we can see the architecture of the key component of the Transformer, e.g. self-attention block. Given a number of entities (without sequential order, but entities may contain position embeddings themselves), we use $W{q}$, $W{k}$, $W_{v}$ to calculate $q$, $k$, $v$ corresponding to each entity.
$$ Q=HW{q} \quad K=HW{k} \quad V=HW_{v}, $$
where $H=[X,E]$ contains the feature embeddings $X$ and the position embeddings $E$. After calculating $Q$, $K$, $V$, respectively. We use each item of $Q$ to query $K$, and get the probability of each value item. This process is like aggregating information from all nodes in the graph, but it is different from Graph Neural Network (GNN), e.g. it doesn't exploit the graph adjacency matrix explicitly, the probability is calculated from correlation.
$$ Prob{l,i}=softmax(\frac{q{l,i}K{l}^{T}}{\sqrt{d{k}}}). $$
Then we directly multiply $Prob{l,i}$ with $V$ to get the updated entity $h{l+1,i}$.
$$ h{l+1,i}=Prob{l,i}V_{l}. $$
Now, let's see how can we re-formulate TEM as such a self-attention block, e.g. what $W{*}$, $H$ indeed correspond to. Let's look at the generative process of TEM. Obviously, the input entity embeddings contain feature embeddings $x$ (e.g. sensory observation or the neural activities of LEC) and position embeddings $g$ (e.g. the neural activities of MEC). Here, we use RNN to update the abstract location $g$, which models the neural activities of MEC. We use an action-specific weight matrix $W{a}$ (which is the output of MLP with no bias in both layers, and the input is the one-hot version of $a$) to get the $\Delta g$.
$$ \Delta g{t-1}=W{a}g_{t-1}. $$
Then we can update $g{t-1}$ with $\Delta g{t-1}$, clamp the value of updated value to get $g_{t}$.
$$ g{t}=f{g}(g{t-1}+\Delta g{t-1}). $$
Both $g{t-1}$ and $g{t}$ are the position embeddings of the corresponding entity. So, how could we get the corresponding query vector $\tilde{g}$? How to express the query matrix $W{q}$? We first downsample $g$ then repeat the downsampled value to get the query vector $\tilde{g}$. Therefore, the query vector $\tilde{g}$ and query matrix $W{q}$ can be formulated as follows:
$$ \tilde{g}{t}=W{repeat}f{down}(g{t}) \quad W{q}=W{repeat}f_{down}(\cdot). $$
We can easily see that the query matrix $W{q}$ of TEM is not like that of the self-attention block in Transformer, e.g. the query matrix $W{q}$ of TEM is hand-coded, instead of learnable.
Now, let's focus on the memory retrieval part of the generative process. Firstly, we have to understand what the Hebbian memory $M_{t-1}$ (hippocampus) is exactly doing. The calculation equation of $p$ and the update equation of $M$ is as follows:
$$ p=flatten(x^{t}g) \quad M{t}=\sum{\tau=0}^{t}p{\tau}^{T}p{\tau}. $$
From the above equation, we can see that $M_{t}$ is binding every $g$ with every $x$. We should note that we ignore some computational details in the memory update process:
$$ M{t}=\lambda M{t-1}+\eta (p{t}-\hat{p}{t})(p{t}+\hat{p}{t})^{T}. $$
Now, we take one attractor step in TEM with no non-linearity as an example,
$$ \begin{aligned} \tilde{x}{t}^{retrieved}&=sum(unflatten(p{t}^{retrieved}), 1)\ \tilde{p}{t}^{retrieved}&=q{t}M{t-1}=q{t}\sum{\tau=0}^{t-1}p{\tau}^{T}p{\tau}=\sum{\tau=0}^{t-1}[q_{t}\ \end{aligned} $$
Due to that
$$ \begin{aligned} &[q{t}p{\tau}^{T}]=\bar{\tilde{x}}{t} [\tilde{g}{t} \cdot \tilde{g}{\tau}],\ &where \quad \bar{\tilde{x}}=\sum{i}(\tilde{x}{\tau}){i}. \end{aligned} $$
Then we have
$$ p{t}^{retrieved}=\tilde{g}{t}\tilde{G}^{T}\Lambda_{x}P. $$
Finally, we get the following equation:
$$ \tilde{x}{t}^{retrieved}=(\alpha \tilde{g}{t}\tilde{G}^{T})\tilde{X} \quad c.f. \quad softmax(\frac{\tilde{g}{t}G^{T}}{\sqrt{d{k}}})\tilde{X}. $$
Therefore, the key matrix $W_{k}$ can be formulated as follows:
$$ W{k}=W{q}=W{repeat}f{down}(\cdot). $$
We can also find that the key matrix $W{k}$ of TEM is not like that of the self-attention block in Transformer, e.g. the key matrix $W{k}$ of TEM is hand-coded, instead of learnable.
We can see that the generative process in TEM is doing exactly self-attention (but can only attend to past experience, not including future experience).
Now, let's focus on the inference process of TEM. The Hebbian memory in the inference process is doing exactly the same thing as the Hebbian memory in the generative process. They may use different components (either $g$ or $x$, here is $x$) to calculate the update weights. Here, all we have to know is that $M_{t}$ in the inference process of TEM also binds every $x$ with every $g$. Firstly, we use $\gamma$ s (different modules may have different $\gamma$) to filter the original $x$:
$$ \tilde{x}{t}=filter(x{t})=(1-\gamma)\tilde{x}{t-1}+\gamma x{t}. $$
Similarly, one step attractor can be formulated as follows:
$$ \tilde{g}{t}^{retrieved}=(\alpha \tilde{x}{t}\tilde{X}^{T})\tilde{G} \quad c.f. \quad softmax(\frac{\tilde{x}{t}\tilde{X}^{T}}{\sqrt{d{k}}})\tilde{G}. $$
We can easily find that the inference process of TEM is also doing self-attention, e.g.
$$ W{q}=filter(\cdot) \quad W{k}=filter(\cdot) \quad W_{v}=MLP(\cdot) $$
Different from the generative process of TEM, there are learnable parameters in $W{q}$, $W{k}$, $W{v}$ (although most parts are hand-coded in $filter(\cdot)$, which is also part of $W{v}$ in the generative process).
Now, we can conclude that
In the Successive Representation (SR) model, they utilize the problem organization form of non-space tasks in reinforcement learning to tackle the rodents' navigation task. After all, the spatial navigation task is just a special case of the reinforcement learning task. We all know that in RL, the value of states can be expressed as follows:
$$ V(s)=\mathbb{E}{\pi}\left[\sum{t=0}^{\infty}\gamma^{t}R(s{t})|s{0}=s\right]. $$
But we do not care about the reward function $R(s)$ (after all, the standard place cells do not care about the reward), so we can decompose the value function $V(s)$ into two parts as follows:
$$ V(s)=\sum{s'}M(s,s')R(s') \quad M(s,s')=\mathbb{E}{\pi}\left[\sum{t=0}^{\infty}\gamma^{t}\mathbb{I}(s{t}=s')|s_{0}=s\right]. $$
$M(s,s')$ is a sort of function that predicts how likely we can get to the target location $s'$, given the current location $s$. Of course, with exponential decay over time. And SR model uses $M(s,s')$ to describe a place cell with the max-firing location at the target location $s'$. When the rodent is at location $s$, the firing rate of place cell $s'$ will predict the accumulated likelihood of getting to location $s'$ in the future.
Pretty simple? But there are some details in the equation we have to notice.
Now, we can see that claim clearly. $M(s,s')$ is a conjunctive representation of global transition structure $g$ (e.g. the space representation factor) and local transition structure $o$ (e.g. the action representation factor). As we can see, the action representation factor $o$ is associated with the abstract value $r$, but the action representation factor $o$ and the feature representation factor $f$ can still be separated given that we can change the [location, reward value] of the object, which leads to that the action representation factor exists independently (orthogonally) of the object representation. From that view, we can understand why $M(s,s')$ is policy-dependent. Besides, $M(s,s')$ is also limited in the specified task, e.g. neither the location nor the reward value of the object will be changed, which means that $M(s,s')$ directly learns that conjunctive representation, instead of composes two representation factors together. For example, when the walking policy is strict, e.g. we cannot choose any action that causes the agent leaving from the object, the agent will predict there is no way to walk away from the goal location, which has a great difference from the prediction of using the space representation factor. Therefore, the SR model cannot provide a generalizable graph representation (the space representation factor).
Now, we can agree with one point that the SR model is trying to decompose the conjunctive representation into multiple representation factors heuristically, instead of automatically. So how could we decompose the conjunctive representation (e.g. the task representation) into multiple representation factors (e.g. orthogonal task factors in the task (representation) space) automatically? Here is a work from James Whittington. He finds that adding some biological constraints (namely nonnegativity and energy efficiency in both activity and weights) to the loss function leads to disentangled representation, e.g. decomposing the task representation into multiple representation factors automatically.
$$ \begin{aligned} \mathcal{L}&=\underbrace{\mathcal{L}{non-neg}+\mathcal{L}{activity}+\mathcal{L}{weight}}{Biological\ constraints}+\underbrace{\mathcal{L}{prediction}}{Functional constraints}\ &\mathcal{L}{non-neg}=\beta{non-neg}\sum{i}max(-a{i},0)\ &\mathcal{L}{activity}=\beta{activity}\sum{l}||a{l}||^{2}\ &\mathcal{L}{weight}=\beta{weight}\sum{l}||W{l}||^{2}\ &\mathcal{L}{prediction}=\beta{prediction}||\hat{y}-y||^{2}\ \end{aligned} $$
In the first figure, the biological constraints are not added to the model, and we can see that each hidden neuron may respond to the change of multiple task factors. But if we add the biological constraints to the model, each hidden neuron will only respond to the change of one task factor, e.g. disentangled representation. And this means that we automatically decompose the task representation to orthogonal task factors, with each hidden neuron representing at most one task factor. Now, we can conclude that the SR model is just a subset of TEM-disentangled.
Now, we come to Clone-Structured Cognitive Graph (CSCG). This model can also learn a graph representation, but just like the SR model, the learned graph representation is also not generalizable. And due to that the walking policy is diffusion by default, the graph representation learned by CSCG is exactly the space representation (without combining with the action representation). Despite its non-generalizable graph representation, CSCG can learn such representation quickly. And in the beginning of entering environments with a brand-new structure, the hippocampus serves as a graph, instead of associative memory. Therefore, we can use CSCG to model the hippocampus, and then generate replays from CSCG to facilitate the structure abstraction process of TEM. Cheers~
Spatial Memory Pipeline (SMP) is another model built to explain the computational mechanisms of the hippocampal-entorhinal system. The SMP model has a similar architecture to TEM, e.g. VAE. But the SMP model is much more complex than TEM. SMP uses a memory bank in the machine learning field to model the hippocampus, which is not bio-plausible as the attractor network (after all, the attractor network is built by the computational neuroscience community). And its path integration part is also complex. Of course, such a complex model provides a much more powerful information processing ability than TEM. It can directly process raw visual sequences! And we can observe similar neural representation in the SMP model. However, due to the complexity of the model, the SMP model is not as good as TEM to explain the key principle of the hippocampal-entorhinal neural system.
We use a discrete $16\times 16$ world, (so 256 locations; $n{l}=256$) and optimize an independent representation, $z(x) \in \mathbb{R}^{n{c}}, at each location. We now detail each component of the following loss
$$ \mathcal{L}=\underbrace{\mathcal{L}{nonneg}+\mathcal{L}{activity}+\mathcal{L}{weight}}{Biological\ constraints}+\underbrace{\mathcal{L}{location}+\mathcal{L}{actions}+\mathcal{L}{objects}}{Functional\ constraints}+\underbrace{\mathcal{L}{path\ integration}}{Structure\ constraints}. $$
$$ \begin{aligned} &\mathcal{L}{nonneg}=\frac{\beta{nonneg}}{n{l}}\sum{x}\sum{i}max(-z{i}(x),0)\ &\mathcal{L}{weight}=\beta{weight}\sum{t}||W{t}||^{2}\ &\mathcal{L}{activity}=\frac{\beta{activity}}{n{l}}\sum{x}||z(x)||^{2}, \end{aligned} $$
where $z_{i}(x)$ is a neuron in representation $z(x)$, $t$ indexes the task (i.e. object, action, location, prediction), and the $\beta$ determines the regularization strength.
$$ \mathcal{L}{location}=-\frac{\beta{location}}{n{l}}\sum{x}ln\frac{e^{l{x}\cdot z(x)}}{\sum{x'}e^{l_{x'}\cdot z(x)}}. $$
$$ \mathcal{L}{object}=-{\beta{object}}{n{l}}\sum{x}\left[\mathbb{I}(object\ at\ x)ln\sigma(W{o}z(x))+\mathbb{I}(object\ not\ at\ x)ln(1- \sigma(W{o}z(x)))\right]. $$
$$ \mathcal{L}{action}=-\frac{\beta{action}}{n{l}}\sum{x}\sum{a}\left[\mathbb{I}(a=a(x))ln\sigma(t{a}\cdot z(x))+(\mathbb{I}(a\neq a(x)))ln(1-\sigma(t_{a}\cdot z(x)))\right]. $$
$$ \mathcal{L}{path\ integration}=\frac{\beta{path\ integration}}{n{l}}\sum{x}\sum{a}||z(x)-f(W{a}z(x-d_{a}))||^{2}, $$
where $W{a} \in \mathbb{R}^{n{c}\times n{c}} is a weight matrix that depends on action, $a$, (i.e. there are $4$ trainable weights matrices - one for each action), and $d{a}$ means the displacement in the underlying space (the space of $x$), that the action $a$ corresponds to.
Pattern forming dynamics. The overall loss can be optimized with respect to the weights. However, it can also be optimized directly with respect to $z$. This is particularly interesting for us, as it
This is a necessity for us as we need to represent objects which may move between tasks. Planning? To optimize both $z$ (task particularities) and weights (task generalities), we do so in two stages. First, we optimize with respect to $z$ to "infer" a representation for the current task. Second, we optimize with respect to the weights to learn parameters that are general across tasks.
When optimizing with respect to $z$ we only optimize two terms in the loss: $\mathcal{L}{object}$ and $\mathcal{L}{path\ integration}$. We optimize the first term so the systems has the ability to know where the objects are. We optimize the second term so that information can be propagated around (effectively via path integration).
The dynamics of the $\mathcal{L}_{object}$ are:
$$ \frac{d\mathcal{L}{object}}{dz(x)}=-W{objects}^{T}(\mathbb{I}(object\ at\ x)-\sigma(W_{o}z(x))). $$
This says if you get the object prediction wrong, then update $z$ to better predict the object. We restrict this update to only take place where the object is, so it is just an object signal. This update is equivalent to a rodent observing that it is at an object.
The dynamics of the$\mathcal{L}_{path\ integration}$ are:
$$ \begin{aligned} \frac{d\mathcal{L}{path\ integration}}{dz(x)}=&\sum{a}[-(z(x)-f(W{a}z(x-d{a}))\ &+W{a}^{T}(z(x+d{a})-f(W{a}z(x)))\odot f'(W{a}z(x))]. \end{aligned} $$
The two terms in the above equation can be easily understood. The first says that the representation at each location, $z(x)$, should be updated according to what its neighbors think it should (this is the same update rule as path integration!). The second term says the representation at each location, $z(x)$, should be updated if did not predict its neighbors correctly. This equation tells representations to update based on their neighbors. This is just like a cellular automata, but instead of a discrete value being updated on the basis of its neighbors, it is a whole population vector whose elements are continuous. Indeed, just like cellular automata, it is also possible to initialize a single "cell" (location) of the cellular automata, and have that representations propagate throughout the space. In this case, it's just like path integration, but spreading through all space at once. We note, however, that in our simulations we initialize representations at all locations (for each walk).
We note that while we simulated this on a discrete grid, the same principles apply to continuous cases. In this case the sums over location/actions need to be replaced with integrals.
This is very general approach for understanding representations. The structural loss does not have to be related to the rules of path integration. It can be anything. It could be the rules of a graph. It could be rules of topology. It could have one set of rules at some locations and another set of rules at other locations.
If there are structure or rules in the world or behavior, our constraints say that representations should understand that structures or rules. In mathematics this is known as a homeomorphism. In sum, understanding representations via constraints is very general.
Whittington J C R, Muller T H, Mark S, et al. The Tolman-Eichenbaum machine: Unifying space and relational memory through generalization in the hippocampal formation.