Open ideAxel opened 7 months ago
There is a typo in the comment l.282 : "# reshape '(b t) p c -> (b t) p c'" It should be : "# reshape '(b t) c p -> (b t) p c'"
At line 298, x should be of size : (b t) p c so you should extract BT, P C from its shape. To get B, you should get it from the input x of the forward I guess :
B = x.shape[0] at line 292 before prepare_tokens
Then you could replace line 298 with :
bt, P, C = x.shape
And line 299 with :
T = int(bt / B) x = x.unflatten(0, (B, T)) # reshape '(b t) p c -> b t p c' x =x.permute(0, 2, 1, 3) # reshape 'b t p c -> b p t c' x =x.flatten(0, 1) # reshape 'b p t c -> (b p) t c' Maybe I'm wrong, tell me what you think about it @ideAxel
Hello Alex, Great take. I made a first attempt to make it work with the PatchEmbed method, we also need to adapt the interpolate_pos_encoding in the future.
Improved the Tensor manipulation in the SpatioTemporalTransformer Module. Processed to a reshape of the data as follow : 1 Spatial_Attention b t p c -> (bt) p c 2 Temporal_Attention b t p c -> (bp) t c .