HPDL-Group / Merak

Apache License 2.0
69 stars 9 forks source link

Introspecting pipeline stage partitioning results #6

Open jaywonchung opened 1 year ago

jaywonchung commented 1 year ago

Hi,

First of all, thanks a lot for the great codebase. I'm benefiting significantly from this.

I was wondering what's a good way to look deeper into the how my model is split into multiple stages for pipeline parallelism. Specifically, currently Merak outputs something like this:

stage=0 layers=14
     0: GraphModule
     1: GraphModule
     2: GraphModule
     3: GraphModule
     4: GraphModule
     5: GraphModule
     6: GraphModule
     7: GraphModule
     8: GraphModule
     9: GraphModule
    10: GraphModule
    11: GraphModule
    12: GraphModule
    13: GraphModule
stage=1 layers=14
    14: GraphModule
    15: GraphModule
    16: GraphModule
    17: GraphModule
    18: GraphModule
    19: GraphModule
    20: GraphModule
    21: GraphModule
    22: GraphModule
    23: GraphModule
    24: GraphModule
    25: GraphModule
    26: GraphModule
    27: GraphModule
  loss: loss_fn

but is there a way to identify, for example, layer 17 is a transformer layer? What is inside all the GraphModule layers, i.e. which points were the original pytorch module split at?

Thanks a lot.

lucasleesw commented 1 year ago

Hello! Thanks for using Merak. We are happy to hear from you! To see what is inside graphmodule please set the parameter --wall_clock_breakdown True. The information is printed here

We are sorry that the message is in a mess when opening the parameter wall_clock_breakdown. We will go to improve it in the near future~

jaywonchung commented 1 year ago

I just tried it out for a couple models and it works perfectly! Thanks.

Slightly unrelated to the original issue, but if I wanted to influence where to cut these GraphModules, what would be a way to do that? For instance, I just saw that one transformer encoder layer gets split into two different GraphModules for T5, i.e. self attention ends a GraphModule, and then the fully connected layer for that transformer encoder starts the next GraphModule. What would be a way to force each GraphModule to contain exactly one transformer encoder/decoder layer?

lucasleesw commented 1 year ago

Slightly unrelated to the original issue, but if I wanted to influence where to cut these GraphModules, what would be a way to do that? For instance, I just saw that one transformer encoder layer gets split into two different GraphModules for T5, i.e. self attention ends a GraphModule, and then the fully connected layer for that transformer encoder starts the next GraphModule. What would be a way to force each GraphModule to contain exactly one transformer encoder/decoder layer?

Hello @jaywonchung , this is a very interesting question. It is hard to tell the algorithm what is a transformer encoder. The contents of a graphmodule are decided by the number of parameters and the connections between graph nodes. But setting the parameter shard_count as described here might be a good try.

jaywonchung commented 1 year ago

I see. I reviewed Merak.autoshard.shard_layers._split_layers and indeed it just operates on torch.fx traced nodes and only views them as a chunk of parameters. Yeah I guess it's very tricky to keep this function general across all models and still make it layer name aware.

As per your suggestion, I tried doubling the shard_count from the default, which splitted about twice as more granular than before, and I was able to group together GraphModules belonging to the same transformer layer into the same pipeline stage by hardcoding layer indices as a new partition_method here. https://github.com/HPDL-Group/Merak/blob/e8a2a779fea878be9b778f8a808a192364766f36/Merak/modules/module.py#L376

Not the ideal solution but a hack that works for now. Thanks!

lucasleesw commented 1 year ago

Happy to hear that works for you. And there might be another way that makes the algorithm layer name aware. Leaf modules in Merak.trainer could let torch.fx treat a module as a whole. Passing the transformer layer instance could be helpful. The usage is here. The related code is here and here. Sorry for the late information. I hope this would also be helpful.

jaywonchung commented 1 year ago

Thanks for the extra information! T5Block would be the right leaf node to use. I tried it like this in the run_t5.py example:

from transformers.models.t5.modeling_t5 import T5Block

trainer = MerakTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset, 
    # Data collator will default to DataCollatorWithPadding, so we change it.
    data_collator=default_data_collator,
    leaf_modules=(T5Block,),
)

But, unfortunately it dies during tracing:

Traceback (most recent call last):
  File "/workspace/merak/examples/language-modeling/run_t5.py", line 93, in <module>
    main()
File "/workspace/merak/examples/language-modeling/run_t5.py", line 86, in main
    train_result = trainer.train()
  File "/workspace/merak/Merak/merak_trainer.py", line 729, in train
    self.create_optimizer_and_scheduler(num_training_steps=max_steps)
  File "/workspace/merak/Merak/merak_trainer.py", line 260, in create_optimizer_and_scheduler
    model, model_layers, input_to_shard_dic = convert_to_sequential(self.model, self.args, self.leaf_modules)
  File "/workspace/merak/Merak/autoshard/convert.py", line 107, in convert_to_sequential
    traced, dummy_inputs = symbolic_trace(
  File "/workspace/merak/Merak/autoshard/convert.py", line 378, in symbolic_trace
    traced_graph = tracer.trace(model, concrete_args=concrete_args)
  File "/workspace/merak/Merak/autoshard/convert.py", line 274, in trace
    graph = torch.fx.Tracer.trace(self, root, concrete_args=concrete_args)
  File "/root/.local/miniconda3/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py", line 615, in trace
    self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
  File "/root/.local/miniconda3/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 1611, in forward
    decoder_outputs = self.decoder(
  File "/root/.local/miniconda3/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py", line 604, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/root/.local/miniconda3/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py", line 422, in call_module
    return forward(*args, **kwargs)
  File "/root/.local/miniconda3/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py", line 600, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/root/.local/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/.local/miniconda3/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 912, in forward
    batch_size, seq_length = input_shape
ValueError: too many values to unpack (expected 2)

Removing the leaf_modules line runs fine.