SivilTaram / Persona-Dialogue-Generation

The code of ACL 2020 paper "You Impress Me: Dialogue Generation via Mutual Persona Perception"
MIT License
309 stars 46 forks source link

ERROR occurs when running train_psquare.py #7

Closed Lireanstar closed 4 years ago

Lireanstar commented 4 years ago

Hi, after i train the recevier and transmitter model ,then i run the train_psquare.py In my local environment, I have two cards, and I run it in my terminal according to the following code:

CUDA_VISIBLE_DEVICES=0,1 python train_psquare.py

Then the errors occur as below:

[loading fbdialog data:/home/Persona-Dialogue-Generation/data/ConvAI2/train_self_original_no_cands.txt] [loading fbdialog data:/home/Persona-Dialogue-Generation/data/ConvAI2/train_self_original_selfplay.txt] [ training... ] .> [ Saving tensorboard logs here: ./tmp/psquare/tensorboard ] /pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.

Traceback (most recent call last): File "train_psquare.py", line 105, in TrainLoop(opt).train() File "/home/Persona-Dialogue-Generation/scripts/train_model_selfplay.py", line 270, in train world.parley_episode(is_training=True, is_display=is_display) File "/home/Persona-Dialogue-Generation/worlds/selfplay.py", line 186, in parley_episode self.parley(is_display) File "/home/Persona-Dialogue-Generation/worlds/selfplay.py", line 90, in parley acts[0] = agents[0].act(is_display) File "/home/Persona-Dialogue-Generation/agents/psquare/psquare.py", line 604, in act act = self.batch_act(self.observation) File "/home/Persona-Dialogue-Generation/agents/psquare/psquare.py", line 624, in batch_act cand_inds, is_training) File "/home/Persona-Dialogue-Generation/agents/psquare/psquare.py", line 814, in transmitter_predict raise e File "/home/Persona-Dialogue-Generation/agents/psquare/psquare.py", line 787, in transmitter_predict sampling=True) File "/home/Persona-Dialogue-Generation/agents/transmitter/gpt/model.py", line 120, in forward predictions, scores, hidden_states = self.sample_decoding(batch_size, prior_context, prior_dis, self.topk) File "/home/Persona-Dialogue-Generation/agents/transmitter/gpt/model.py", line 383, in sample_decoding is_end = is_end | (predict_tok == self.end_idx).view(-1)

RuntimeError: Expected object of scalar type Byte but got scalar type Bool for argument #2 'other' in call to _th_or

How can i fix this error? thanks!

SivilTaram commented 4 years ago

@Libincn-HNU What is the version number of PyTorch? We recommend you to use PyTorch 1.0.

Lireanstar commented 4 years ago

@SivilTaram I use the torch 1.3.0, i try to change the version to fix it

SivilTaram commented 4 years ago

@Libincn-HNU Looking forward to your further feedback.

Lireanstar commented 4 years ago

Hi, i change the source code : is_end = is_end.bool() | (predict_tok == self.end_idx).bool().view(-1) finally it works, ignoring the warning:

/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.

Finally, it appears:

[ time:127s parleys:54 ] {'reward_var': 0.35737055998582107, 'reward': 0.0036568641662597656, 'num_selfplay_episode': 13, 'num_selfplay_turns': 78, 'total_reward': -0.01236748007627634}

Besides, i lowered the batch_size : )