lucidrains / q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind
MIT License
347 stars 20 forks source link

A simple question about the code #8

Closed KID0031 closed 10 months ago

KID0031 commented 10 months ago

Hi, @lucidrains, I'm a beginner trying to use Q-transformer and encountered a question while reading the code. In the QHeadMultipleActions class, I noticed that Q-transformer encodes the bin into an embedding using self.action_bin_embeddings. However, when obtaining the q value, it multiplies the attention output with self.action_bin_embeddings once again. Is there a specific reason for using this approach to derive the q value instead of employing a new MLP layer multiplied by the attention output? I've shared the relevant code below. Thank you!

def maybe_append_actions(self, sos_tokens, actions: Optional[Tensor] = None):
        if not exists(actions):
            return sos_tokens

        batch, num_actions = actions.shape
        action_embeddings = self.action_bin_embeddings[:num_actions]

        action_embeddings = repeat(action_embeddings, 'n a d -> b n a d', b = batch)
        past_action_bins = repeat(actions, 'b n -> b n 1 d', d = action_embeddings.shape[-1])

        bin_embeddings = action_embeddings.gather(-2, past_action_bins)
        bin_embeddings = rearrange(bin_embeddings, 'b n 1 d -> b n d')

        tokens, _ = pack((sos_tokens, bin_embeddings), 'b * d')
        tokens = tokens[:, :self.num_actions] # last action bin not needed for the proposed q-learning
        return tokens

def get_q_values(self, embed):
        num_actions = embed.shape[-2]
        action_bin_embeddings = self.action_bin_embeddings[:num_actions]

        if self.dueling:
            advantages = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings)

            values = einsum('b n d, n d -> b n', embed, self.to_values[:num_actions])
            values = rearrange(values, 'b n -> b n 1')

            q_values = values + (advantages - reduce(advantages, '... a -> ... 1', 'mean'))
        else:
            q_values = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings)

        return q_values.sigmoid()
lucidrains commented 10 months ago

@KID0031 hey thanks for your interest

yes you are right that a linear projection would probably do fine there as well. i'm following the weight tied output embedding technique from earlier transformer architectures (which in theory should allow the network to learn better embeddings), but that has been shown to be unnecessary

i'll make it an option to do it the way you describe

lucidrains commented 10 months ago

@KID0031 try setting this to False

on reflection, i think i had a bug in the weight tied action bin embeddings, so thanks for raising this