chao1224 / MoleculeSTM

Multi-modal Molecule Structure-text Model for Text-based Editing and Retrieval, Nat Mach Intell 2023 (https://www.nature.com/articles/s42256-023-00759-6)
https://chao1224.github.io/MoleculeSTM
Other
188 stars 18 forks source link

RuntimeError: RNG state is wrong size #18

Closed Lzcstan closed 3 months ago

Lzcstan commented 6 months ago

Hello! Thank you for your excellent work! I hope to try the scripts you provided and have downloaded the relevant checkpoints following your tutorial. But when I used python pretrain.py --verbose --batch_size=32 --molecule_type=SMILES --epochs=2 to run the pre-trained script, the following error occurred:

arguments        Namespace(seed=42, device=0, dataspace_path='../data', dataset='PubChemSTM',        
text_type='SciBERT', molecule_type='SMILES', representation_frozen=False, batch_size=32,             
text_lr=0.0001, mol_lr=1e-05, text_lr_scale=1, mol_lr_scale=1, num_workers=8, epochs=2, decay=0,     
verbose=True, output_model_dir=None, max_seq_len=512,                                                
megamolbart_input_dir='../data/pretrained_MegaMolBART/checkpoints',                                  
vocab_path='../MoleculeSTM/bart_vocab.txt', pretrain_gnn_mode='GraphMVP_G', gnn_emb_dim=300,         
num_layer=5, JK='last', dropout_ratio=0.5, gnn_type='gin', graph_pooling='mean', SSL_loss='EBM_NCE', 
SSL_emb_dim=256, CL_neg_samples=1, T=0.1, normalize=True)                                            
len of CID2text: 324292                                                                              
len of CID2SMILES: 324270                                                                            
CID 28145 missing
CID 24606 missing
CID 24594 missing
CID 61654 missing
CID 24637 missing
CID 28117 missing
CID 21863527 missing
CID 61861 missing
CID 24258 missing
CID 61851 missing
CID 28127 missing
CID 28116 missing
CID 6857667 missing
CID 5460533 missing
CID 5460520 missing
CID 139646 missing
CID 5460519 missing
CID 11966241 missing
CID 24906310 missing
CID 13847619 missing
CID 60211070 missing
CID 28299 missing
missing 22
len of text_list: 361885
using world size: 1 and model-parallel size: 1 
using torch.float32 for parameters ...
-------------------- arguments --------------------
  adam_beta1 ...................... 0.9
  adam_beta2 ...................... 0.999
  adam_eps ........................ 1e-08
  adlr_autoresume ................. False
  adlr_autoresume_interval ........ 1000
  apply_query_key_layer_scaling ... False
  apply_residual_connection_post_layernorm  False
  attention_dropout ............... 0.1
  attention_softmax_in_fp32 ....... False
  batch_size ...................... None
  bert_load ....................... None
  bias_dropout_fusion ............. False
  bias_gelu_fusion ................ False
  block_data_path ................. None
  checkpoint_activations .......... False
  checkpoint_in_cpu ............... False
  checkpoint_num_layers ........... 1
  clip_grad ....................... 1.0
  contigious_checkpointing ........ False
  cpu_optimizer ................... False
  cpu_torch_adam .................. False
  data_impl ....................... infer
  data_path ....................... None
  dataset_path .................... None
  DDP_impl ........................ local
  deepscale ....................... False
  deepscale_config ................ None
  deepspeed ....................... False
  deepspeed_activation_checkpointing  False
  deepspeed_config ................ None
  deepspeed_mpi ................... False
  distribute_checkpointed_activations  False                                                         
  distributed_backend ............. nccl                                                             
  dynamic_loss_scale .............. True                                                             
  eod_mask_loss ................... False
  eval_interval ................... 1000
  eval_iters ...................... 100
  exit_interval ................... None
  faiss_use_gpu ................... False
  finetune ........................ False
  fp16 ............................ False
  fp16_lm_cross_entropy ........... False
  fp32_allreduce .................. False
  gas ............................. 1
  hidden_dropout .................. 0.1
  hidden_size ..................... 256
  hysteresis ...................... 2
  ict_head_size ................... None
  ict_load ........................ None
  indexer_batch_size .............. 128
  indexer_log_interval ............ 1000
  init_method_std ................. 0.02
  load ............................ ../data/pretrained_MegaMolBART/checkpoints                       
  local_rank ...................... None
  log_interval .................... 100
  loss_scale ...................... None
  loss_scale_window ............... 1000
  lr .............................. None
  lr_decay_iters .................. None
  lr_decay_style .................. linear
  make_vocab_size_divisible_by .... 128
  mask_prob ....................... 0.15
  max_position_embeddings ......... 512
  merge_file ...................... None
  min_lr .......................... 0.0
  min_scale ....................... 1
  mmap_warmup ..................... False
  model_parallel_size ............. 1
  no_load_optim ................... False
  no_load_rng ..................... False
  no_save_optim ................... False
  no_save_rng ..................... False
  num_attention_heads ............. 8
  num_layers ...................... 4
  num_unique_layers ............... None
  num_workers ..................... 2
  onnx_safe ....................... None
  openai_gelu ..................... False
  override_lr_scheduler ........... False
  param_sharing_style ............. grouped
  params_dtype .................... torch.float32
  partition_activations ........... False
  pipe_parallel_size .............. 0
  profile_backward ................ False
  query_in_block_prob ............. 0.1
  rank ............................ 0
  report_topk_accuracies .......... []
  reset_attention_mask ............ False
  reset_position_ids .............. False
  save ............................ None
  save_interval ................... None
  scaled_masked_softmax_fusion .... False
  scaled_upper_triang_masked_softmax_fusion  False 
  seed ............................ 1234
  seq_length ...................... None
  short_seq_prob .................. 0.1
  split ........................... 969, 30, 1
  synchronize_each_layer .......... False
  tensorboard_dir ................. None
  titles_data_path ................ None
  tokenizer_type .................. GPT2BPETokenizer
  train_iters ..................... None
  use_checkpoint_lr_scheduler ..... False
  use_cpu_initialization .......... False
  use_one_sent_docs ............... False
  vocab_file ...................... ../MoleculeSTM/bart_vocab.txt
  warmup .......................... 0.01
  weight_decay .................... 0.01
  world_size ...................... 1
  zero_allgather_bucket_size ...... 0.0
  zero_contigious_gradients ....... False
  zero_reduce_bucket_size ......... 0.0
  zero_reduce_scatter ............. False
  zero_stage ...................... 1.0
---------------- end of arguments ---------------- 
> initializing torch distributed ...
> initializing model parallel with size 1                                                            
> setting random seeds to 1234 ...                                                                   
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel r
ank 0 with model parallel seed: 3952 and data parallel seed: 1234                                    
Loading vocab from ../MoleculeSTM/bart_vocab.txt.
Loading from ../data/pretrained_MegaMolBART/checkpoints
global rank 0 is loading checkpoint ../data/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_0
0/model_optim_rng.pt
could not find arguments in the checkpoint ...
Traceback (most recent call last):
  File "/data02/luozc/chemgpt/MoleculeSTM/scripts/pretrain.py", line 270, in <module>
    MegaMolBART_wrapper = MegaMolBART(
  File "/data02/luozc/chemgpt/MoleculeSTM/scripts/../MoleculeSTM/models/mega_molbart/mega_mol_bart.py
", line 98, in __init__
    self.model = self.load_model(args, self.tokenizer, decoder_max_seq_len)
  File "/data02/luozc/chemgpt/MoleculeSTM/scripts/../MoleculeSTM/models/mega_molbart/mega_mol_bart.py
", line 157, in load_model
    self.iteration = load_checkpoint(model, None, None)
  File "/home/oem/anaconda3/envs/chemgpt/lib/python3.10/site-packages/megatron/checkpointing.py", lin
e 287, in load_checkpoint
    torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
  File "/home/oem/anaconda3/envs/chemgpt/lib/python3.10/site-packages/torch/cuda/random.py", line 75,
 in set_rng_state
    _lazy_call(cb)
  File "/home/oem/anaconda3/envs/chemgpt/lib/python3.10/site-packages/torch/cuda/__init__.py", line 2
29, in _lazy_call
    callable()
  File "/home/oem/anaconda3/envs/chemgpt/lib/python3.10/site-packages/torch/cuda/random.py", line 73,
 in cb
    default_generator.set_state(new_state_copy)
RuntimeError: RNG state is wrong size

How should I fix it? I'm using a server with NVIDIA H800, which has cuda==12.1 and pytorch==2.1.2 Thanks again 🙏

chao1224 commented 6 months ago

Hi @Lzcstan,

I guess this might be due to the cuda and pytorch version. My cuda version is 11 (and pytorch-1.9). Can you try to downgrade them?

Lzcstan commented 6 months ago

Hi @Lzcstan,

I guess this might be due to the cuda and pytorch version. My cuda version is 11 (and pytorch-1.9). Can you try to downgrade them?

But I guess NVIDIA H800 could not use cuda<12, I tried cuda==11.3 and pytorch==1.10.1 but get the follow error:

[2024-01-10 03:20:35,380] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
/home/oem/anaconda3/envs/mol_stm/lib/python3.7/site-packages/apex/pyprof/__init__.py:5: FutureWarning: pyprof will be removed by the end of June, 2022
  warnings.warn("pyprof will be removed by the end of June, 2022", FutureWarning)
arguments        Namespace(CL_neg_samples=1, JK='last', SSL_emb_dim=256, SSL_loss='EBM_NCE', T=0.1, batch_size=32, dataset='PubChemSTM', dataspace_path='../data', decay=0, device=0, dropout_ratio=0.5, epochs=2, gnn_emb_dim=300, 
gnn_type='gin', graph_pooling='mean', max_seq_len=512, megamolbart_input_dir='../data/pretrained_MegaMolBART/checkpoints', mol_lr=1e-05, mol_lr_scale=1, molecule_type='SMILES', normalize=True, num_layer=5, num_workers=8, 
output_model_dir=None, pretrain_gnn_mode='GraphMVP_G', representation_frozen=False, seed=42, text_lr=0.0001, text_lr_scale=1, text_type='SciBERT', verbose=True, vocab_path='../MoleculeSTM/bart_vocab.txt')
Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
/home/oem/anaconda3/envs/mol_stm/lib/python3.7/site-packages/torch/cuda/__init__.py:143: UserWarning: 
NVIDIA H800 with CUDA capability sm_90 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 compute_37.
If you want to use the NVIDIA H800 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/

  warnings.warn(incompatible_device_warn.format(device_name, capability, " ".join(arch_list), device_name))

Here is the support for my guess.

chao1224 commented 6 months ago

Hi @Lzcstan,

I am afraid that I don't have H800 to check the code, and according to the exception messages you listed above, it is the incompatibility between H800 (cuda and pytorch) and megatron.

Lzcstan commented 3 months ago

Hi @Lzcstan,

I am afraid that I don't have H800 to check the code, and according to the exception messages you listed above, it is the incompatibility between H800 (cuda and pytorch) and megatron.

  • I checked the megatron github repo, and according to this link, running megatron w/ H800 should work. (which is good to know)

  • What I am not sure is how compatible is MEgatron-LM-v1.1.5 with H800. Can you try to use the following CMD?


cd MolBART/megatron_molbart/Megatron-LM-v1.1.5-3D_parallelism

pip install .
  • Also, now you are using pip install megatron-lm, and another information that might be helpful is to print out the state_dict['cuda_rng_state'] before the line torch.cuda.set_rng_state(state_dict['cuda_rng_state']) in the source code.

Hi, I checked the shape of RNG state of CUDA and found that H800 cannot fit with the checkpoint. Switching the GPU solves my problem, I will close this issue. Thank you for your kind reply:-)