berlino / gated_linear_attention

MIT License
97 stars 2 forks source link

How to get St? #6

Closed JL-er closed 8 months ago

JL-er commented 10 months ago

image

I wanted to get St but I couldn't find his exact location, I returned memory_cache but he doesn't seem to be equal to St

sustcsonglin commented 10 months ago

Hi, we didn't include the recurrent form of the model. If you need this feature we will consider adding it soon. Our code is chunkwise, i..e, storing the hidden state of the last element for a single contiguous chunk.

JL-er commented 10 months ago

Hi, we didn't include the recurrent form of the model. If you need this feature we will consider adding it soon. Our code is chunkwise, i..e, storing the hidden state of the last element for a single contiguous chunk.

Thanks,because I need to infer, I require the final hidden state. And I want know what memory_cache is. image

sustcsonglin commented 10 months ago

i will add an interface to obtain the final hidden state soon.

JL-er commented 10 months ago

i will add an interface to obtain the final hidden state soon.

Thank you so much

sustcsonglin commented 9 months ago

Hi, you can check out impl1 and impl2 to obtain the final hidden state.

JL-er commented 9 months ago

Hi, you can check out impl1 and impl2 to obtain the final hidden state.

超感谢,上周我学了一点triton就自己写了last state测试没有什么问题,现在正好可以学习一下你怎么写的初始化state。对了还有个问题,就是intra部分会在块间在分块最小t需要满足16的倍数,所以目前推理的时候需要pad计算的时候需要mask,我感觉pad有些浪费不知道你有没有好的解决办法

sustcsonglin commented 9 months ago

你的推理指的是什么? 是自回归推理 还是 eval一下给定句子的PPL? 目前感觉Pad到16的倍数不是很cost,所以暂时还没有给GLA加内部的padding机制(需要改cuda kernel有点麻烦)。retnet的实现倒是在triton kernel里面加了mask的机制,所以不需要pad。

JL-er commented 9 months ago

你的推理指的是什么? 是自回归推理 还是 eval一下给定句子的PPL? 目前感觉Pad到16的倍数不是很cost,所以暂时还没有给GLA加内部的padding机制(需要改cuda kernel有点麻烦)。retnet的实现倒是在triton kernel里面加了mask的机制,所以不需要pad。

都有,推理目前没有区分块。还有个问题是,我测试chunksize的时候发现chunk越小越快,这是对的吗

sustcsonglin commented 9 months ago

如果head dimension不是很大的话,I/O cost不会很大。在这个情况下,chunk越小,并行度越高,所以越快是很有可能的。但通常为了达到更好的效果,head dimension会比较大,所以chunk越小可能I/O越大,并行带来的优势不如I/O的cost。这里面存在一个tradeoff。

JL-er commented 9 months ago

如果head dimension不是很大的话,I/O cost不会很大。在这个情况下,chunk越小,并行度越高,所以越快是很有可能的。但通常为了达到更好的效果,head dimension会比较大,所以chunk越小可能I/O越大,并行带来的优势不如I/O的cost。这里面存在一个tradeoff。

谢谢,bsz 8的情况chunk 16更快应该并行度不够,bsz拉大应该更适合128也会更快