OpenGVLab / Vision-RWKV

Vision-RWKV: Efficient and Scalable Visual Perception with RWKV-Like Architectures
https://arxiv.org/abs/2403.02308
Apache License 2.0
371 stars 14 forks source link

Confusion on the spatial decay vector ``w'' in wkv6_cuda.cu #21

Closed zhouyiks closed 3 months ago

zhouyiks commented 5 months ago

https://github.com/OpenGVLab/Vision-RWKV/blob/9e458035730790ca2faafaed1b2d9446f5b80b1d/classification/mmcls_custom/models/backbones/cuda_v6/wkv6_cuda.cu#L26C9-L26C27

The original rwkv-6 calculate decay by exp(-exp(w)) to keep the value in (0,1). But exp(w) is used here, I'm confused about this code here.

duanduanduanyuchen commented 4 months ago

https://github.com/OpenGVLab/Vision-RWKV/blob/9e458035730790ca2faafaed1b2d9446f5b80b1d/classification/mmcls_custom/models/backbones/cuda_v6/wkv6_cuda.cu#L26C9-L26C27

The original rwkv-6 calculate decay by exp(-exp(w)) to keep the value in (0,1). But exp(w) is used here, I'm confused about this code here.

The operation of exp(-exp(w)) will limit the decay between 0 and 1 (use exp(-exp(w)) as the new decay vector). This may introduce more nonlinearity and limit the flexibility of spatial decay. We release the limitation and add a normalized factor T to the decay (w/T) to avoid potential overflow.