huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.34k stars 875 forks source link

Incorrect output when using accelerate in a pytorch Unet model #2849

Open cporrasn opened 3 weeks ago

cporrasn commented 3 weeks ago

System Info

Good morning! I'm trying to use accelerate to distribute an unet model that has already been trained. I need to do model and tensor parallelism because a single image does not fit on a single GPU, so when performing the inference with a single GPU I get a memory error.

I load the .pth file like:

model = UNet(3,1)
model = model.to(memory_format=torch.channels_last)
state_dict = torch.load("model.pth", map_location="cpu")
del state_dict['mask_values']
model.load_state_dict(state_dict);
model.eval()

After, I use:

model = prepare_pippy(model, example_args=(input,))

Then, I try to make an inference like this:

with torch.no_grad():
 output = model(input)

I have 2 GPUs, and I do notice that when I run it, the load is distributed on both GPUs.

The model I use segments an image, and should return me a segmentation, however the output is completely wrong, it returns me an image with completely random blank pixels.

The model run on a single GPU for a small image gets the correct result.

If you have any idea how to solve it, please tell me.

Information

Tasks

Reproduction

model = UNet(3,1)
model = model.to(memory_format=torch.channels_last)
state_dict = torch.load("model.pth", map_location="cpu")
del state_dict['mask_values']
model.load_state_dict(state_dict);
model.eval()

model = prepare_pippy(model, example_args=(input,))

with torch.no_grad():
 output = model(input)

Expected behavior

A correctly segmented image

muellerzr commented 3 weeks ago

Are you making sure to gather the results at the end (or look on the last process only)? otherwise you'll have intermittent results on each GPU. Please see the chunk in the examples here: https://github.com/huggingface/accelerate/blob/main/examples/inference/pippy/bert.py#L74-L76

# The outputs are only on the final process by default
if PartialState().is_last_process:
    output = torch.stack(tuple(output[0]))