ruotianluo / self-critical.pytorch

Unofficial pytorch implementation for Self-critical Sequence Training for Image Captioning. and others.
MIT License
996 stars 279 forks source link

BertCapModel #235

Open sssilence opened 3 years ago

sssilence commented 3 years ago

When I train BertCapModel, I occur an AssertionError. File "/anaconda3/lib/python3.7/site-packages/transformers/modeling_bert.py", line 410, in forward ), f"If encoder_hidden_states are passed, {self} has to be instantiated with cross-attention layers by setting config.add_cross_attention=True" AssertionError: If encoder_hidden_states are passed, BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=512, out_features=512, bias=True) (key): Linear(in_features=512, out_features=512, bias=True) (value): Linear(in_features=512, out_features=512, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=512, out_features=512, bias=True) (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=512, out_features=512, bias=True) ) (output): BertOutput( (dense): Linear(in_features=512, out_features=512, bias=True) (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) has to be instantiated with cross-attention layers by setting config.add_cross_attention=True

Now I don't know where to set 'config.add_cross_attention=True' Thank you!!! @ruotianluo

blalalt commented 3 years ago
dec_config = BertConfig(vocab_size=tgt_vocab,
                                hidden_size=d_model,
                                num_hidden_layers=N_dec,
                                num_attention_heads=h,
                                intermediate_size=d_ff,
                                hidden_dropout_prob=dropout,
                                attention_probs_dropout_prob=dropout,
                                max_position_embeddings=512,
                                type_vocab_size=1,
                                is_decoder=True,
                                add_cross_attention=True) # in here