NoahVl / Explaining-In-Style-Reproducibility-Study

Re-implementation of the StylEx paper, training a GAN to explain a classifier in StyleSpace, paper by Lang et al. (2021).
Other
36 stars 8 forks source link

Trouble running with `multi_gpus=True` #14

Closed bisraelsen closed 2 years ago

bisraelsen commented 2 years ago

Hello,

I'm fairly new to pytorch so am not sure if this is a bug in the code or something that I'm doing incorrectly.

I've got 8 GPUs available (world_size=8) and in order to make use of them I passed multi_gpus=True to cli.py. When I do that I get an error (shown below). I expect that the code should be working with multiple gpus since I've seen others commenting about training with multiple GPUs on other issues.

[UPDATE]: Some other info that might be useful: a) I'm training using stylex_train_new because my images are 256x256 and training would not converge using stylex_train; b) GPUs are A100s

Traceback (most recent call last):
  File "cli.py", line 263, in <module>
    main()
  File "cli.py", line 259, in main
    fire.Fire(train_from_folder)
  File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "cli.py", line 252, in train_from_folder
    mp.spawn(run_training,
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 
-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/cnvrg/stylex/cli.py", line 72, in run_training
    retry_call(model.train, tries=3, exceptions=NanException)
  File "/opt/conda/lib/python3.8/site-packages/retry/api.py", line 101, in retry_call
    return __retry_internal(partial(f, *args, **kwargs), exceptions, tries, delay, max_delay, backoff, jitter, logger)
  File "/opt/conda/lib/python3.8/site-packages/retry/api.py", line 33, in __retry_internal
    return f()
  File "/cnvrg/stylex/stylex_train_new.py", line 1478, in train
    rec_loss = self.rec_scaling * reconstruction_loss(image_batch, generated_images,
  File "/cnvrg/stylex/stylex_train_new.py", line 415, in reconstruction_loss
    loss = 0.1 * lpips_loss(encoder_batch_norm, generated_images_norm).mean() +\
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/lpips/lpips.py", line 118, in forward
    in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/lpips/lpips.py", line 154, in forward
    return (inp - self.shift) / self.scale
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!
NoahVl commented 2 years ago

Hey! Thanks for taking an interest in our project.

We sadly haven't been able to test multi gpu support and during the conversion process of lucidrains' code we didn't take multi gpu support into account (as we only had single GPU machines available to us). You can try contacting the people that did get the code working with multiple gpu's, but officially we don't support it. You could also see if the original lucidrains' code works for you with multiple GPU's, so then you know it's our code that's at fault.

My apologies I can't help much more with this specific issue, but I hope it helped!