PiotrNawrot / nanoT5

Fast & Simple repository for pre-training and fine-tuning T5-style models
Apache License 2.0
957 stars 70 forks source link

Pre-training fails at step 30155 out of 32768 steps every time #22

Closed QinengWang-Aiden closed 1 year ago

QinengWang-Aiden commented 1 year ago

I have encountered a problem Cannot call sizes() on tensor with symbolic sizes/strides during pre-training. Whenever I try to pretrain the nanoT5 using google/t5-v1_1-small config, the program fails at step 30155 out of the total 32768 steps. I only modified the modules get_tokenizer and load_dataset_splits to load a customized dataset and tokenizer. The rest of the program remains unchanged except for an added wandb logger. But when I start to train from checkpoint-30000, this problem does not occur (neither happens when I set the args.current_train_step=1 nor when I set the args.current_train_step=30000). Below is the detailed traceback stack:

/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:111: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:112: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:118: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:119: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:111: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:112: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:118: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:119: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:111: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:112: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:118: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/overrides.py:119: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
Traceback (most recent call last):
  File "/data/user/projects/nanoT5/nanoT5/main.py", line 89, in main
    train(model, train_dataloader, test_dataloader, accelerator,
  File "/data/user/projects/nanoT5/nanoT5/utils/train_utils.py", line 190, in train
    loss, stats = forward(model, batch)
  File "/data/user/projects/nanoT5/nanoT5/utils/train_utils.py", line 88, in forward
    outputs = model(**batch)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
    return fn(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1521, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1357, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/accelerate/utils/operations.py", line 581, in forward
    return model_forward(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return hijacked_callback(frame, cache_size, hooks, frame_state)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 637, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 371, in _convert_frame_assert
    return _compile(
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 567, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 181, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 466, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 433, in transform
    tracer.run()
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2071, in run
    super().run()
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2159, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 853, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 953, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 181, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1020, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1005, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/backends/distributed.py", line 436, in compile_fn
    submod_compiler.run(*example_inputs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/backends/distributed.py", line 430, in run_node
    return curr_submod(*new_args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/fx/graph_module.py", line 678, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/fx/graph_module.py", line 284, in __call__
    raise e
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/fx/graph_module.py", line 274, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.251", line 25, in forward
    l__self___encoder_block_0_layer_0_self_attention_q = self.L__self___encoder_block_0_layer_0_SelfAttention_q(type_as)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides

And this is my command that runs the code:

accelerate launch --multi_gpu --num_processes=2 -m nanoT5.main model.compile=true
PiotrNawrot commented 1 year ago

Is it possible that the dataset you're loading has only 30155 batches?

QinengWang-Aiden commented 1 year ago

No, actually the number of batches is more than 40,000...

QinengWang-Aiden commented 1 year ago

after i changed the input data format, this problems solved ...

PiotrNawrot commented 1 year ago

That's great!

PiotrNawrot commented 1 year ago

What do you mean by changing the input data format?

QinengWang-Aiden commented 1 year ago

I changed my original data format from one without spaces to one with spaces, and then retrained a tokenizer, and it seems that I haven't encountered this issue anymore...

Newly added I found another promising solution to this just now: Cannot call sizes() on tensor with symbolic sizes/strides I will now try this approach to see if it works. It works for me! So anyone who is using pytorch-nightly and has the similar issuses can refer to this discussion as a feasible solution!

QinengWang-Aiden commented 1 year ago

I have moved my new question to another issue :)