alexcbb / Genie-Generative-Interactive-Environments

This repo aims to reproduce and open the results obtained from "Generative Interactive Environments" of Google DeepMind.
MIT License
5 stars 2 forks source link

Improved the Tensor manipulation #7

Open ideAxel opened 5 months ago

ideAxel commented 5 months ago

Improved the Tensor manipulation in the SpatioTemporalTransformer Module. Processed to a reshape of the data as follow : 1 Spatial_Attention b t p c -> (bt) p c 2 Temporal_Attention b t p c -> (bp) t c .

alexcbb commented 5 months ago

There is a typo in the comment l.282 : "# reshape '(b t) p c -> (b t) p c'" It should be : "# reshape '(b t) c p -> (b t) p c'"

At line 298, x should be of size : (b t) p c so you should extract BT, P C from its shape. To get B, you should get it from the input x of the forward I guess :

B = x.shape[0] at line 292 before prepare_tokens

Then you could replace line 298 with :

bt, P, C = x.shape

And line 299 with :

T = int(bt / B) x = x.unflatten(0, (B, T)) # reshape '(b t) p c -> b t p c' x =x.permute(0, 2, 1, 3) # reshape 'b t p c -> b p t c' x =x.flatten(0, 1) # reshape 'b p t c -> (b p) t c' Maybe I'm wrong, tell me what you think about it @ideAxel

ideAxel commented 4 months ago

Hello Alex, Great take. I made a first attempt to make it work with the PatchEmbed method, we also need to adapt the interpolate_pos_encoding in the future.