pytorch / PiPPy

Pipeline Parallelism for PyTorch
BSD 3-Clause "New" or "Revised" License
676 stars 81 forks source link

Issue with the unet example on a different model, reduction dim mismatch #955

Open muellerzr opened 4 months ago

muellerzr commented 4 months ago

After an issue was raised on the accelerate repo I tried adapting https://github.com/pytorch/PiPPy/pull/943 to work with stable-diffusion-v1-5. However I found that we can't trace anymore and dynamo finds an error what seems to be around the split point.

Environment:

- `Accelerate` version: 0.28.0.dev0
- `Diffusers` version: 0.26.3
- Platform: Linux-6.5.0-21-generic-x86_64-with-glibc2.35
- Python version: 3.10.13
- Numpy version: 1.26.2
- PyTorch version (GPU?): 2.2.0+cu118 (True)
- System RAM: 93.41 GB
- GPU type: NVIDIA GeForce RTX 4090x2

Reproduction

Script:

# Copyright (c) Meta Platforms, Inc. and affiliates

# Minimum effort to run this example:
# $ torchrun --nproc-per-node 2 pippy_unet.py

import argparse
import os

import torch
import torch.distributed as dist

from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
from pippy.PipelineStage import PipelineStage

from diffusers import UNet2DConditionModel

from hf_utils import get_number_of_params

def run(args):
    print("Using device:", args.device)

    # Create model
    # See https://github.com/huggingface/diffusers?tab=readme-ov-file#quickstart
    unet = UNet2DConditionModel.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        subfolder="unet",
        variant="fp16",
        torch_dtype=torch.float16,
    )
    unet.to(args.device)
    unet.eval()
    if args.rank == 0:
        print(f"Total number of params = {get_number_of_params(unet) // 10 ** 6}M")
        print(unet)

    # Input configs
    sample_size = unet.config.sample_size
    noise = torch.randn((2, 4, sample_size, sample_size), device=args.device, dtype=torch.float16)
    encoder_hidden_states=  torch.randn(2, 77, 768, dtype=torch.float16, device=args.device)
    timestep = 1

    # Split model into two stages:
    #   Stage 0: down_blocks + mid_block
    #   Stage 2: up_blocks
    annotate_split_points(unet, {"mid_block": PipeSplitWrapper.SplitPoint.END})

    # Create pipeline
    unet_pipe = Pipe.from_tracing(
        unet,
        num_chunks=args.chunks,
        example_args=(noise, timestep, encoder_hidden_states),
    )
    nstages = len(list(unet_pipe.split_gm.children()))
    assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
    if args.rank == 0:
        for i, sm in enumerate(unet_pipe.split_gm.children()):
            print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params")

    # Create schedule runtime
    stage = PipelineStage(
        unet_pipe,
        args.rank,
        device=args.device,
    )

    # Run
    if args.rank == 0:
        stage(noise)
    elif args.rank == args.world_size - 1:
        out = stage()
    else:
        stage()

    print(f"Rank {args.rank} completes")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 2)))
    parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1)))
    parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost'))
    parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500'))
    parser.add_argument('--schedule', type=str, default="FillDrain")
    parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available()))
    parser.add_argument("--chunks", type=int, default=2)
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--batches', type=int, default=1)

    args = parser.parse_args()

    if args.cuda:
        dev_id = args.rank % torch.cuda.device_count()
        args.device = torch.device(f"cuda:{dev_id}")
    else:
        args.device = torch.device("cpu")

    # Init process group
    backend = "nccl" if args.cuda else "gloo"
    dist.init_process_group(
        backend=backend,
        rank=args.rank,
        world_size=args.world_size,
    )

    run(args)

Launching either with torchrun or accelerate launch

Stack Trace

(de-duplicated for sanity)

Traceback (most recent call last):
  File "/home/zach/work/PiPPy/examples/huggingface/pippy_unet.py", line 105, in <module>
    run(args)
  File "/home/zach/work/PiPPy/examples/huggingface/pippy_unet.py", line 49, in run
    unet_pipe = Pipe.from_tracing(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/pippy/IR.py", line 1067, in from_tracing
    traced = Pipe._trace_with_export(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/pippy/IR.py", line 1015, in _trace_with_export
    traced: torch.fx.GraphModule = _export_to_torch_ir(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_export/__init__.py", line 516, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1342, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 328, in call_function
    return tx.inline_user_function_return(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
    return tx.inline_user_function_return(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 328, in call_function
    return tx.inline_user_function_return(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
    return tx.inline_user_function_return(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1264, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 328, in call_function
    return tx.inline_user_function_return(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
    return tx.inline_user_function_return(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1264, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 301, in call_function
    return wrap_fx_proxy(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1314, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1399, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1525, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1486, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1027, in wrap_fake_exception
    return fn()
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1487, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1592, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1576, in run_node
    return nnmodule(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/attention.py", line 372, in forward
    attn_output = self.attn2(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 512, in forward
    return self.processor(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 1231, in __call__
    key = attn.to_k(encoder_hidden_states, *args)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 250, in _fn
    result = fn(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 3952, in matmul
    output = t1_folded.mm(t2).view(output_shape)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1712, in dispatch
    r = func(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_ops.py", line 513, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 250, in _fn
    result = fn(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_meta_registrations.py", line 1975, in meta_mm
    torch._check(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/__init__.py", line 1081, in _check
    _check_with(RuntimeError, cond, message)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/__init__.py", line 1064, in _check_with
    raise error_type(message_evaluated)
torch._dynamo.exc.TorchRuntimeError: Failed running call_module self_mid_block_mod_attentions_0_transformer_blocks_0(*(FakeTensor(..., device='cuda:1', size=(1, 64, 1280), dtype=torch.float16,
           grad_fn=<ViewBackward0>),), **{'attention_mask': None, 'encoder_hidden_states': None, 'encoder_attention_mask': None, 'timestep': None, 'cross_attention_kwargs': None, 'class_labels': None}):
a and b must have same reduction dim, but got [64, 1280] X [768, 1280].

from user code:
   File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1160, in forward
    sample = self.mid_block(sample, emb)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/pippy/IR.py", line 1200, in forward
    return self.mod(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 802, in forward
    hidden_states = attn(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 391, in forward
    hidden_states = block(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

[2024-02-28 10:33:41,051] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 12546) of binary: /home/zach/miniconda3/envs/accelerate/bin/python
Traceback (most recent call last):
  File "/home/zach/miniconda3/envs/accelerate/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/zach/work/accelerate/src/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/home/zach/work/accelerate/src/accelerate/commands/launch.py", line 1016, in launch_command
    multi_gpu_launcher(args)
  File "/home/zach/work/accelerate/src/accelerate/commands/launch.py", line 672, in multi_gpu_launcher
    distrib_run.run(args)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
pippy_unet.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2024-02-28_10:33:41
  host      : workhorse
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 12547)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-02-28_10:33:41
  host      : workhorse
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 12546)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
muellerzr commented 4 months ago

cc @sayakpaul +related accelerate issue: https://github.com/huggingface/accelerate/issues/2497