logix-project / logix

AI Logging for Interpretability and Explainability🔬
Apache License 2.0
74 stars 6 forks source link

FSDP error #93

Closed sangkeun00 closed 7 months ago

sangkeun00 commented 7 months ago
Traceback (most recent call last):
  File "/home/sangkeuc/workspace/AlpaCare/extract_log.py", line 225, in <module>
    train()
  File "/home/sangkeuc/workspace/AlpaCare/extract_log.py", line 221, in train
    trainer.extract_log()
  File "/home/sangkeuc/workspace/analog/analog/huggingface/patch.py", line 83, in extract_log
    self.train(*args, **kwargs)
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/site-packages/transformers/trainer.py", line 1868, in _inner_training_loop
    with self.accelerator.accumulate(model):
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/site-packages/accelerate/accelerator.py", line 997, in accumulate
    cm_stack.enter_context(contextlib.nullcontext() if self.sync_gradients else self.no_sync(m))
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/contextlib.py", line 492, in enter_context
    result = _cm_type.__enter__(cm)
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/site-packages/accelerate/accelerator.py", line 889, in no_sync
    with context():
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1040, in no_sync
    _lazy_init(self, self)
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 170, in _lazy_init
    _share_state_and_init_handle_attrs(state, root_module)
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 240, in _share_state_and_init_handle_attrs
    _p_assert(
  File "/home/sangkeuc/miniconda3/envs/chatdoc/lib/python3.10/site-packages/torch/distributed/utils.py", line 146, in _p_assert
    raise AssertionError(s)
AssertionError: Non-root FSDP instance's `_is_root` should not have been set yet or should have been set to `False`

The above error occurs when using HF Trainer + FSDP + autowrap policy. Not sure if this is related to this Issue in the original PyTorch repo.

sangkeun00 commented 7 months ago

Here is the minimal example to reproduce this bug:

import os
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
import torch.distributed as dist
import torch.nn as nn
import analog

def init_distributed_mode():
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    torch.distributed.init_process_group(backend='nccl')

# Define the model
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return x

def main():
    init_distributed_mode()
    run = analog.init("fsdp")
    model = MyModel()

    fsdp_model = FSDP(model.cuda(), auto_wrap_policy=ModuleWrapPolicy({torch.nn.Conv2d}))
    # fsdp_model = FSDP(model.cuda())
    analog.watch(fsdp_model)

    run.setup({"log": "grad"})
    with run(data_id=["1", "2"]):
        x = torch.randn(2, 3, 2, 2).cuda()
        out = fsdp_model(x)
        loss = out.pow(2).sum()
        loss.backward()
    log = run.get_log()

if __name__ == "__main__":
    main()

The above error only occurs when using auto_wrap_policy in FSDP.

sangkeun00 commented 7 months ago

Interestingly, this error occurred in the below line:

https://github.com/sangkeun00/analog/blob/7358ef2addb58820bcc99731d8e86d36c3b58b35/analog/analog.py#L100

where model.state_dict() is called, similar to the issue in the original PyTorch repo. When I replaced state_dict with named_modules, the error disappeared.