Closed 980202006 closed 3 years ago
You can use the following code for a simple fix.
Add this code before line 38 in train.py
:
# simple fix for dataparallel that allows access to class attributes
class MyDataParallel(torch.nn.DataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
and insert the following code between line 104 and 105:
for key in model:
model[key] = MyDataParallel(model[key], device_ids=[0, 1, 2, 3])
for key in model_ema:
model_ema[key] = MyDataParallel(model_ema[key], device_ids =[0, 1, 2, 3])
What an amazing code. In practice, I found that GPU resource occupation increases with the increase of domain. Is there a solution for thousands of speakers? Even with 4 GPUs, the model cannot be trained with a batch size of 5.
Since the original StarGAN v2 wasn't designed to take thousands of domains, you will need to find a way to work around it. The reason RAM consumption increases with the number of domains is that in model.py, the unshared components of the mapping network consist of individual small subnetworks dedicated to each domain, so is the style encoder. You may remove the mapping network completely, for example.
The style encoder can also be changed for zero-shot conversion, though more studies need to be done because what I got so far was not very ideal. You may want to open a new issue to discuss how to make it work with thousands of speakers or zero-shot in general.
thank you!
When I run the DataParallel mode on a single machine with multiple cards, I encountered a lot of problems