eloialonso / iris

Transformers are Sample-Efficient World Models. ICLR 2023, notable top 5%.
https://openreview.net/forum?id=vhFu1Acb0xb
GNU General Public License v3.0
805 stars 80 forks source link

World model working #24

Closed suryadheeshjith closed 8 months ago

suryadheeshjith commented 9 months ago

Hello,

Thank you for the great work you have done! I had a question regarding the working of the world model.

From the paper, 'Our autoregressive Transformer is based on the implementation of minGPT (Karpathy, 2020). It takes as input a sequence of L(K + 1) tokens and embeds it into a L(K + 1) x D tensor using an A x D embedding table for actions, and a N x D embedding table for frames tokens'

In the code world_model.py,

def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs: Any) -> LossWithIntermediateLosses:
        with torch.no_grad():        
            obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens

Why are we using the token indices as input and then embedding them into a new space when we already have encodings of the frame from the tokenizer?

vmicheli commented 9 months ago

Hi,

Thanks for the kind words!

In early experiments we tried using the frame encodings from the tokenizer, but this approach did not yield better results than learning frame tokens embeddings from scratch.