After an issue was raised on the accelerate repo I tried adapting https://github.com/pytorch/PiPPy/pull/943 to work with stable-diffusion-v1-5. However I found that we can't trace anymore and dynamo finds an error what seems to be around the split point.
# Copyright (c) Meta Platforms, Inc. and affiliates
# Minimum effort to run this example:
# $ torchrun --nproc-per-node 2 pippy_unet.py
import argparse
import os
import torch
import torch.distributed as dist
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
from pippy.PipelineStage import PipelineStage
from diffusers import UNet2DConditionModel
from hf_utils import get_number_of_params
def run(args):
print("Using device:", args.device)
# Create model
# See https://github.com/huggingface/diffusers?tab=readme-ov-file#quickstart
unet = UNet2DConditionModel.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="unet",
variant="fp16",
torch_dtype=torch.float16,
)
unet.to(args.device)
unet.eval()
if args.rank == 0:
print(f"Total number of params = {get_number_of_params(unet) // 10 ** 6}M")
print(unet)
# Input configs
sample_size = unet.config.sample_size
noise = torch.randn((2, 4, sample_size, sample_size), device=args.device, dtype=torch.float16)
encoder_hidden_states= torch.randn(2, 77, 768, dtype=torch.float16, device=args.device)
timestep = 1
# Split model into two stages:
# Stage 0: down_blocks + mid_block
# Stage 2: up_blocks
annotate_split_points(unet, {"mid_block": PipeSplitWrapper.SplitPoint.END})
# Create pipeline
unet_pipe = Pipe.from_tracing(
unet,
num_chunks=args.chunks,
example_args=(noise, timestep, encoder_hidden_states),
)
nstages = len(list(unet_pipe.split_gm.children()))
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
if args.rank == 0:
for i, sm in enumerate(unet_pipe.split_gm.children()):
print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params")
# Create schedule runtime
stage = PipelineStage(
unet_pipe,
args.rank,
device=args.device,
)
# Run
if args.rank == 0:
stage(noise)
elif args.rank == args.world_size - 1:
out = stage()
else:
stage()
print(f"Rank {args.rank} completes")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 2)))
parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1)))
parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost'))
parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500'))
parser.add_argument('--schedule', type=str, default="FillDrain")
parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available()))
parser.add_argument("--chunks", type=int, default=2)
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--batches', type=int, default=1)
args = parser.parse_args()
if args.cuda:
dev_id = args.rank % torch.cuda.device_count()
args.device = torch.device(f"cuda:{dev_id}")
else:
args.device = torch.device("cpu")
# Init process group
backend = "nccl" if args.cuda else "gloo"
dist.init_process_group(
backend=backend,
rank=args.rank,
world_size=args.world_size,
)
run(args)
Launching either with torchrun or accelerate launch
Stack Trace
(de-duplicated for sanity)
Traceback (most recent call last):
File "/home/zach/work/PiPPy/examples/huggingface/pippy_unet.py", line 105, in <module>
run(args)
File "/home/zach/work/PiPPy/examples/huggingface/pippy_unet.py", line 49, in run
unet_pipe = Pipe.from_tracing(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/pippy/IR.py", line 1067, in from_tracing
traced = Pipe._trace_with_export(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/pippy/IR.py", line 1015, in _trace_with_export
traced: torch.fx.GraphModule = _export_to_torch_ir(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_export/__init__.py", line 516, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1342, in inner
result_traced = opt_f(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
compiled_product = _compile(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
transformations(instructions, code_options)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
return fn(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
tracer.run()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
super().run()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
return inner_fn(self, inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 328, in call_function
return tx.inline_user_function_return(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
tracer.run()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
return inner_fn(self, inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars.items)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
return super().call_function(tx, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
return super().call_function(tx, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
return tx.inline_user_function_return(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
tracer.run()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
return inner_fn(self, inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars.items)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 328, in call_function
return tx.inline_user_function_return(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
tracer.run()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
return inner_fn(self, inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars.items)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
return super().call_function(tx, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
return super().call_function(tx, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
return tx.inline_user_function_return(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
tracer.run()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
return inner_fn(self, inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1264, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 328, in call_function
return tx.inline_user_function_return(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
tracer.run()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
return inner_fn(self, inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars.items)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
return super().call_function(tx, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
return super().call_function(tx, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
return tx.inline_user_function_return(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
tracer.run()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
return inner_fn(self, inst)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1264, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 301, in call_function
return wrap_fx_proxy(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1314, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1399, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1525, in get_fake_value
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1486, in get_fake_value
ret_val = wrap_fake_exception(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1027, in wrap_fake_exception
return fn()
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1487, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1592, in run_node
raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1576, in run_node
return nnmodule(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/attention.py", line 372, in forward
attn_output = self.attn2(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 512, in forward
return self.processor(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 1231, in __call__
key = attn.to_k(encoder_hidden_states, *args)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 250, in _fn
result = fn(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 3952, in matmul
output = t1_folded.mm(t2).view(output_shape)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1712, in dispatch
r = func(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_ops.py", line 513, in __call__
return self._op(*args, **(kwargs or {}))
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 250, in _fn
result = fn(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/_meta_registrations.py", line 1975, in meta_mm
torch._check(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/__init__.py", line 1081, in _check
_check_with(RuntimeError, cond, message)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/__init__.py", line 1064, in _check_with
raise error_type(message_evaluated)
torch._dynamo.exc.TorchRuntimeError: Failed running call_module self_mid_block_mod_attentions_0_transformer_blocks_0(*(FakeTensor(..., device='cuda:1', size=(1, 64, 1280), dtype=torch.float16,
grad_fn=<ViewBackward0>),), **{'attention_mask': None, 'encoder_hidden_states': None, 'encoder_attention_mask': None, 'timestep': None, 'cross_attention_kwargs': None, 'class_labels': None}):
a and b must have same reduction dim, but got [64, 1280] X [768, 1280].
from user code:
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1160, in forward
sample = self.mid_block(sample, emb)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/pippy/IR.py", line 1200, in forward
return self.mod(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 802, in forward
hidden_states = attn(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 391, in forward
hidden_states = block(
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
[2024-02-28 10:33:41,051] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 12546) of binary: /home/zach/miniconda3/envs/accelerate/bin/python
Traceback (most recent call last):
File "/home/zach/miniconda3/envs/accelerate/bin/accelerate", line 8, in <module>
sys.exit(main())
File "/home/zach/work/accelerate/src/accelerate/commands/accelerate_cli.py", line 47, in main
args.func(args)
File "/home/zach/work/accelerate/src/accelerate/commands/launch.py", line 1016, in launch_command
multi_gpu_launcher(args)
File "/home/zach/work/accelerate/src/accelerate/commands/launch.py", line 672, in multi_gpu_launcher
distrib_run.run(args)
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
elastic_launch(
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
pippy_unet.py FAILED
------------------------------------------------------------
Failures:
[1]:
time : 2024-02-28_10:33:41
host : workhorse
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 12547)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2024-02-28_10:33:41
host : workhorse
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 12546)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
After an issue was raised on the accelerate repo I tried adapting https://github.com/pytorch/PiPPy/pull/943 to work with
stable-diffusion-v1-5
. However I found that we can't trace anymore and dynamo finds an error what seems to be around the split point.Environment:
Reproduction
Script:
Launching either with
torchrun
oraccelerate launch
Stack Trace
(de-duplicated for sanity)