hpcaitech / ColossalAI

Making large AI models cheaper, faster and more accessible
https://www.colossalai.org
Apache License 2.0
38.62k stars 4.33k forks source link

[BUG]: InterleavedPipelineSchedule fails to run #3187

Open KimmiShi opened 1 year ago

KimmiShi commented 1 year ago

🐛 Describe the bug

I am trying to reproduce the resnet50-pipeline parallel demo in this page: https://colossalai.org/docs/features/pipeline_parallel

I can go well with the code in this page. However, I'd like to try the interleaved scheduler. And I tried the following things:

  1. simply setting NUM_CHUNKS=2 in the above example. I got an error msg like: RuntimeError: Given groups=1, weight of size [512, 2048, 1, 1], expected input[64, 1024, 14, 14] to have 2048 channels, but got 1024 channels instead
  2. I read the code and foud that I have to set model.num_chunks in CONFIG to create a InterleavedPipelineSchedule object. I did so and got another Error msg:
    Traceback (most recent call last):
    File "test_pp.py", line 166, in <module>
    engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion,
    File "/mnt/.local/lib/python3.8/site-packages/colossalai/initialize.py", line 463, in initialize
    engine = Engine(model=model,
    File "/mnt/.local/lib/python3.8/site-packages/colossalai/engine/_base_engine.py", line 94, in __init__
    self._schedule.pre_processing(self)
    File "/mnt/.local/lib/python3.8/site-packages/colossalai/engine/schedule/_pipeline_schedule.py", line 490, in pre_processing
    elif isinstance(engine.model[0], NaiveAMPModel):
    TypeError: 'PipelinableModel' object is not subscriptable

    code script:

# Define some config
BATCH_SIZE = 512
NUM_EPOCHS = 2 
NUM_CHUNKS = 2 # update at 2023-3-22
# CONFIG = dict(NUM_MICRO_BATCHES=8, parallel=dict(pipeline=2))
CONFIG = dict(NUM_MICRO_BATCHES=8, parallel=dict(pipeline=2), model=dict(num_chunks=NUM_CHUNKS))

# Train
disable_existing_loggers()
parser = colossalai.get_default_parser()
# args = parser.parse_args()
rank, world_size, port,addr= setup_distributed_slurm()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host=addr)
logger = get_dist_logger()
pipelinable = PipelinableContext()

# build model
with pipelinable:
    model = resnet50()

exec_seq = [
    'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool',
    (lambda x: torch.flatten(x, 1), "behind"), 'fc'
]
pipelinable.to_layer_list(exec_seq)
#pipelinable.policy = "uniform"
#pipelinable.to_layer_list()

model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))

# build criterion
criterion = nn.CrossEntropyLoss()

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# build dataloader
root = os.environ.get('DATA', './data')
# train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32)
train_dataloader, test_dataloader = get_dataloaders(BATCH_SIZE, [3,224,224], 40000, 100)
test_dataloader=None
lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1)
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion,
                                                                                train_dataloader, test_dataloader,
                                                                                lr_scheduler)
timer = MultiTimer()

trainer = Trainer(engine=engine, timer=timer, logger=logger)

hook_list = [
    hooks.LossHook(),
    hooks.AccuracyHook(col_nn.metric.Accuracy()),
    hooks.LogMetricByEpochHook(logger),
    hooks.LRSchedulerHook(lr_scheduler, by_epoch=True)
]

beg=time.time()
trainer.fit(train_dataloader=train_dataloader,
            epochs=NUM_EPOCHS,
            test_dataloader=test_dataloader,
            test_interval=1,
            hooks=hook_list,
            display_progress=True)

Environment

No response

JThh commented 1 year ago

You were still having NUM_CHUNKS=1. How about setting it to 2?

KimmiShi commented 1 year ago

You were still having NUM_CHUNKS=1. How about setting it to 2?

Thanks, I've set it to 2 (but not shown in the code above) .

File "/mnt/.../.local/lib/python3.8/site-packages/colossalai/engine/schedule/_pipeline_schedule.py", line 490, in pre_processing
    elif isinstance(engine.model[0], NaiveAMPModel):
TypeError: 'PipelinableModel' object is not subscriptable

So I think this might be a bug? As it try to access engine.model[0]

JThh commented 1 year ago

It happens that PipelinableModel wraps around a list of modules without exposing them. Can you try temporarily patching this function as below?

    def pre_processing(self, engine):
        from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
        if isinstance(engine.model, ShardedModelV2):
            self.dtype = torch.half
        modules = engine.model._module_list;
        if isinstance(modules[0], NaiveAMPModel):
            self.dtype = torch.half
        for model in modules:
            if isinstance(model, NaiveAMPModel):
                model = model.model
            sig = inspect.signature(model.forward)
            for p in sig.parameters.values():
                assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'

We will seek to solve in better ways soon!

JThh commented 1 year ago

@ver217 , may you take a look at this bug, or was it indeed a bug? Thanks!

KimmiShi commented 1 year ago

It happens that PipelinableModel wraps around a list of modules without exposing them. Can you try temporarily patching this function as below?

    def pre_processing(self, engine):
        from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
        if isinstance(engine.model, ShardedModelV2):
            self.dtype = torch.half
        modules = engine.model._module_list;
        if isinstance(modules[0], NaiveAMPModel):
            self.dtype = torch.half
        for model in modules:
            if isinstance(model, NaiveAMPModel):
                model = model.model
            sig = inspect.signature(model.forward)
            for p in sig.parameters.values():
                assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'

We will seek to solve in better ways soon!

Thanks, I tried and got another error msg:

  File "/mnt/xxx/.local/lib/python3.8/site-packages/colossalai/engine/schedule/_pipeline_schedule.py", line 586, in forward_backward_step
    input_objs = [[] for _ in range(len(model))]
TypeError: object of type 'PipelinableModel' has no len()
JThh commented 1 year ago

Yes, we should expect this error. Let's wait for some reply!

ver217 commented 1 year ago

@YuliangLiu0306 Did pipelinable models support interleaved 1f1b?