OpenGVLab / Vision-RWKV

Vision-RWKV: Efficient and Scalable Visual Perception with RWKV-Like Architectures
https://arxiv.org/abs/2403.02308
Apache License 2.0
371 stars 14 forks source link

About the Q-shift operation in the code #37

Open kkkkkk123-ops opened 2 months ago

kkkkkk123-ops commented 2 months ago

in the jit_func in Class VRWKV_SpatialMix_V6 why we need to -x after the shift_func? It seems not -x in the _inner_forward in Class VRWKV_ChannelMix when calculate xx.

Class VRWKV_SpatialMix_V6

def jit_func(self, x, patch_resolution):

Mix x with the previous timestep to produce xk, xv, xr

    B, T, C = x.size()

    xx = self.shift_func(x, self.shift_pixel, patch_resolution=patch_resolution, 
                         with_cls_token=self.with_cls_token) - x
    xxx = x + xx * self.time_maa_x  # [B, T, C]

Class VRWKV_ChannelMix

  def _inner_forward(x):
        xx = self.shift_func(x, self.shift_pixel, patch_resolution=patch_resolution,
                             with_cls_token=self.with_cls_token)
duanduanduanyuchen commented 2 weeks ago

Hi! The subtraction in the interpolation in vrwkv6 gives a formula equivalent to that in VRWKV. VRWKV6: xx = shift(x) - x xxx = x + \mu xx = x + \mu (shift(x) - x) = (1 - \mu) x + \mu shift(x)

VRWKV: xx = shift(x) xxx = \mu x + (1 - \mu) xx = \mu x + (1 - \mu) shift(x)