HPDL-Group / Merak

Apache License 2.0
69 stars 9 forks source link

Runtime error when trying to use the run_gpt language modelling example #12

Open prajwal1210 opened 6 months ago

prajwal1210 commented 6 months ago

Hello!

I have been trying to run the run_gpt.py file in the language_modelling examples and I get the following error:

    train_result = trainer.train()
  File "/global/u1/p/prajwal/venvs/merakenv/lib/python3.9/site-packages/Merak/train_func.py", line 133, in train
    self.create_optimizer_and_scheduler(num_training_steps=max_steps)
  File "/global/u1/p/prajwal/venvs/merakenv/lib/python3.9/site-packages/Merak/merak_trainer.py", line 226, in create_optimizer_and_scheduler
    model, model_layers, input_to_shard_dic = convert_to_sequential(self.model, self.args, self.add_model, self.leaf_modules)
  File "/global/u1/p/prajwal/venvs/merakenv/lib/python3.9/site-packages/Merak/autoshard/convert.py", line 110, in convert_to_sequential
    traced, dummy_inputs = symbolic_trace(
  File "/global/u1/p/prajwal/venvs/merakenv/lib/python3.9/site-packages/Merak/autoshard/convert.py", line 389, in symbolic_trace
    traced_graph = tracer.trace(model, concrete_args=concrete_args)
  File "/global/u1/p/prajwal/venvs/merakenv/lib/python3.9/site-packages/Merak/autoshard/convert.py", line 301, in trace
    graph.erase_node(node)
  File "/global/common/software/nersc/pm-2022q4/sw/pytorch/2.0.1/lib/python3.9/site-packages/torch/fx/graph.py", line 873, in erase_node
    raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} '
RuntimeError: Tried to erase Node past_key_values_1 but it still had 1 users in the graph: {_assert_is_none: None}!

Is there a specific version of HF or Pytorch to be used with Merak to not encounter this error? I am using Pytorch 2.0.1 and transformers 4.15.0

prajwal1210 commented 6 months ago

Any update on this issue?

lucasleesw commented 5 months ago

Thanks for using our repo! Support for Pytorch 2.0.1 is under development, please try to use torch 1.10.0.