NorbertZheng / read-papers

My paper reading notes.
MIT License
8 stars 0 forks source link

Cell '20 | The Tolman-Eichenbaum machine: Unifying space and relational memory through generalization in the hippocampal formation. #16

Open NorbertZheng opened 2 years ago

NorbertZheng commented 2 years ago

Whittington J C R, Muller T H, Mark S, et al. The Tolman-Eichenbaum machine: Unifying space and relational memory through generalization in the hippocampal formation.

NorbertZheng commented 2 years ago

相关文献

NorbertZheng commented 2 years ago

Mattar et al. 2018提出的prioritized-sweep可以较好地解释小鼠在rest时replay内容顺序的结构,但是人类在offline的replay却不能很好地被其预测,如Antonov et al. 2022Eldar et al. 2020中数据所示。这表明人类不是单纯地由expected-return所驱动,而online的sample是严重依赖policy的,导致TEM学出的hpc表征也是依赖policy的。但人类中除了online的expr,也可以在rest时自发地重组织形成simu的expr(如Liu et al. 2019),其内部具备丰富的sequence结构,使用其代替TEM中对Hopfield的reactivation是否有助于env-model的学习呢?

NorbertZheng commented 2 years ago

这也是在expr中只学习p(s'|s)而不学习p(s'|s,a)的弊端,如果还有action会导致存储指数爆炸,主要也没有来自motor cortex的信息。如何平衡model修正与value-update之间的关系?毕竟replay应不仅仅和value相关,replay的是state-trans。如果dopamine驱动,reward对应的back replay,算是加上了一种state-trans偏好? 另外一点,我们就可以用cognitive-map这一对env的model进行RL任务,解决control问题,优化policy了。

NorbertZheng commented 2 years ago

另外replay的信息压缩是否和information bottleneck的rl-sampling(Zhu et al. 2020)相关?

NorbertZheng commented 2 years ago

Tolman-Eichenbaum Machine

在这里,我们解析pytorch版本的TEM代码,主要解释model中的运算逻辑和架构。十分感谢Jcobb Bakermans提供的pytorch版本代码,比James Whittington的tensorflow版本TEM代码好看太多了。 image

NorbertZheng commented 2 years ago

Algorithm

forward

def forward(self, walk, prev_iter = None, prev_M = None):
    # The previous iteration may contain walks without action.
    # These are new walks, for which some parameters need to be reset.
    steps = self.init_walks(prev_iter)
    # Forward pass: perform a TEM iteration for each set of [place, observation, action],
    # and produce inferred and generated variables for each step.
    for g, x, a in walk:
        # If there is no previous iteration at all: all walks
        # are new, initialise a whole new iteration object
        if steps is None:
            # Use an Iteration object to set initial values before any real iterations,
            # initialising M, x_inf as zero. Set actions to None blank
            # to indicate there was no previous action
            steps = [self.init_iteration(g, x, [None for _ in range(len(a))], prev_M)]
        # Perform TEM iteration using transition from previous iteration
        L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf = self.iteration(
            x, g, steps[-1].a, steps[-1].M, steps[-1].x_inf, steps[-1].g_inf
        )
        # Store this iteration in iteration object in steps list
        steps.append(Iteration(g, x, a, L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf))    
    # The first step is either a step from a previous walk or initialisiation rubbish, so remove it
    steps = steps[1:]
    # Return steps, which is a list of Iteration objects
    return steps

在这里,我们可以看到每次调用tem的时候,我们需要向其提供walkprev_iter(也就是steps)输入。其中,输入的数据结构如下:

之后便开始,对walk进行迭代,每次得到的Iteration都被放进steps中,作为下一次调用forward时的prev_iter

NorbertZheng commented 2 years ago

iteration

def iteration(self, x, locations, a_prev, M_prev, x_prev, g_prev):
    # First, do the transition step, as it will be necessary for both
    # the inference and generative part of the model
    gt_gen, gt_inf = self.gen_g(a_prev, g_prev, locations)
    # Run inference model: infer grounded location p_inf (hippocampus),
    # abstract location g_inf (entorhinal). Also keep filtered sensory observation (x_inf),
    # and retrieved grounded location p_inf_x
    x_inf, g_inf, p_inf_x, p_inf = self.inference(x, locations, M_prev, x_prev, gt_inf)
    # Run generative model: since generative model is only used
    # for training purposes, it will generate from *inferred* variables
    # instead of *generated* variables (as it would when used for generation)
    x_gen, x_logits, p_gen = self.generative(M_prev, p_inf, g_inf, gt_gen)
    # Update generative memory with generated and inferred grounded location.
    M = [self.hebbian(M_prev[0], torch.cat(p_inf,dim=1), torch.cat(p_gen,dim=1))]
    # If using memory for grounded location inference: append inference memory
    if self.hyper['use_p_inf']:
        # Inference memory is identical to generative memory
        # if using common memory, and updated separatedly if not
        M.append(M[0] if self.hyper['common_memory'] else\
            self.hebbian(
                M_prev[1], torch.cat(p_inf,dim=1), torch.cat(p_inf_x,dim=1),
                do_hierarchical_connections=False
            )
        )
    # Calculate loss of this step
    L = self.loss(gt_gen, p_gen, x_logits, x, g_inf, p_inf, p_inf_x, M_prev)
    # Return all iteration values
    return L, M, gt_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf

iterationforward函数中唯一调用的函数。其主要步骤如下:

NorbertZheng commented 2 years ago

generative

generative model可以被分解为如下形式,分别对应着generative process的三部分gen_ggen_pgen_ximage

def generative(self, M_prev, p_inf, g_inf, g_gen):
    # M_prev - (1/2, batch_size(16), sum(n_p)(400), sum(n_p)(400))
    # p_inf - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
    # g_inf - (n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))
    # g_gen - (n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))
    # Generate observation from inferred grounded location, using only
    # the highest frequency. Also keep non-softmaxed logits which are used in the loss later
    # x_p - (batch_size(16), n_x(45))
    # x_p_logits - (batch_size(16), n_x(45))
    x_p, x_p_logits = self.gen_x(p_inf[0])
    # Retrieve grounded location from memory by pattern completion on inferred abstract location
    # p_g_inf - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
    p_g_inf = self.gen_p(g_inf, M_prev[0]) # was p_mem_gen
    # And generate observation from the grounded location retrieved from inferred abstract location
    # x_g - (batch_size(16), n_x(45))
    # x_g_logits - (batch_size(16), n_x(45))
    x_g, x_g_logits = self.gen_x(p_g_inf[0])
    # Retreive grounded location from memory by pattern completion on abstract location by transitioning
    # p_g_gen - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
    p_g_gen = self.gen_p(g_gen, M_prev[0])
    # Generate observation from sampled grounded location
    # x_gt - (batch_size(16), n_x(45))
    # x_gt_logits - (batch_size(16), n_x(45))
    x_gt, x_gt_logits = self.gen_x(p_g_gen[0])
    # Return all generated observations and their corresponding logits
    return (x_p, x_g, x_gt), (x_p_logits, x_g_logits, x_gt_logits), p_g_inf

generative process的示意图如下: image generative process的更新公式如下: image generativeiteration函数中被调用。其主要模块如下:

其中,(x_p,x_g,x_gt)分别是不同时期由不同p通过gen_x得到的,具体依赖关系如下: image

NorbertZheng commented 2 years ago

gen_g

这也是属于tem的generative部分。其对应示意图中左上部分,公式对应Transition Sample: image 我们也很自然的将gen_g拆解为两部分f_mu_g_pathf_sigma_g_path

最后,g_inf便从N(μ=mu_g,σ=sigma_g)中sample得到(如果do_sample为False,则直接使用mu_g)。g_gen相比g_inf多一步,如果存在shiny,其中f_mu_g_pathD_a需要被D_no_a代替,重新得到的mu_g直接作为g_gen

NorbertZheng commented 2 years ago

inference

inference model可以被分解为如下形式,分别对应着inference process的两部分infer_pinf_gimage 其中前项可以进一步分解,使得inference process中的infer_g可以复用generative process中的gen_g(也就是用来生成gt_inf): image

def inference(self, x, locations, M_prev, x_prev, g_gen):
    # Compress sensory observation from one-hot
    # to two-hot (or alternatively, whatever an MLP makes of it)
    # x - (batch_size(16), n_x(45))
    # x_c - (batch_size(16), n_x_c(10))
    x_c = self.f_c(x)
    # Temporally filter sensory observation by mixing it with previous experience
    # x_prev - (n_f(5), batch_size(16), n_x_c(10))
    # x_f - (n_f(5), batch_size(16), n_x_c(10))
    x_f = self.x_prev2x(x_prev, x_c)
    # Prepare sensory experience for input to memory by normalisation and weighting
    # x_ - (n_f(5), batch_size(16), n_p(100,100,80,60,60))
    x_ = self.x2x_(x_f)
    # Retrieve grounded location from memory by
    # doing pattern completion on current sensory experience
    # p_x - (n_f(5), batch_size(16), n_p(100,100,80,60,60))
    p_x = self.attractor(x_, M_prev[1], retrieve_it_mask=self.hyper['p_retrieve_mask_inf'])\
        if self.hyper['use_p_inf'] else None
    # Infer abstract location by combining previous abstract location and
    # grounded location retrieved from memory by current sensory experience
    # g_gen - (n_f(5), batch_size(16), self.g_init[f](30, 30, 24, 18, 18))
    # x - (batch_size(16), n_x(45)), one-hot code, each element corresponds to a sensory
    # locations - (batch_size(16),), each element is dict, e.g. {'id': 24, 'shiny': None}
    # g - (n_f(5), batch_size(16), self.g_init[f](30, 30, 24, 18, 18)),
    # use x & g_gen(for inv_var_weight) to infer g
    g = self.inf_g(p_x, g_gen, x, locations)
    # Prepare abstract location for input to memory by downsampling and weighting
    # g_ - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
    g_ = self.g2g_(g)
    # Infer grounded location from sensory experience and inferred abstract location
    # p - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
    p = self.inf_p(x_, g_)
    # Return variables in order that they were created
    return x_f, g, p_x, p

inference process的示意图如下: test_00 inference process的更新公式如下: image inferenceiteration函数中被调用。其主要模块如下:

NorbertZheng commented 2 years ago

hebbian

这是Memory部分的学习算法,更新公式如下: image

def hebbian(self, M_prev, p_inferred, p_generated, do_hierarchical_connections=True):
    # Create new ground memory for attractor network by setting weights to outer product of learned vectors
    # p_inferred corresponds to p in the paper, and p_generated corresponds to p^. 
    # The order of p + p^ and p - p^ is reversed since these are row vectors,
    # instead of column vectors in the paper.
    # M_new - (batch_size(16), sum(n_p)(400), sum(n_p)(400)),
    # calculated from (16,400,1) matmul (16,1,400), not element-wise mul
    M_new = torch.squeeze(torch.matmul(
        torch.unsqueeze(p_inferred + p_generated, 2),torch.unsqueeze(p_inferred - p_generated,1)
    ))
    # Multiply by connection vector, e.g. only keeping weights
    # from low to high frequencies for hierarchical retrieval.
    if do_hierarchical_connections:
        M_new = M_new * self.hyper['p_update_mask']
    # Store grounded location in attractor network memory with weights M by Hebbian learning of pattern
    M = torch.clamp(self.hyper['lambda'] * M_prev + self.hyper['eta'] * M_new, min=-1, max=1)
    return M

在使用更新公式更新之后,使用clamp截断(-1,1)之外的部分。

NorbertZheng commented 2 years ago

loss

这是Cortex部分的学习算法,loss公式如下: image

def loss(self, g_gen, p_gen, x_logits, x, g_inf, p_inf, p_inf_x, M_prev):
    # Calculate loss function, separately for each component
    # because you might want to reweight contributions later.
    # L_p_gen is squared error loss between inferred grounded location
    # and grounded location retrieved from inferred abstract location
    L_p_g = torch.sum(torch.stack(utils.squared_error(p_inf, p_gen), dim=0), dim=0)
    # L_p_inf is squared error loss between inferred grounded location
    # and grounded location retrieved from sensory experience
    L_p_x = torch.sum(torch.stack(utils.squared_error(p_inf, p_inf_x), dim=0), dim=0)\
        if self.hyper['use_p_inf'] else torch.zeros_like(L_p_g)
    # L_g is squared error loss between generated abstract location and inferred abstract location
    L_g = torch.sum(torch.stack(utils.squared_error(g_inf, g_gen), dim=0), dim=0)         
    # L_x is a cross-entropy loss between sensory experience and different model predictions.
    # First get true labels from sensory experience
    labels = torch.argmax(x, 1)
    # L_x_gen: losses generated by generative model from g_prev -> g -> p -> x
    L_x_gen = utils.cross_entropy(x_logits[2], labels)
    # L_x_g: Losses generated by generative model from g_inf -> p -> x
    L_x_g = utils.cross_entropy(x_logits[1], labels)
    # L_x_p: Losses generated by generative model from p_inf -> x
    L_x_p = utils.cross_entropy(x_logits[0], labels)
    # L_reg are regularisation losses, L_reg_g on L2 norm of g
    L_reg_g = torch.sum(torch.stack([torch.sum(g ** 2, dim=1) for g in g_inf], dim=0), dim=0)
    # And L_reg_p regularisation on L1 norm of p
    L_reg_p = torch.sum(torch.stack([torch.sum(torch.abs(p), dim=1) for p in p_inf], dim=0), dim=0)
    # Return total loss as list of losses, so you can possibly reweight them
    L = [L_p_g, L_p_x, L_x_gen, L_x_g, L_x_p, L_g, L_reg_g, L_reg_p]
    return L

通过该loss,使用BPTT+ADAM进行更新。

NorbertZheng commented 2 years ago

Discussion

Successive Representation

TEM的学习过程是完全offline的,其并不存在与env进行交互动态更新policy的可能性,那么对于env的sample-policy就会影响其形成的表征,不论是grid-cell还是place-cell,而这有可能引发有关#11 SR的讨论。在TEM的simulation中,其使用喜欢在边界附近花费时间并接近物体的policy来模拟non-diffusive transitions。关于此,James Whittington在其phd-thesis中如下写道:

这其实涉及到p(s'|s)的表征问题,很自然的,在non-diffusive transition下,p(s'|s)并不是等概率的,这导致agent学到的状态转移矩阵T就可以反应policy,具备形成SR的基础特征。而SR即在解释policy诱发的p(s'|s)偏好,由于TEM是在学习预测下一个state以及对应的sensory,而且这里只有对边界和物体的喜欢,因而只产生了object-vector cell,这是SR的一种形式。如果我们进一步增加non-diffusive transition的种类,我们可以观测到更多种类的SR。 image 当然,环境自身的transition结构,TEM也是学到了,毕竟在generative中的gen_g部分,有D_a*g_{t-1}表示(s,a)-pair,而其是用于生成g_{t}的,这里其实就表示了p(s'|s,a)。由于这里假定了MDP的全可观测性,并不存在POMDP的问题,那么其env动力学方程可由p(s',r|s,a)表示,同时由于没有提供reward信息,这里的env动力学方程可简化为p(s'|s,a),所以我们认为其已经学会了环境自身的transition结构,这一部分表示在μ=D_a*g_{t-1}。而σ=f_sigma_g(g_{t-1})表示对于μ的uncertainty,但由于环境是determined,所以这里的distribution既不表示stochasticity,也不表示uncertainty,理应归为0,关于stochasticity和uncertainty的讨论详见#13 。主要是x的sensory distribution变化,无法影响到理想的graph,g_{t}的值理应是确定的,不存在stochasticity和uncertainty导致的distribution。 另外,我认为一个模型能够表示state和在state-graph上的inference,并不会把其representation限制在diffusive transition上,因为这两者相互之间具备很大的自由度,与TEM能够形成SR并不矛盾。这一想法主要来自TEM论文中提到对于state的约束:

NorbertZheng commented 2 years ago

Hierarchies in the Map

在介绍Algorithm的时候,我们提到graph中的sub-graph要尽可能的相同,才能让tem在限定的struct-set下抽取出通用的rule,目前还做不到任意graph抽取转移rule。这里的主要原因在于:

这也是HPC和ERC都具备不同freq的原因,我们可以通过不同的scale来表示world,这种hierarchy的组织形式是极其高效的(二进制独立编码),有助于我们以我们认为正常的方式抽取world中的rule。同时,Saxe et. al. 2019表示学习过程中会展现出一种non-linear的形式,先填满第一个奇异值(也就是最大的那个)对应的向量,然后再去修正后面的奇异值对应向量。考虑到hierarchy对应的其实也是特征值的问题,这暗示了人们理解事物的时候的一个基本原则,难度要循序渐进,先解决统一问题,再解决细节问题,我们才能摸到背后统一的rule。关于同时学经历具备不同subgraph的graph问题,我们不清楚学习到的grid-cell发放模式是否还具备hierarchy的结构,不知道TEM能否预测在这种实验范式下的grid-cell发放模式,注意这里是同时学习,而不是在一个环境中学完到另外一个新环境中产生remapping,学不会就另说了,TEM现在是真不太好说是否具备这样的功能。 另外,一个猜测,TEM在学习的过程中应该也是线形成large-scale的grid-cell(也就是特征值大的那个对应的特征向量),然后再逐渐补全剩余的部分,也就是表明了统一问题优先的原则。这就让我们考虑到了generalization的问题,不知道这里是否是和计算机那边DeepRL在PCG-env中考虑的generalization问题一致,Jiang et. al. 2021针对这一类问题从sample-efficiency的角度尝试解决不知道对human在offline时期replay来平衡generalization和reward-maximization有没有帮助,不然不好formalize人类执行任务中的generalization问题的形式。另外一点,就是人类日常生活中所解决的问题往往是POMDP,这也就导致了belief-state representation的出现,比如Gershman et. al. 2019提出mPFC等作为belief-state representation的潜在计算位置。Give the belief to TEM~ 但inf_g的时候,依据p_inf_x生成了sigma_g_input,这对于TEM刚进入新环境而言,其实算是一种不能确定true parameters,是一种epistemic uncertainty,这就变成了一个epistemic POMDP问题。而TEM很好地解决了这个问题,能够在经历一次新环境所有节点之后立刻推断出所有边,做到了epistemic POMDP问题的generalization。

# Not in paper, but this greatly improves zero-shot inference: provide
# the uncertainty function of the inferred abstract location with measures of memory quality
with torch.no_grad():
    # For the first measure, use the grounded location inferred from memory to generate an observation
    x_hat, x_hat_logits = self.gen_x(p_x[0])
    # Then calculate the error between the generated observation and the actual observation:
    # if the memory is working well, this error should be small
    err = utils.squared_error(x, x_hat)
# The second measure is the vector norm of the inferred abstract location; good memories should have
# similar vector norms. Concatenate the two measures as input for the abstract location uncertainty function
# sigma_g_input - (n_f(5), batch_size(16), n_measure(2))
sigma_g_input = [torch.cat(
    (torch.sum(g ** 2, dim=1, keepdim=True), torch.unsqueeze(err, dim=1)), dim=1
) for g in mu_g_mem]
...
# And get standard deviation/uncertainty of inferred abstract location by
# providing uncertainty function with memory quality measures
# sigma_g_mem - (n_f(5), batch_size(16), n_g(30,30,24,18,18))
sigma_g_mem = self.f_sigma_g_mem(sigma_g_input)
NorbertZheng commented 2 years ago

Questions

NorbertZheng commented 2 years ago

Summary

NorbertZheng commented 2 years ago

Future Plan

NorbertZheng commented 2 years ago

Questions

TEM

NorbertZheng commented 1 year ago

Report

The following is a report about Tolman-Eichenbaum Machine (TEM), the corresponding ppt can be downloaded from here.

NorbertZheng commented 1 year ago

TEM as Transformer/GNN

image Here, we can see the architecture of the key component of the Transformer, e.g. self-attention block. Given a number of entities (without sequential order, but entities may contain position embeddings themselves), we use $W{q}$, $W{k}$, $W_{v}$ to calculate $q$, $k$, $v$ corresponding to each entity.

$$ Q=HW{q} \quad K=HW{k} \quad V=HW_{v}, $$

where $H=[X,E]$ contains the feature embeddings $X$ and the position embeddings $E$. After calculating $Q$, $K$, $V$, respectively. We use each item of $Q$ to query $K$, and get the probability of each value item. This process is like aggregating information from all nodes in the graph, but it is different from Graph Neural Network (GNN), e.g. it doesn't exploit the graph adjacency matrix explicitly, the probability is calculated from correlation.

$$ Prob{l,i}=softmax(\frac{q{l,i}K{l}^{T}}{\sqrt{d{k}}}). $$

Then we directly multiply $Prob{l,i}$ with $V$ to get the updated entity $h{l+1,i}$.

$$ h{l+1,i}=Prob{l,i}V_{l}. $$

image Now, let's see how can we re-formulate TEM as such a self-attention block, e.g. what $W{*}$, $H$ indeed correspond to. Let's look at the generative process of TEM. Obviously, the input entity embeddings contain feature embeddings $x$ (e.g. sensory observation or the neural activities of LEC) and position embeddings $g$ (e.g. the neural activities of MEC). Here, we use RNN to update the abstract location $g$, which models the neural activities of MEC. We use an action-specific weight matrix $W{a}$ (which is the output of MLP with no bias in both layers, and the input is the one-hot version of $a$) to get the $\Delta g$.

$$ \Delta g{t-1}=W{a}g_{t-1}. $$

Then we can update $g{t-1}$ with $\Delta g{t-1}$, clamp the value of updated value to get $g_{t}$.

$$ g{t}=f{g}(g{t-1}+\Delta g{t-1}). $$

Both $g{t-1}$ and $g{t}$ are the position embeddings of the corresponding entity. So, how could we get the corresponding query vector $\tilde{g}$? How to express the query matrix $W{q}$? We first downsample $g$ then repeat the downsampled value to get the query vector $\tilde{g}$. Therefore, the query vector $\tilde{g}$ and query matrix $W{q}$ can be formulated as follows:

$$ \tilde{g}{t}=W{repeat}f{down}(g{t}) \quad W{q}=W{repeat}f_{down}(\cdot). $$

We can easily see that the query matrix $W{q}$ of TEM is not like that of the self-attention block in Transformer, e.g. the query matrix $W{q}$ of TEM is hand-coded, instead of learnable.

image Now, let's focus on the memory retrieval part of the generative process. Firstly, we have to understand what the Hebbian memory $M_{t-1}$ (hippocampus) is exactly doing. The calculation equation of $p$ and the update equation of $M$ is as follows:

$$ p=flatten(x^{t}g) \quad M{t}=\sum{\tau=0}^{t}p{\tau}^{T}p{\tau}. $$

From the above equation, we can see that $M_{t}$ is binding every $g$ with every $x$. We should note that we ignore some computational details in the memory update process:

$$ M{t}=\lambda M{t-1}+\eta (p{t}-\hat{p}{t})(p{t}+\hat{p}{t})^{T}. $$

Now, we take one attractor step in TEM with no non-linearity as an example,

$$ \begin{aligned} \tilde{x}{t}^{retrieved}&=sum(unflatten(p{t}^{retrieved}), 1)\ \tilde{p}{t}^{retrieved}&=q{t}M{t-1}=q{t}\sum{\tau=0}^{t-1}p{\tau}^{T}p{\tau}=\sum{\tau=0}^{t-1}[q_{t}\ \end{aligned} $$

Due to that

$$ \begin{aligned} &[q{t}p{\tau}^{T}]=\bar{\tilde{x}}{t} [\tilde{g}{t} \cdot \tilde{g}{\tau}],\ &where \quad \bar{\tilde{x}}=\sum{i}(\tilde{x}{\tau}){i}. \end{aligned} $$

Then we have

$$ p{t}^{retrieved}=\tilde{g}{t}\tilde{G}^{T}\Lambda_{x}P. $$

Finally, we get the following equation:

$$ \tilde{x}{t}^{retrieved}=(\alpha \tilde{g}{t}\tilde{G}^{T})\tilde{X} \quad c.f. \quad softmax(\frac{\tilde{g}{t}G^{T}}{\sqrt{d{k}}})\tilde{X}. $$

Therefore, the key matrix $W_{k}$ can be formulated as follows:

$$ W{k}=W{q}=W{repeat}f{down}(\cdot). $$

We can also find that the key matrix $W{k}$ of TEM is not like that of the self-attention block in Transformer, e.g. the key matrix $W{k}$ of TEM is hand-coded, instead of learnable.

We can see that the generative process in TEM is doing exactly self-attention (but can only attend to past experience, not including future experience).

image Now, let's focus on the inference process of TEM. The Hebbian memory in the inference process is doing exactly the same thing as the Hebbian memory in the generative process. They may use different components (either $g$ or $x$, here is $x$) to calculate the update weights. Here, all we have to know is that $M_{t}$ in the inference process of TEM also binds every $x$ with every $g$. Firstly, we use $\gamma$ s (different modules may have different $\gamma$) to filter the original $x$:

$$ \tilde{x}{t}=filter(x{t})=(1-\gamma)\tilde{x}{t-1}+\gamma x{t}. $$

Similarly, one step attractor can be formulated as follows:

$$ \tilde{g}{t}^{retrieved}=(\alpha \tilde{x}{t}\tilde{X}^{T})\tilde{G} \quad c.f. \quad softmax(\frac{\tilde{x}{t}\tilde{X}^{T}}{\sqrt{d{k}}})\tilde{G}. $$

We can easily find that the inference process of TEM is also doing self-attention, e.g.

$$ W{q}=filter(\cdot) \quad W{k}=filter(\cdot) \quad W_{v}=MLP(\cdot) $$

Different from the generative process of TEM, there are learnable parameters in $W{q}$, $W{k}$, $W{v}$ (although most parts are hand-coded in $filter(\cdot)$, which is also part of $W{v}$ in the generative process).

image Now, we can conclude that

NorbertZheng commented 1 year ago

Relation to other models

image In the Successive Representation (SR) model, they utilize the problem organization form of non-space tasks in reinforcement learning to tackle the rodents' navigation task. After all, the spatial navigation task is just a special case of the reinforcement learning task. We all know that in RL, the value of states can be expressed as follows:

$$ V(s)=\mathbb{E}{\pi}\left[\sum{t=0}^{\infty}\gamma^{t}R(s{t})|s{0}=s\right]. $$

But we do not care about the reward function $R(s)$ (after all, the standard place cells do not care about the reward), so we can decompose the value function $V(s)$ into two parts as follows:

$$ V(s)=\sum{s'}M(s,s')R(s') \quad M(s,s')=\mathbb{E}{\pi}\left[\sum{t=0}^{\infty}\gamma^{t}\mathbb{I}(s{t}=s')|s_{0}=s\right]. $$

$M(s,s')$ is a sort of function that predicts how likely we can get to the target location $s'$, given the current location $s$. Of course, with exponential decay over time. And SR model uses $M(s,s')$ to describe a place cell with the max-firing location at the target location $s'$. When the rodent is at location $s$, the firing rate of place cell $s'$ will predict the accumulated likelihood of getting to location $s'$ in the future.

Pretty simple? But there are some details in the equation we have to notice.

image Now, we can see that claim clearly. $M(s,s')$ is a conjunctive representation of global transition structure $g$ (e.g. the space representation factor) and local transition structure $o$ (e.g. the action representation factor). As we can see, the action representation factor $o$ is associated with the abstract value $r$, but the action representation factor $o$ and the feature representation factor $f$ can still be separated given that we can change the [location, reward value] of the object, which leads to that the action representation factor exists independently (orthogonally) of the object representation. From that view, we can understand why $M(s,s')$ is policy-dependent. Besides, $M(s,s')$ is also limited in the specified task, e.g. neither the location nor the reward value of the object will be changed, which means that $M(s,s')$ directly learns that conjunctive representation, instead of composes two representation factors together. For example, when the walking policy is strict, e.g. we cannot choose any action that causes the agent leaving from the object, the agent will predict there is no way to walk away from the goal location, which has a great difference from the prediction of using the space representation factor. Therefore, the SR model cannot provide a generalizable graph representation (the space representation factor).

image Now, we can agree with one point that the SR model is trying to decompose the conjunctive representation into multiple representation factors heuristically, instead of automatically. So how could we decompose the conjunctive representation (e.g. the task representation) into multiple representation factors (e.g. orthogonal task factors in the task (representation) space) automatically? Here is a work from James Whittington. He finds that adding some biological constraints (namely nonnegativity and energy efficiency in both activity and weights) to the loss function leads to disentangled representation, e.g. decomposing the task representation into multiple representation factors automatically.

$$ \begin{aligned} \mathcal{L}&=\underbrace{\mathcal{L}{non-neg}+\mathcal{L}{activity}+\mathcal{L}{weight}}{Biological\ constraints}+\underbrace{\mathcal{L}{prediction}}{Functional constraints}\ &\mathcal{L}{non-neg}=\beta{non-neg}\sum{i}max(-a{i},0)\ &\mathcal{L}{activity}=\beta{activity}\sum{l}||a{l}||^{2}\ &\mathcal{L}{weight}=\beta{weight}\sum{l}||W{l}||^{2}\ &\mathcal{L}{prediction}=\beta{prediction}||\hat{y}-y||^{2}\ \end{aligned} $$

In the first figure, the biological constraints are not added to the model, and we can see that each hidden neuron may respond to the change of multiple task factors. But if we add the biological constraints to the model, each hidden neuron will only respond to the change of one task factor, e.g. disentangled representation. And this means that we automatically decompose the task representation to orthogonal task factors, with each hidden neuron representing at most one task factor. Now, we can conclude that the SR model is just a subset of TEM-disentangled.

image Now, we come to Clone-Structured Cognitive Graph (CSCG). This model can also learn a graph representation, but just like the SR model, the learned graph representation is also not generalizable. And due to that the walking policy is diffusion by default, the graph representation learned by CSCG is exactly the space representation (without combining with the action representation). Despite its non-generalizable graph representation, CSCG can learn such representation quickly. And in the beginning of entering environments with a brand-new structure, the hippocampus serves as a graph, instead of associative memory. Therefore, we can use CSCG to model the hippocampus, and then generate replays from CSCG to facilitate the structure abstraction process of TEM. Cheers~

image Spatial Memory Pipeline (SMP) is another model built to explain the computational mechanisms of the hippocampal-entorhinal system. The SMP model has a similar architecture to TEM, e.g. VAE. But the SMP model is much more complex than TEM. SMP uses a memory bank in the machine learning field to model the hippocampus, which is not bio-plausible as the attractor network (after all, the attractor network is built by the computational neuroscience community). And its path integration part is also complex. Of course, such a complex model provides a much more powerful information processing ability than TEM. It can directly process raw visual sequences! And we can observe similar neural representation in the SMP model. However, due to the complexity of the model, the SMP model is not as good as TEM to explain the key principle of the hippocampal-entorhinal neural system.

NorbertZheng commented 1 year ago

More about disentangled representation

We use a discrete $16\times 16$ world, (so 256 locations; $n{l}=256$) and optimize an independent representation, $z(x) \in \mathbb{R}^{n{c}}, at each location. We now detail each component of the following loss

$$ \mathcal{L}=\underbrace{\mathcal{L}{nonneg}+\mathcal{L}{activity}+\mathcal{L}{weight}}{Biological\ constraints}+\underbrace{\mathcal{L}{location}+\mathcal{L}{actions}+\mathcal{L}{objects}}{Functional\ constraints}+\underbrace{\mathcal{L}{path\ integration}}{Structure\ constraints}. $$

$$ \begin{aligned} &\mathcal{L}{nonneg}=\frac{\beta{nonneg}}{n{l}}\sum{x}\sum{i}max(-z{i}(x),0)\ &\mathcal{L}{weight}=\beta{weight}\sum{t}||W{t}||^{2}\ &\mathcal{L}{activity}=\frac{\beta{activity}}{n{l}}\sum{x}||z(x)||^{2}, \end{aligned} $$

where $z_{i}(x)$ is a neuron in representation $z(x)$, $t$ indexes the task (i.e. object, action, location, prediction), and the $\beta$ determines the regularization strength.

$$ \mathcal{L}{location}=-\frac{\beta{location}}{n{l}}\sum{x}ln\frac{e^{l{x}\cdot z(x)}}{\sum{x'}e^{l_{x'}\cdot z(x)}}. $$

$$ \mathcal{L}{object}=-{\beta{object}}{n{l}}\sum{x}\left[\mathbb{I}(object\ at\ x)ln\sigma(W{o}z(x))+\mathbb{I}(object\ not\ at\ x)ln(1- \sigma(W{o}z(x)))\right]. $$

$$ \mathcal{L}{action}=-\frac{\beta{action}}{n{l}}\sum{x}\sum{a}\left[\mathbb{I}(a=a(x))ln\sigma(t{a}\cdot z(x))+(\mathbb{I}(a\neq a(x)))ln(1-\sigma(t_{a}\cdot z(x)))\right]. $$

$$ \mathcal{L}{path\ integration}=\frac{\beta{path\ integration}}{n{l}}\sum{x}\sum{a}||z(x)-f(W{a}z(x-d_{a}))||^{2}, $$

where $W{a} \in \mathbb{R}^{n{c}\times n{c}} is a weight matrix that depends on action, $a$, (i.e. there are $4$ trainable weights matrices - one for each action), and $d{a}$ means the displacement in the underlying space (the space of $x$), that the action $a$ corresponds to.

NorbertZheng commented 1 year ago

Pattern forming dynamics. The overall loss can be optimized with respect to the weights. However, it can also be optimized directly with respect to $z$. This is particularly interesting for us, as it

This is a necessity for us as we need to represent objects which may move between tasks. Planning? To optimize both $z$ (task particularities) and weights (task generalities), we do so in two stages. First, we optimize with respect to $z$ to "infer" a representation for the current task. Second, we optimize with respect to the weights to learn parameters that are general across tasks.

When optimizing with respect to $z$ we only optimize two terms in the loss: $\mathcal{L}{object}$ and $\mathcal{L}{path\ integration}$. We optimize the first term so the systems has the ability to know where the objects are. We optimize the second term so that information can be propagated around (effectively via path integration).

The dynamics of the $\mathcal{L}_{object}$ are:

$$ \frac{d\mathcal{L}{object}}{dz(x)}=-W{objects}^{T}(\mathbb{I}(object\ at\ x)-\sigma(W_{o}z(x))). $$

This says if you get the object prediction wrong, then update $z$ to better predict the object. We restrict this update to only take place where the object is, so it is just an object signal. This update is equivalent to a rodent observing that it is at an object.

The dynamics of the$\mathcal{L}_{path\ integration}$ are:

$$ \begin{aligned} \frac{d\mathcal{L}{path\ integration}}{dz(x)}=&\sum{a}[-(z(x)-f(W{a}z(x-d{a}))\ &+W{a}^{T}(z(x+d{a})-f(W{a}z(x)))\odot f'(W{a}z(x))]. \end{aligned} $$

The two terms in the above equation can be easily understood. The first says that the representation at each location, $z(x)$, should be updated according to what its neighbors think it should (this is the same update rule as path integration!). The second term says the representation at each location, $z(x)$, should be updated if did not predict its neighbors correctly. This equation tells representations to update based on their neighbors. This is just like a cellular automata, but instead of a discrete value being updated on the basis of its neighbors, it is a whole population vector whose elements are continuous. Indeed, just like cellular automata, it is also possible to initialize a single "cell" (location) of the cellular automata, and have that representations propagate throughout the space. In this case, it's just like path integration, but spreading through all space at once. We note, however, that in our simulations we initialize representations at all locations (for each walk).

We note that while we simulated this on a discrete grid, the same principles apply to continuous cases. In this case the sums over location/actions need to be replaced with integrals.

This is very general approach for understanding representations. The structural loss does not have to be related to the rules of path integration. It can be anything. It could be the rules of a graph. It could be rules of topology. It could have one set of rules at some locations and another set of rules at other locations.

If there are structure or rules in the world or behavior, our constraints say that representations should understand that structures or rules. In mathematics this is known as a homeomorphism. In sum, understanding representations via constraints is very general.