DigiRL-agent / digirl

Official repo for paper DigiRL: Training In-The-Wild Device-Control Agents with Autonomous Reinforcement Learning.
Apache License 2.0
269 stars 21 forks source link

Errors in loading model from checkpoint #16

Closed tlc4418 closed 2 months ago

tlc4418 commented 3 months ago

Hi, I had the same issue as @mousewu in Issue #11 , where there were dimensions errors (when loading this checkpoint).

Doing as suggested in the issue and adding the information to the config.json fixed this problem and I managed to load the model. However, I now get the following warnings about both unused and missing weights:

Some weights of the model checkpoint at JackBAI/webshop-off2on-digirl were not used when initializing T5ForMultimodalGeneration: ['decoder.block.0.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.0.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.1.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.1.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.10.layer.0.SelfAttention.k.weight', 'decoder.block.10.layer.0.SelfAttention.o.weight', 'decoder.block.10.layer.0.SelfAttention.q.weight', 'decoder.block.10.layer.0.SelfAttention.v.weight', 'decoder.block.10.layer.0.layer_norm.weight', 'decoder.block.10.layer.1.EncDecAttention.k.weight', 'decoder.block.10.layer.1.EncDecAttention.o.weight', 'decoder.block.10.layer.1.EncDecAttention.q.weight', 'decoder.block.10.layer.1.EncDecAttention.v.weight', 'decoder.block.10.layer.1.layer_norm.weight', 'decoder.block.10.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.10.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.10.layer.2.DenseReluDense.wo.weight', 'decoder.block.10.layer.2.layer_norm.weight', 'decoder.block.11.layer.0.SelfAttention.k.weight', 'decoder.block.11.layer.0.SelfAttention.o.weight', 'decoder.block.11.layer.0.SelfAttention.q.weight', 'decoder.block.11.layer.0.SelfAttention.v.weight', 'decoder.block.11.layer.0.layer_norm.weight', 'decoder.block.11.layer.1.EncDecAttention.k.weight', 'decoder.block.11.layer.1.EncDecAttention.o.weight', 'decoder.block.11.layer.1.EncDecAttention.q.weight', 'decoder.block.11.layer.1.EncDecAttention.v.weight', 'decoder.block.11.layer.1.layer_norm.weight', 'decoder.block.11.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.11.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.11.layer.2.DenseReluDense.wo.weight', 'decoder.block.11.layer.2.layer_norm.weight', 'decoder.block.2.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.2.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.3.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.3.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.4.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.4.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.5.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.5.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.6.layer.0.SelfAttention.k.weight', 'decoder.block.6.layer.0.SelfAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.q.weight', 'decoder.block.6.layer.0.SelfAttention.v.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.6.layer.1.EncDecAttention.k.weight', 'decoder.block.6.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.1.EncDecAttention.q.weight', 'decoder.block.6.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.1.layer_norm.weight', 'decoder.block.6.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.6.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.6.layer.2.DenseReluDense.wo.weight', 'decoder.block.6.layer.2.layer_norm.weight', 'decoder.block.7.layer.0.SelfAttention.k.weight', 'decoder.block.7.layer.0.SelfAttention.o.weight', 'decoder.block.7.layer.0.SelfAttention.q.weight', 'decoder.block.7.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.0.layer_norm.weight', 'decoder.block.7.layer.1.EncDecAttention.k.weight', 'decoder.block.7.layer.1.EncDecAttention.o.weight', 'decoder.block.7.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.1.EncDecAttention.v.weight', 'decoder.block.7.layer.1.layer_norm.weight', 'decoder.block.7.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.7.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.7.layer.2.DenseReluDense.wo.weight', 'decoder.block.7.layer.2.layer_norm.weight', 'decoder.block.8.layer.0.SelfAttention.k.weight', 'decoder.block.8.layer.0.SelfAttention.o.weight', 'decoder.block.8.layer.0.SelfAttention.q.weight', 'decoder.block.8.layer.0.SelfAttention.v.weight', 'decoder.block.8.layer.0.layer_norm.weight', 'decoder.block.8.layer.1.EncDecAttention.k.weight', 'decoder.block.8.layer.1.EncDecAttention.o.weight', 'decoder.block.8.layer.1.EncDecAttention.q.weight', 'decoder.block.8.layer.1.EncDecAttention.v.weight', 'decoder.block.8.layer.1.layer_norm.weight', 'decoder.block.8.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.8.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.8.layer.2.DenseReluDense.wo.weight', 'decoder.block.8.layer.2.layer_norm.weight', 'decoder.block.9.layer.0.SelfAttention.k.weight', 'decoder.block.9.layer.0.SelfAttention.o.weight', 'decoder.block.9.layer.0.SelfAttention.q.weight', 'decoder.block.9.layer.0.SelfAttention.v.weight', 'decoder.block.9.layer.0.layer_norm.weight', 'decoder.block.9.layer.1.EncDecAttention.k.weight', 'decoder.block.9.layer.1.EncDecAttention.o.weight', 'decoder.block.9.layer.1.EncDecAttention.q.weight', 'decoder.block.9.layer.1.EncDecAttention.v.weight', 'decoder.block.9.layer.1.layer_norm.weight', 'decoder.block.9.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.9.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.9.layer.2.DenseReluDense.wo.weight', 'decoder.block.9.layer.2.layer_norm.weight', 'encoder.block.0.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.0.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.1.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.1.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.10.layer.0.SelfAttention.k.weight', 'encoder.block.10.layer.0.SelfAttention.o.weight', 'encoder.block.10.layer.0.SelfAttention.q.weight', 'encoder.block.10.layer.0.SelfAttention.v.weight', 'encoder.block.10.layer.0.layer_norm.weight', 'encoder.block.10.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.10.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.10.layer.1.DenseReluDense.wo.weight', 'encoder.block.10.layer.1.layer_norm.weight', 'encoder.block.11.layer.0.SelfAttention.k.weight', 'encoder.block.11.layer.0.SelfAttention.o.weight', 'encoder.block.11.layer.0.SelfAttention.q.weight', 'encoder.block.11.layer.0.SelfAttention.v.weight', 'encoder.block.11.layer.0.layer_norm.weight', 'encoder.block.11.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.11.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.11.layer.1.DenseReluDense.wo.weight', 'encoder.block.11.layer.1.layer_norm.weight', 'encoder.block.2.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.2.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.3.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.3.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.4.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.4.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.5.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.5.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.6.layer.0.SelfAttention.k.weight', 'encoder.block.6.layer.0.SelfAttention.o.weight', 'encoder.block.6.layer.0.SelfAttention.q.weight', 'encoder.block.6.layer.0.SelfAttention.v.weight', 'encoder.block.6.layer.0.layer_norm.weight', 'encoder.block.6.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.6.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.6.layer.1.DenseReluDense.wo.weight', 'encoder.block.6.layer.1.layer_norm.weight', 'encoder.block.7.layer.0.SelfAttention.k.weight', 'encoder.block.7.layer.0.SelfAttention.o.weight', 'encoder.block.7.layer.0.SelfAttention.q.weight', 'encoder.block.7.layer.0.SelfAttention.v.weight', 'encoder.block.7.layer.0.layer_norm.weight', 'encoder.block.7.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.7.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.7.layer.1.DenseReluDense.wo.weight', 'encoder.block.7.layer.1.layer_norm.weight', 'encoder.block.8.layer.0.SelfAttention.k.weight', 'encoder.block.8.layer.0.SelfAttention.o.weight', 'encoder.block.8.layer.0.SelfAttention.q.weight', 'encoder.block.8.layer.0.SelfAttention.v.weight', 'encoder.block.8.layer.0.layer_norm.weight', 'encoder.block.8.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.8.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.8.layer.1.DenseReluDense.wo.weight', 'encoder.block.8.layer.1.layer_norm.weight', 'encoder.block.9.layer.0.SelfAttention.k.weight', 'encoder.block.9.layer.0.SelfAttention.o.weight', 'encoder.block.9.layer.0.SelfAttention.q.weight', 'encoder.block.9.layer.0.SelfAttention.v.weight', 'encoder.block.9.layer.0.layer_norm.weight', 'encoder.block.9.layer.1.DenseReluDense.wi_0.weight', 'encoder.block.9.layer.1.DenseReluDense.wi_1.weight', 'encoder.block.9.layer.1.DenseReluDense.wo.weight', 'encoder.block.9.layer.1.layer_norm.weight']
- This IS expected if you are initializing T5ForMultimodalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing T5ForMultimodalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of T5ForMultimodalGeneration were not initialized from the model checkpoint at JackBAI/webshop-off2on-digirl and are newly initialized: ['decoder.block.0.layer.2.DenseReluDense.wi.weight', 'decoder.block.1.layer.2.DenseReluDense.wi.weight', 'decoder.block.2.layer.2.DenseReluDense.wi.weight', 'decoder.block.3.layer.2.DenseReluDense.wi.weight', 'decoder.block.4.layer.2.DenseReluDense.wi.weight', 'decoder.block.5.layer.2.DenseReluDense.wi.weight', 'encoder.block.0.layer.1.DenseReluDense.wi.weight', 'encoder.block.1.layer.1.DenseReluDense.wi.weight', 'encoder.block.2.layer.1.DenseReluDense.wi.weight', 'encoder.block.3.layer.1.DenseReluDense.wi.weight', 'encoder.block.4.layer.1.DenseReluDense.wi.weight', 'encoder.block.5.layer.1.DenseReluDense.wi.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

When I try to perform inference using this model, the output is gibberish, as pointed out in Issue #14. I believe this nonsensical generation output might be linked to the above warnings. Maybe due to a mismatch between the model architecture provided in the code, and the model checkpoint you provide in Huggingface? Would you be able to shed some light on this and explain how best to load your final model checkpoint for inference?

Thank you

BiEchi commented 2 months ago

It might be a problem with the config.json file. I just deleted it on HuggingFace. Would you like to delete the config.json file and try things again?

I released two versions of the final checkpoints. The Google Drive version does not have a config.json file while the HuggingFace version has it. I did not receive any issues with the GDrive version checkpoint, thus it's probable that deleting config.json solves the problem.

BiEchi commented 2 months ago

By the way, we don't support the .from_pretrained() method in huggingface. Please use trainer.load() to load the model, as defined in this file.

StarWalkin commented 2 months ago

Hi @BiEchi, thanks for your inspiring work. Currently, I am confused about how to reproduce your results. When I evaluate with AutoUI-base model it went well. But when I downloaded the JackBAI/general-off2on-digirl checkpoint from huggingface, changed the "policy model" in scripts/config/main/default.yaml to its path and copied the tokenizing and generating configs from AutoUI without other adjustment, I got dimension mismatch when loading the model as mentioned in issue #11. Doing as suggested in the issue and adding the information to the config.json fixed the size mismatch problem and I managed to load the model, but got untranslatable actions as issue #14 and this issue #16 mention. And if I remove the config.json as you say, I will get the does not appear to have a file named config.json error with T5ForMultimodalGeneration.from_pretrained function.

For us to reproduce the results, could you detail the steps to change the setting and solve all these wierd issues?(Like what to do with the config, and loading method). That would be really helpful. Thanks!

BiEchi commented 2 months ago

@StarWalkin thanks for following up on this. does the error only occurs if you use the hugging face checkpoint, or does it still occurs when you use the checkpoint in the google drive link?

BiEchi commented 2 months ago

@StarWalkin Also, the policy model in the default.yaml config should still be AutoUI instead of the path to the checkpoint. The path of the checkpoint should be specified in the save_path item in eval_only.yaml.

BiEchi commented 2 months ago

Please let me know whether it works, I'll add a patch to README if it works.

StarWalkin commented 2 months ago

@BiEchi Thanks for ur response. I think some of the people troubled with issue #11, issue #12 and issue #16 may made the same mistake as me. Here is how I worked it out.

Currently it's running successfully. I haven't got the final success rate but I think it is the right way for that.

BiEchi commented 2 months ago

Thanks a lot for confirming! I would also suggest to use the huggingface version to run after you got the success rates using the google drive version. I'll reply to every issue after your confirmation.

BiEchi commented 2 months ago

Just modified the README. I'll leave this issue open for one more week in case you've got any updates. @StarWalkin

BiEchi commented 2 months ago

Closing as the problem is now solved.