lucidrains / q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind
MIT License
333 stars 19 forks source link

question about Q-head #11

Closed 2M-kotb closed 4 months ago

2M-kotb commented 5 months ago

Hi, Thank you for the code, really nice work.

I am new to the transformer architecture, and I am taking this code as a guidance to implement a simple q-transformer that works with a single task (i.e. not language conditioned) using states as observations and not images.

So, I think the "QHeadMultipleActions" class is only needed in my case. However, Do I still need the cross attention layer? There is no language or images in my case.

Thank you

lucidrains commented 5 months ago

@2M-kotb hey Mostafa

what you are planning to do deviates from the paper a lot

i can build you a simpler variant for non-image states w/o language guidance, if you plan on testing out the autoregressive Q learning. if you are doing just simple Q learning, then you can prob just use a way simpler library

2M-kotb commented 5 months ago

Thank you @lucidrains for your response

I really appreciate any help. I want to test the new Bellman updates that presented in the paper, I am not 100% convinced. That is why I decided to start off easy with states as observations and without language conditioned.

Also during inference, the autoregressive predictions of action dims is weird and seems you lose a lot as you only attend on previously predicted dims only. What if the current dim is entangled with dim not yet predicted!!

What I did so far ...

I took the QHeadMultipleActionsclass and I added a Linear layer as state embedder.

I did not use cross attention in the transformer, only the self attention and the Feedforward network.

I am testing with MetaWorld tasks with 4 action dims, but still the performance is not good.

Thanks

lucidrains commented 5 months ago

@2M-kotb what are the state dimensions for MetaWorld?

2M-kotb commented 5 months ago

@lucidrains state dim is 39 in metaworld

lucidrains commented 5 months ago

@2M-kotb is it robotics? this?

2M-kotb commented 5 months ago

@lucidrains Yes, it is robotic manipulation tasks

lucidrains commented 5 months ago

@2M-kotb 39 seems too small? it looks like it has images from quick perusal?

2M-kotb commented 5 months ago

@lucidrains MetaWorld support using states instead of images as observations

lucidrains commented 5 months ago

@2M-kotb i see, not the states for the objects the arm is going to pick up right? that would be a bit cheating

lucidrains commented 5 months ago

@2M-kotb ok, i understand now, will see what i can do

have you tried using the image setting on Metaworld with what is in this repository? that should work if you strip out the cross attention

2M-kotb commented 5 months ago

@lucidrains No, I have not tried with images. I want to make sure the algorithm is working before moving to images.

Thank you in advance :)

lucidrains commented 5 months ago

@2M-kotb sounds good, the most immediate thing i can help you is remove the language conditioning

you can test the images route that way first

but i would like to build out a simple MLP for the arm states, as planning to do some RL training myself in the coming year

lucidrains commented 5 months ago

give me about a week

lucidrains commented 4 months ago

@2M-kotb oh, i have already built in the option to remove language conditioning here

2M-kotb commented 4 months ago

@lucidrains yes, I noticed. I am working on using images instead of states at the moment.

Thank you again for your contribution , I will let you know if I have any questions

2M-kotb commented 4 months ago

@lucidrains Regarding the MaxViT model. It works with high resolution images (224x224), in my environment images are 84x84, Do I still need MaxViT? or regular ViT is enough?

lucidrains commented 4 months ago

@2M-kotb just stick with maxvit, regular vit is hard to train from scratch

even the maxvit in this repo can benefit from some pretrained weights, which is not done

could you try the following vit setting in the latest version? it should work for 84x84

    vit = dict(
        num_classes = 1000,
        dim_conv_stem = 64,
        dim = 64,
        dim_head = 64,
        depth = (2, 4),
        window_size = 7,
        conv_stem_downsample = False,
        mbconv_expansion_rate = 4,
        mbconv_shrinkage_rate = 0.25,
        dropout = 0.1
    ),
lucidrains commented 4 months ago

@2M-kotb oh, saw you closed this

did it work?

2M-kotb commented 4 months ago

@lucidrains using states as observations finally worked.

It turns out that the conservative loss is not needed in my case as I am not using offline RL.

I used a simple GPT-like transformer with learned positional embeddings and I also used the dueling trick that you presented in your implementation.

Here is a screenshot of "open-drawer" task from MetaWorld

Screenshot 2024-05-15 at 11 04 43

You can see after 200k steps, the success rate is 100%. The learning is a bit slow and when I try harder tasks than this one, the learning is not stable. This is a common problem with Q-learning.

lucidrains commented 4 months ago

@2M-kotb woohoo! this is the simple Q, or the autoregressive flavor the authors proposed?

lucidrains commented 4 months ago

@2M-kotb do you know what are the current best solutions for dealing with overestimation bias?

2M-kotb commented 4 months ago

@lucidrains

This is the autoregressive Q-learning with 4 action dims.

Regarding overestimation bias, the best solution is using two separate Q-Network and take the min of their Q estimation as the target. Some papers even used more than 2 networks.

lucidrains commented 4 months ago

@2M-kotb ok cool, yea i'm running into Q networks as i dive into the SAC side of things, so maybe i'll start building out some of these bandaid solutions for overestimation bias, here and in other repos

lucidrains commented 4 months ago

@2M-kotb let's definitely keep in touch! i don't meet many students working with robotics, but its something i want to get into over the next year

2M-kotb commented 4 months ago

@lucidrains Sure, let's keep in touch.

I will keep you updated. Once I successfully solve the overestimation problem, I let you know what solution I took.