facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.13k stars 274 forks source link

Got error when training GPT2 with FSDP and activation checkpoint #934

Open ver217 opened 2 years ago

ver217 commented 2 years ago

I'm trying to train GPT2 with FSDP.

My environment is below.

PyTorch: 1.10.0+cu113 Fairscale: 0.4.5 transformers: 4.16.2 Tesla A100 x8

When I set CUDA_LAUNCH_BLOCKING=1, I got:

Traceback (most recent call last):
  File "/home/lclhx/DeepSpeed-FairScale-Benchmark/run_benchmark.py", line 65, in <module>
    main()
  File "/home/lclhx/DeepSpeed-FairScale-Benchmark/run_benchmark.py", line 51, in main
    run_iter_func(model, optimizer, criterion)
  File "/home/lclhx/DeepSpeed-FairScale-Benchmark/benchmark/fairscale/fairscale.py", line 23, in run_iter
    loss.backward()
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/autograd/function.py", line 199, in apply
    return user_fn(self, *args)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/fairscale/nn/checkpoint/checkpoint_activations.py", line 348, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/autograd/function.py", line 199, in apply
    return user_fn(self, *args)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 122, in backward
    outputs = ctx.run_function(*detached_inputs)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/contextlib.py", line 137, in __exit__
    self.gen.throw(typ, value, traceback)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/random.py", line 129, in fork_rng
    torch.cuda.set_rng_state(gpu_rng_state, device)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/cuda/random.py", line 64, in set_rng_state
    _lazy_call(cb)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/cuda/__init__.py", line 153, in _lazy_call
    callable()
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/cuda/random.py", line 62, in cb
    default_generator.set_state(new_state_copy)
RuntimeError: CUDA error: an illegal memory access was encountered

When CUDA_LAUNCH_BLOCKING was not set, I got:

Traceback (most recent call last):
  File "/home/lclhx/DeepSpeed-FairScale-Benchmark/run_benchmark.py", line 65, in <module>
    main()
  File "/home/lclhx/DeepSpeed-FairScale-Benchmark/run_benchmark.py", line 51, in main
    run_iter_func(model, optimizer, criterion)
  File "/home/lclhx/DeepSpeed-FairScale-Benchmark/benchmark/fairscale/fairscale.py", line 23, in run_iter
    loss.backward()
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/autograd/function.py", line 199, in apply
    return user_fn(self, *args)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/fairscale/nn/checkpoint/checkpoint_activations.py", line 348, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/autograd/function.py", line 199, in apply
    return user_fn(self, *args)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 122, in backward
    outputs = ctx.run_function(*detached_inputs)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 881, in custom_forward
    return module(*inputs, use_cache, output_attentions)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 400, in forward
    attn_outputs = self.attn(
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 321, in forward
    raise e
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 318, in forward
    query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lclhx/.conda/envs/colossal/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1837, in forward
    x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
RuntimeError: setStorage: sizes [16384, 3072], strides [0, 1], storage offset 332776448, and itemsize 2 requiring a storage size of 665559040 are out of bounds for storage of size 0

I train my model like:

def run_iter(model, optimizer, criterion):
    # img = torch.rand(BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE).cuda()
    # label = torch.randint(0, NUM_CLASS, (BATCH_SIZE, )).cuda()
    input_ids = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN)).cuda()
    mask = torch.ones((BATCH_SIZE, SEQ_LEN), dtype=torch.int64, device=torch.cuda.current_device())
    optimizer.zero_grad()
    model.zero_grad(set_to_none=True)

    with torch.cuda.amp.autocast():
        # out = model(img)
        # loss = criterion(out, label)
        out = model(input_ids, mask)
        loss = criterion(out, input_ids)
    loss.backward()
    optimizer.step()

def init(model, criterion, stage):

    if stage < 2:
        model = DDP(model, device_ids=[dist.get_rank()])
    elif stage == 2:
        model = ShardedDataParallel(model, optimizer)
    elif stage == 3:
        model = FullyShardedDataParallel(model, mixed_precision=True, reshard_after_forward=False, disable_reshard_on_root=False)
    if stage == 1 or stage == 2:
        optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=0.001)
    else:
        optimizer = OPTIM(model.parameters(), lr=0.001)
    # criterion = torch.nn.CrossEntropyLoss()
    return model, optimizer, criterion

The GPT2 provided by transformers use torch's checkpoint. I also tried to use fairscale's checkpoint_wrapper by modifying the source code of transformers. However, I still got the error. Could you help me figure out this problem?

min-xu-ai commented 2 years ago

Have you tried disabling flattening? Also, there is a version of FSDP in pytorch as prototype as well. Can you give that version a try?

ver217 commented 2 years ago

Hi, I tried disabling flattening and this solved my problem. However, I wonder why I can't enable flattening.

min-xu-ai commented 2 years ago

Hi, I tried disabling flattening and this solved my problem. However, I wonder why I can't enable flattening.

I think unfortunately there is a bug somewhere with flattening. I don't know where it is until exactly debugging this issue.

ver217 commented 2 years ago

Hi, I tried disabling flattening and this solved my problem. However, I wonder why I can't enable flattening.

I think unfortunately there is a bug somewhere with flattening. I don't know where it is until exactly debugging this issue.

Hi, I'm trying to debug this issue and find out that activation checkpointing and mix precision lead to this issue. I notice that comments say we get two different gradient accumulation objects in mixed precision mode. I also find that the backward post hook is registered but never triggered for parameters with checkpoint.

min-xu-ai commented 2 years ago

Oh, that's cool! It is indeed an issue that we have never resolved in a good way. cc @zhaojuanmao too.

If you have a small reproducible test case, that'd be great.

Also, please try different pytorch versions. Maybe it will behave differently across different versions?

ver217 commented 2 years ago

Oh, that's cool! It is indeed an issue that we have never resolved in a good way. cc @zhaojuanmao too.

If you have a small reproducible test case, that'd be great.

Also, please try different pytorch versions. Maybe it will behave differently across different versions?

Ok, I will try different pytorch versions and test more cases. However, I try to register the hook just on the parameter instead of grad accumulator object, and the hook can be triggered normally. Could you tell the reason why you don't just register the hook on parameter?

min-xu-ai commented 2 years ago

If I recall it correctly we want the hook to fire with the gradient computed. If you register the hook on the parameters do you get the hook fire in the right time after the gradient is computed?

ruipeterpan commented 2 years ago

Hi @ver217, would you mind sharing a bit about how you are modifying the source code of transformers to use checkpoint_wrapper?

anj-s commented 2 years ago

Oh, that's cool! It is indeed an issue that we have never resolved in a good way. cc @zhaojuanmao too.

If you have a small reproducible test case, that'd be great.

Also, please try different pytorch versions. Maybe it will behave differently across different versions?

@min-xu-ai Have we run into this issue before? Do you remember the context?

@ver217 Friendly reminder to share a small repro if possible to help us fix this.

edward-io commented 2 years ago

I'm also experiencing this issue:

  File "/fsx/users/hack/fairscale/fairscale/nn/model_parallel/layers.py", line 290, in forward
    output_parallel = F.linear(input_parallel, self.weight, self.bias)
RuntimeError: setStorage: sizes [512, 512], strides [1, 512], storage offset 70686720, and itemsize 2 requiring a storage size of 141897728 are out of bounds for storage of size 0

flatten_parameters: False mitigates the problem, but it would be good not to leave performance on the table :)

@min-xu-ai @anj-s A repro can be found at P521821959 (meta only)

min-xu-ai commented 2 years ago

@edward-io, have you tried pytorch version FSDP? It likely has better performance already in the flatten case. @zhaojuanmao

edward-io commented 2 years ago

@min-xu-ai thanks for the recommendation! I've used the pytorch distributed FSDP, but haven't tried it with model parallel yet.

ssnnoo commented 1 year ago

I have a similar issue: setStorage: sizes [144, 1940], strides [1, 144], storage offset 1881800, and itemsize 4 requiring a storage size of 8644640 are out of bounds for storage of size 0

this is with the fsdp_native strategy. however, there is no flatten_parameter option for this one?

min-xu-ai commented 1 year ago

Are you using pytorch version or FairScale version?

ssnnoo commented 1 year ago

This is with the PyTorch version -> fsdp_native

min-xu-ai commented 1 year ago

I see. Can you please open an issue with pytorch team if you haven’t. It is better to have a small reproduction code for people to debug it

ssnnoo commented 1 year ago

ups, yes sorry .

I tried now the fairscale one and get "mat2 must be a matrix, got 1-D tensor" with flatten_parameters: False. Strange. Will investigate further.

min-xu-ai commented 1 year ago

No worries. The error you are getting is probably due to the fact that your code try to use params in a matmul outside of the forward function call. Outside of the forward call, the params your module originally have are flattened and they are 1d tensors.

aifartist commented 1 year ago

I get the exact same stack from set_rng_state to set_state and illegal memory access with a different app(stable diffusion). How does one disable flattening?