sw32-seo / GTA

Official code of Graph Truncated Attention for Retrosynthesis in AAAI2021
MIT License
13 stars 2 forks source link

RuntimeError on training GTA #3

Closed taein98 closed 10 months ago

taein98 commented 1 year ago

I appreciate your nice work on retrosynthesis. However, I get a RuntimeError message on running train.py when I tried to start from preprocessing the USPTO-50k_no_rxn_dataset. (torchtext 0.4.0, pytorch 1.8.2, python 3.8.0, cudatoolkit 11.1.7)

image

And I also found a similar issue #1 and the different dtypes on running loss

image image

I would deeply appreciate your effort if you help debug!

sw32-seo commented 1 year ago

Could you give me the full error message and the command you used? I think the error is not coming from loss.div() but happens while backprop.

taein98 commented 1 year ago

I agree with you. That was an error message from debugging, And here is the full error message. (below error repeats)

image

And the command I used

image

python train.py -data data/${dataset}/${dataset} \ -savemodel experiments/${dataset}${model_name} \ -seed 2023 -gpu_ranks 0 -world_size 1 \ -save_checkpoint_steps 1000 -keep_checkpoint 11 \ -train_steps 1000 -valid_steps 1000 -report_every 1000 \ -param_init 0 -param_init_glorot \ -batch_size 4 -batch_type tokens -normalization tokens \ -dropout 0.3 -max_grad_norm 0 -accum_count 4 \ -optim adam -adam_beta1 0.9 -adam_beta2 0.998 \ -decay_method noam -warmup_steps 8000 \ -learning_rate 2 -label_smoothing 0.0 \ -enc_layers 6 -dec_layers 6 -rnn_size 256 -word_vec_size 256 \ -encoder_type transformer -decoder_type transformer \ -share_embeddings -position_encoding -max_generator_batches 0 \ -global_attention general -global_attention_function softmax \ -self_attn_type scaled-dot -max_relative_positions 4 \ -heads 8 -transformer_ff 2048 -max_distance 1 2 3 4 \ -early_stopping 40 -alpha 1.0 \ -tensorboard -tensorboard_logdir runs/${dataset}${modelname} 2>&1 | tee train$model_name.log

The pytorch 1.4.0 is not compatible with our GPU (RTX 3090) image

And the oldest version of pytorch (1.7.0) which is compatible with cuda 11 doesn't work, too.

image

YoujunZhao commented 1 year ago

I got the same problem when I used the code for training. Have you solve problem yet?

Seojin-Kim commented 1 year ago

Although I'm not the author of this paper, I'm gonna share my solution below.

You can simply replace "cross_mask" of the 279th line of onmt/utils/loss.py with "cross_mask.float()"

The main issue was the type mismatch of input (float) and target (integer) within MSE loss.

sw32-seo commented 10 months ago

Thank you @Seojin-Kim you helped me a lot!