GuyTevet / motion-diffusion-model

The official PyTorch implementation of the paper "Human Motion Diffusion Model"
MIT License
3.12k stars 336 forks source link

The error in fp16, using bert as text encoder #140

Closed CrazyLearner98 closed 1 year ago

CrazyLearner98 commented 1 year ago

When I change the clip to the bert text encoder, because the code is really complex, so I only changed two positions in the mdm.py, like this: def load_and_freeze_clip(self, clip_version): bert_model = BertModel.from_pretrained('bert-base-uncased') bert_model.eval() for p in bert_model.parameters(): p.requires_grad = False return bert_model def encode_text(self, raw_text): encoded_text = tokenizer(raw_text, padding='max_length', max_length = default_context_len,truncation=True, return_tensors='pt').to(device) self.projection_layer = nn.Linear(768, 512).to(device) bert_outputs = self.clip_model(**encoded_text).last_hidden_state.mean(dim=1) final_outputs = self.projection_layer(bert_outputs).half() print("bert",final_outputs.shape) print(final_outputs) return final_outputs I know i didn't make specific hardcoding for humanml dataset, since I want to make sure the mdm can be trainning successfully before tunning. Now, it can running to the step 0, and shows some result like:

Loading CLIP... Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']

I am really confused about this error, and I don't know how to solve it, can someone just help me with this problem? Really appreciate!!!!!!!

CrazyLearner98 commented 1 year ago

Now, when I am debugging where is the error, I found that len(list(self.model.parameters()))) changed from 305 to 307 after the line 233 in training_loop.py: losses = compute_losses()

I don't know the functools.partial compute_losses = functools.partial( self.diffusion.training_losses, self.ddp_model, micro, # [bs, ch, image_size, image_size] t, # bs sampled timesteps model_kwargs=micro_cond, dataset=self.data.dataset )

is running what in the codes. Could someone help me?