MrZihan / GridMM

Official implementation of GridMM: Grid Memory Map for Vision-and-Language Navigation (ICCV'23).
60 stars 1 forks source link

Model config files #4

Closed Mercy2Green closed 10 months ago

Mercy2Green commented 10 months ago

Hi! Could you upload your model files? Thanks!

Best regards.

`    # Model config
    model_config = PretrainedConfig.from_json_file(opts.model_config)
    model_config.pretrain_tasks = []
    for train_dataset_config in opts.train_datasets.values():
        model_config.pretrain_tasks.extend(train_dataset_config['tasks'])
    model_config.pretrain_tasks = set(model_config.pretrain_tasks)

    tokenizer = AutoTokenizer.from_pretrained('../../bert-base')

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint, map_location=lambda storage, loc: storage)
    else:
        checkpoint = {}
        if opts.init_pretrained == 'bert':
            tmp = AutoModel.from_pretrained('../../bert-base')
            for param_name, param in tmp.named_parameters():
                checkpoint[param_name] = param
            if model_config.lang_bert_name == 'xlm-roberta-base':
                # embeddings.token_type_embeddings.weight (1 -> 2, the second is for image embedding)
                checkpoint['embeddings.token_type_embeddings.weight'] = torch.cat(
                    [checkpoint['embeddings.token_type_embeddings.weight']] * 2, 0
                )
            del tmp
        elif opts.init_pretrained == 'lxmert':
            tmp = torch.load(
                '../datasets/pretrained/LXMERT/model_LXRT.pth', 
                map_location=lambda storage, loc: storage
            )
            for param_name, param in tmp.items():
                param_name = param_name.replace('module.', '')
                if 'bert.encoder.layer' in param_name:
                    param_name = param_name.replace('bert.encoder.layer', 'bert.lang_encoder.layer')
                    checkpoint[param_name] = param
                elif 'bert.encoder.x_layers' in param_name:
                    param_name1 = param_name.replace('bert.encoder.x_layers', 'bert.local_encoder.encoder.x_layers')
                    param_name2 = param_name.replace('bert.encoder.x_layers', 'bert.global_encoder.encoder.x_layers')
                    param_name3 = param_name.replace('bert.encoder.x_layers', 'bert.grid_txt_encoder.encoder.x_layers')
                    checkpoint[param_name1] = checkpoint[param_name2] = checkpoint[param_name3] = param

                elif 'cls.predictions' in param_name:
                    param_name = param_name.replace('cls.predictions', 'mlm_head.predictions')
                    checkpoint[param_name] = param
                else:
                    checkpoint[param_name] = param
            del tmp

    model_class = GlocalTextPathCMTPreTraining`
MrZihan commented 10 months ago

Hi, I have uploaded the "bert-base" and "pretrained/LXMERT/model_LXRT.pth". In fact, They are also the model files used in the VLN_DUET. If you have any questions, please feel free to contact me and I will try to deal with them.

Mercy2Green commented 10 months ago

Thank you! You are awsome!