Closed Imsovegetable closed 1 month ago
You can use dispatch_batches=True
as part of your DataLoaderConfiguration
and only load in the data on the first GPU, and then send it off to the rest of them.
Essentially make a "dummy" dataset on the rest of your GPUs, and a real one on GPU 0.
We then will only touch the ones on GPU 0 and it'll send them to the rest of the GPUs.
If I understand correctly, that means I can replace dataset = Mydataset(args)
with
if accelerator.is_local_main_process:
dataset = MyDataset(args)
else:
dataset = DummyDataset()
and after that, I can pass the DummyDataset into dataloader and run the accelerate.prepare
for this dataloader. So all the dataloaders which are not on main rank will be discarded finally?
Besides, I would like to know if there is a possibility to implement a shared memory-like way to avoid the problem of repeated data loading in similar situations. As it is a really common situation.
Your solution here is exactly correct.
Taking that out of your control is a bit too reaching for accelerate (since we are not a Trainer), so this is our solution for you
@Imsovegetable Hello, I am facing the same problem right now. Do you happen to have a minimal working example for such the above-mentioned approach?
Does this also work for situations whereby num_machines > 1?
@Imsovegetable Hello, I am facing the same problem right now. Do you happen to have a minimal working example for such the above-mentioned approach?
Hi, I modify the example provided in accelerate and hope this helpful.
import torch
import copy
from accelerate import Accelerator
from accelerate.utils import set_seed
from torch.utils.data import TensorDataset, DataLoader
# seed
set_seed(0)
# define toy inputs and labels
x = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.])
y = torch.tensor([2., 4., 6., 8., 10., 12., 14., 16.])
gradient_accumulation_steps = 4
batch_size = len(x) // gradient_accumulation_steps
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
# define dataset and dataloader
if accelerator.is_local_main_process:
dataset = TensorDataset(x, y)
else:
dataset = DummyDataset()
dataloader = DataLoader(dataset, batch_size=batch_size)
# define model, optimizer and loss function
model = torch.zeros((1, 1), requires_grad=True)
model_clone = copy.deepcopy(model)
criterion = torch.nn.MSELoss()
model_optimizer = torch.optim.SGD([model], lr=0.02)
model, model_optimizer, dataloader = accelerator.prepare(model, model_optimizer, dataloader)
model_clone_optimizer = torch.optim.SGD([model_clone], lr=0.02)
print(f"initial model weight is {model.mean().item():.5f}")
print(f"initial model weight is {model_clone.mean().item():.5f}")
for i, (inputs, labels) in enumerate(dataloader):
with accelerator.accumulate(model):
inputs = inputs.view(-1, 1)
print(i, inputs.flatten())
labels = labels.view(-1, 1)
outputs = inputs @ model
loss = criterion(outputs, labels)
accelerator.backward(loss)
model_optimizer.step()
model_optimizer.zero_grad()
loss = criterion(x.view(-1, 1) @ model_clone, y.view(-1, 1))
model_clone_optimizer.zero_grad()
loss.backward()
model_clone_optimizer.step()
print(f"w/ accumulation, the final model weight is {model.mean().item():.5f}")
print(f"w/o accumulation, the final model weight is {model_clone.mean().item():.5f}")
Does this also work for situations whereby num_machines > 1?
Actually I haven't test it because I wrote a IterableDataset to bypass this problem. Perhaps you could have a try :)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
I am using a customed dataset where the data is loaded from disk in
__init__
function of dataset. But I found that the data will be loaded n times if I use n gpus (which also means thenum_processes=n
in accelerate config). The size of data is ~70G, so when I use 8 gpus for training, it will cause OOM error since my machine only has 500G memory.My dataset is defined as follows
The accelerate config is as follows:
It happened before the dataloader is prepared by accelerate, and it looks like a really common situation. I wanna know if there a solution or guidance towards possible tutorial about how to deal with it.
Thanks!