OpenGVLab / Vision-RWKV

Vision-RWKV: Efficient and Scalable Visual Perception with RWKV-Like Architectures
https://arxiv.org/abs/2403.02308
Apache License 2.0
288 stars 11 forks source link

Question about the CUDA code #5

Closed yxchng closed 2 months ago

yxchng commented 3 months ago

How to understand

for (int i = _t; i < (_t + _tokenLength); i++){
      const int ii = i * C;
      F no = max(o1, k[ii] - w * (i - _t));
      F e1 = exp(o1 - no);
      F e3 = exp(k[ii] - w * (i - _t) - no);
      c = e1 * c + e3 * v[ii];
      d = e1 * d + e3;
      o1 = no;
      const int ni = 2 * _t + _tokenLength - 1 - i;
      const int nini = ni * C;
      const int exp_w = _t + _tokenLength - ni;
      no = max(o2, k[nini] - w * exp_w);
      F e2 = exp(o2 - no);
      e3 = exp(k[nini] - w * exp_w - no);
      a = e2 * a + e3 * v[nini];
      b = e2 * b + e3;
      o2 = no;
  }

https://github.com/OpenGVLab/Vision-RWKV/blob/master/segmentation/mmseg_custom/models/backbones/base/cuda/wkv_cuda.cu#L32C5-L50C1

It does not seem to correspond to any part of the equations in the paper

duanduanduanyuchen commented 3 months ago

How to understand

for (int i = _t; i < (_t + _tokenLength); i++){
      const int ii = i * C;
      F no = max(o1, k[ii] - w * (i - _t));
      F e1 = exp(o1 - no);
      F e3 = exp(k[ii] - w * (i - _t) - no);
      c = e1 * c + e3 * v[ii];
      d = e1 * d + e3;
      o1 = no;
      const int ni = 2 * _t + _tokenLength - 1 - i;
      const int nini = ni * C;
      const int exp_w = _t + _tokenLength - ni;
      no = max(o2, k[nini] - w * exp_w);
      F e2 = exp(o2 - no);
      e3 = exp(k[nini] - w * exp_w - no);
      a = e2 * a + e3 * v[nini];
      b = e2 * b + e3;
      o2 = no;
  }

https://github.com/OpenGVLab/Vision-RWKV/blob/master/segmentation/mmseg_custom/models/backbones/base/cuda/wkv_cuda.cu#L32C5-L50C1

It does not seem to correspond to any part of the equations in the paper

Hi, thank you for your attention to our work! This part aims to calculate the initial value of the states (a, b, c, d) in each thread.

yxchng commented 3 months ago

do you mind explaining why there are many max and minus no in this intialization? Also, why const int ni = 2 * _t + _tokenLength - 1 - i;?

duanduanduanyuchen commented 3 months ago

do you mind explaining why there are many max and minus no in this intialization? Also, why const int ni = 2 * _t + _tokenLength - 1 - i;?

The max and minus no are for the safe calculation of the exponential. We minus the max value to avoid the risk of overflow.

The const int ni = 2 * _t + _tokenLength - 1 - i; is a convenient index for selecting the element from the last to the first index in the current thread.

BlinkDL commented 2 months ago

Please check RWKV-4 paper for more info: https://arxiv.org/abs/2305.13048