Too much RAM usage by ImageClassificationData #1450

Open ethanwharris opened 2 years ago

ethanwharris commented 2 years ago

Discussed in

Originally posted by **Hravan** September 1, 2022 I'm setting up a training for this kaggle competition dataset: (I'm using here only samples with single labels to make the problem simpler) The problem is that the ImageClassificationData takes too much RAM and GPU is underutilized. I wrote the code in plain PyTorch for comparison to confirm that the problem is somewhere within ImageClassificationData. Code shared by both training versions: ```py import pandas as pd from skimage import io from sklearn.preprocessing import OneHotEncoder import torch from import Dataset from torchvision import transforms as T class PlantDataset(Dataset): def __init__(self, df, transform=None) -> None: super().__init__() self.img_paths = df["image"].tolist() self.transform = transform self.encoder = OneHotEncoder() self.labels = ( self.encoder.fit_transform(df["label"].values.reshape(-1, 1)) .todense() .A ) def __len__(self): return len(self.img_paths) def __getitem__(self, idx): img = io.imread(self.img_paths[idx]) if self.transform is not None: img = self.transform(img) label = self.labels[idx] # return { # "input": img, # "target": torch.tensor(label, dtype=torch.uint8), # } return img, torch.tensor(label, dtype=torch.float32) def preprocess_df(csv_path, images_root): df = pd.read_csv(csv_path) df = df[~df["labels"].str.contains(" ")] df["image"] = images_root + df["image"] df = df.rename(columns={"labels": "label"}) return df def split_df(df, train_pct): df = df.sample(frac=1) n_train = int(train_pct * len(df)) train_df = df.iloc[:n_train].reset_index() val_df = df.iloc[n_train:].reset_index() return train_df, val_df def create_dataloader(df): train_compose = T.Compose( [ T.ToPILImage(), T.Resize((224, 224)), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] ) dataloader = PlantDataset(df, transform=train_compose), batch_size=32, num_workers=8, prefetch_factor=8, ) return dataloader ``` Training in plain PyTorch: ```py def train(model, data_loader, n_epochs): model = model.cuda() optimizer = torch.optim.Adam(model.parameters()) loss_fn = torch.nn.CrossEntropyLoss() for i in range(n_epochs): for images, labels in tqdm.tqdm(data_loader): images = images.cuda() preds = model(images) loss = loss_fn(preds, labels.cuda()) loss.backward() optimizer.step() optimizer.zero_grad() print(f"End of epoch {i}") def main(): arg_parser = argparse.ArgumentParser() arg_parser.add_argument("csv_path") arg_parser.add_argument("images_root") args = arg_parser.parse_args() model = torchvision.models.resnet18() model.fc = torch.nn.Linear(512, 6) df = preprocess_df(args.csv_path, args.images_root) train_df, val_df = split_df(df, 0.1) train_loader = create_dataloader(train_df) time0 = perf_counter() train(model, train_loader, 2) print(f"Time elapsed: {perf_counter() - time0}") if __name__ == "__main__": main() ``` Training in Lightning Flash: ```py class Resnet18(pl.LightningModule): def __init__(self): super().__init__() self.model = torchvision.models.resnet18() self.model.fc = torch.nn.Linear(512, 6) self.loss_fn = torch.nn.CrossEntropyLoss() def training_step(self, batch, batch_idx): x, y = batch["input"], batch["target"] y_hat = self.model(x) loss = self.loss_fn(y_hat, y) return loss def configure_optimizers(self): return torch.optim.Adam(self.model.parameters()) def main(): arg_parser = argparse.ArgumentParser() arg_parser.add_argument("csv_path") arg_parser.add_argument("images_root") args = arg_parser.parse_args() model = Resnet18() df = preprocess_df(args.csv_path, args.images_root) train_df, val_df = split_df(df, 0.1) datamodule = ImageClassificationData.from_data_frame( "image", "label", train_data_frame=train_df, batch_size=32, transform_kwargs=dict(image_size=(224, 224)), num_workers=8, persistent_workers=True, pin_memory=False, ) time0 = perf_counter() trainer = flash.Trainer(max_epochs=2, gpus=torch.cuda.device_count()), datamodule=datamodule) print(f"Time elapsed: {perf_counter() - time0}") if __name__ == "__main__": main() ``` When I increase bach_size to 64 or num_workers to 16 in ImageClassificationData, I start having problems with RAM, which does not happen for the plain PyTorch version. Any ideas what might be the problem? I tried profiling, but didn't get to any sensible conclusion, except that I bet the problem is in BaseDataFetcher in DataModule.
Atharva-Phatak commented 1 year ago

@ethanwharris, I can take a look if this is open. Seems interesting that there is such a bottleneck. Could you give me a bit more details ?

Maybe we can test this on a smaller dataset like CIFAR and see if that's the case.

ethanwharris commented 1 year ago

Hey @Atharva-Phatak thanks for the offer! Please feel free to take a look 😃 I think a great starting point would be to have a model in Flash (trained on e.g. CIFAR-10 as you suggested) and the equivalent model just using Lightning to see if the maximum batch size you can get is different on each. If it is different then that would confirm we have a leak

Atharva-Phatak commented 1 year ago

@ethanwharris Sorry, I was busy with college and working on a PR for bolts. I will look at this week and let's where we can go from here :)

Borda commented 1 year ago

@Atharva-Phatak that would be great is you can still have look at it... :rabbit: