davidtvs / pytorch-lr-finder

A learning rate range test implementation in PyTorch
MIT License
921 stars 120 forks source link

TypeError: forward() missing 1 required positional argument: 'labels' #82

Closed adnan119 closed 2 years ago

adnan119 commented 2 years ago

I've been following and making all the necessary changes required to run the lr_finder.range_test(). However, I'm still facing this error! Here's my code defining the Dataset class:


class HappyWhaleDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.file_names = df['file_path'].values
        self.labels = df['individual_id'].values
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        img_path = self.file_names[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        label = self.labels[index]

        if self.transforms:
            img = self.transforms(image=img)["image"]

        return {
            'image': img,
            'label': torch.tensor(label, dtype=torch.long)
        }

def prepare_loaders(df, fold):
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)

    train_dataset = HappyWhaleDataset(df_train, transforms=data_transforms["train"])
    valid_dataset = HappyWhaleDataset(df_valid, transforms=data_transforms["valid"])

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], 
                              num_workers=2, shuffle=True, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)

    return train_loader, valid_loader

train_loader, valid_loader = prepare_loaders(df, fold=0)

Note: Model training goes without error when I'm just creating a usual train_loader with the above code.

class CustomTrainIter(TrainDataLoaderIter):
    def inputs_labels_from_batch(self, batch_data):
        return batch_data["image"], batch_data["label"]

custom_loader = CustomTrainIter(train_loader)

lr_finder = LRFinder(model, optimizer, criterion, device=CONFIG['device'])
lr_finder.range_test(custom_loader, end_lr=1, num_iter=100, step_mode="linear")
lr_finder.plot(log_lr=False)
lr_finder.reset()
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_34/1446799792.py in <module>
      6 
      7 lr_finder = LRFinder(model, optimizer, criterion, device=CONFIG['device'])
----> 8 lr_finder.range_test(custom_loader, end_lr=1, num_iter=100, step_mode="linear")
      9 lr_finder.plot(log_lr=False)
     10 lr_finder.reset()

/opt/conda/lib/python3.7/site-packages/torch_lr_finder/lr_finder.py in range_test(self, train_loader, val_loader, start_lr, end_lr, num_iter, step_mode, smooth_f, diverge_th, accumulation_steps, non_blocking_transfer)
    318                 train_iter,
    319                 accumulation_steps,
--> 320                 non_blocking_transfer=non_blocking_transfer,
    321             )
    322             if val_loader:

/opt/conda/lib/python3.7/site-packages/torch_lr_finder/lr_finder.py in _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer)
    375 
    376             # Forward pass
--> 377             outputs = self.model(inputs)
    378             loss = self.criterion(outputs, labels)
    379 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() missing 1 required positional argument: 'labels'
NaleRaphael commented 2 years ago

Hi @adnan119,

Can you provide the function signature of forward() in your model class? It seems the forward() function is defined with at least 2 positional arguments (excluding self), e.g., def forward(self, inputs, labels):

We assume that forward() should accept only 1 argument by default, i.e., def forward(self, inputs):. And that's why you see the function call in traceback message like this:

/opt/conda/lib/python3.7/site-packages/torch_lr_finder/lr_finder.py in _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer)
    375 
    376             # Forward pass
--> 377             outputs = self.model(inputs)
    378             loss = self.criterion(outputs, labels)
    379 

If that's the case, you can check out this approach to see whether it could help: https://github.com/davidtvs/pytorch-lr-finder/issues/61#issuecomment-692715109

adnan119 commented 2 years ago

Yes actually! The model architecture is indeed such that it requires labels to be passed in as Input outputs = model(images, labels) during the training phase.

NaleRaphael commented 2 years ago

Yeah! Glad that information helps.

You can update this thread if you have further questions regarding implementing that approach, I will get back to you when I have time.

adnan119 commented 2 years ago

Thanks for the help @NaleRaphael! However, there's still an issue. As I mentioned earlier the architecture is made such that it requires labels itself to be fed along with the input image into the neural net. The solution you've linked to in issue #61 would work if another input was some embedding/image/some other value, but in my case, it's the label itself being fed into the network while training. Which leads to the following error when using the script you've provided:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_34/929039027.py in <module>
     20 
     21 lr_finder = LRFinder(model_wrap, optimizer, criterion, device=CONFIG['device'])
---> 22 lr_finder.range_test(custom_loader, end_lr=1, num_iter=100, step_mode="linear")
     23 lr_finder.plot(log_lr=False)
     24 lr_finder.reset()

/opt/conda/lib/python3.7/site-packages/torch_lr_finder/lr_finder.py in range_test(self, train_loader, val_loader, start_lr, end_lr, num_iter, step_mode, smooth_f, diverge_th, accumulation_steps, non_blocking_transfer)
    318                 train_iter,
    319                 accumulation_steps,
--> 320                 non_blocking_transfer=non_blocking_transfer,
    321             )
    322             if val_loader:

/opt/conda/lib/python3.7/site-packages/torch_lr_finder/lr_finder.py in _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer)
    375 
    376             # Forward pass
--> 377             outputs = self.model(inputs)
    378             loss = self.criterion(outputs, labels)
    379 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_34/929039027.py in forward(self, data)
     11     def forward(self, data):
     12         # Unpack data to the format you need
---> 13         img, labels = data
     14         return self.model(img, labels)
     15 

ValueError: not enough values to unpack (expected 2, got 1)

I've tried a few different things to fix this but nothing seems to work, still got a lot to learn I guess😊

NaleRaphael commented 2 years ago

Got it! This actually can be achieved by modifying your CustomTrainIter slightly.

Here is the explanation. The following code snippet is how forward pass implemented in LRFinder._train_batch(): https://github.com/davidtvs/pytorch-lr-finder/blob/acc5e7ee7711a460bf3e1cc5c5f05575ba1e1b4b/torch_lr_finder/lr_finder.py#L371-L378

We can denote those variables with simpler ones. So here is the simplified code showing how the forward pass works:

# Desired format:
X, Y = next(train_iter)
outputs = self.model(X)
loss = self.criterion(outputs, Y)

Since your model needs to get 2 input arguments: (images, labels), it means the X is actually a 2-value tuple (images, labels). So the code above can be written as below:

# Actual representation:
# Replace X with `(images, labels)`, and replace Y with `labels`
(images, labels), labels = next(train_iter)
outputs = self.model((images, labels))
loss = self.criterion(outputs, labels)

Now we know that train_iter has to return (images, labels), labels in each iteration, so that means you can modify your CustomTrainIter as below:

class CustomTrainIter(TrainDataLoaderIter):
    def inputs_labels_from_batch(self, batch_data):
        images = batch_data["image"]
        labels = batch_data["label"]
        return (images, labels), labels

But since your model takes 2 input arguments rather than 1, the invocation of your model.forward() in that forward pass does not meet the requirement now. Therefore, you have to create a wrapper for model to unpack the tuple ((images, labels)) into 2 variables. That is:

class ModelWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inputs):
        images, labels = inputs     # unpack
        outputs = self.model(images, labels)
        return outputs

That's it! So this should be how LRFinder runs in your case:

custom_loader = CustomTrainIter(train_loader)
model_wrapper = ModelWrapper(your_model)

lr_finder = LRFinder(model_wrapper, optimizer, criterion, device=CONFIG['device'])
lr_finder.range_test(custom_loader, end_lr=1, num_iter=100, step_mode="linear")
lr_finder.plot(log_lr=False)
lr_finder.reset()
adnan119 commented 2 years ago

Thanks a lot for your awesome explanation @NaleRaphael this really helped me get more clarity. return (images, labels), labels is the part I was missing. You did more than just help me solve the error! Best wishes with you.