cj-mills / pytorch-yolox-object-detection-tutorial-code

This repository contains the training code for my PyTorch YOLOX object detection tutorial.
https://christianjmills.com/series/tutorials/pytorch-train-object-detector-yolox-series.html
MIT License
12 stars 2 forks source link

AssertionError: Loss is NaN or infinite at epoch 0, batch 712. Stopping training. #5

Open agoransson opened 1 month ago

agoransson commented 1 month ago

I get it at different batches in the first epoch, not always the same. But around 70-80% progress iof the first batch it seems.

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[48], line 1
----> 1 train_loop(model=model, 
      2            train_dataloader=train_dataloader,
      3            valid_dataloader=valid_dataloader,
      4            optimizer=optimizer, 
      5            loss_func=yolox_loss, 
      6            lr_scheduler=lr_scheduler, 
      7            device=torch.device(device), 
      8            epochs=epochs, 
      9            checkpoint_path=checkpoint_path,
     10            use_scaler=True)

Cell In[42], line 120, in train_loop(model, train_dataloader, valid_dataloader, optimizer, loss_func, lr_scheduler, device, epochs, checkpoint_path, use_scaler)
    117 # Loop over the epochs
    118 for epoch in tqdm(range(epochs), desc="Epochs"):
    119     # Run a training epoch and get the training loss
--> 120     train_loss = run_epoch(model, train_dataloader, optimizer, lr_scheduler, loss_func, device, scaler, epoch, is_training=True)
    121     # Run an evaluation epoch and get the validation loss
    122     with torch.no_grad():

Cell In[42], line 77, in run_epoch(model, dataloader, optimizer, lr_scheduler, loss_func, device, scaler, epoch_id, is_training)
     75         if math.isfinite(loss_item):
     76             print(finate_training_message)
---> 77         assert not math.isnan(loss_item) and math.isfinite(loss_item), stop_training_message
     79 # Cleanup and close the progress bar 
     80 progress_bar.close()

AssertionError: Loss is NaN or infinite at epoch 0, batch 712. Stopping training.

I continue the process after this failure, hoping to get somewhere with the model. When I eventually run the following code block:

# Ensure the model and input data are on the same device
print(device)
wrapped_model.to(device)
input_tensor = transforms.Compose([transforms.ToImage(), 
                                   transforms.ToDtype(torch.float32, scale=True)])(input_img)[None].to(device)

# Make a prediction with the model
with torch.no_grad():
    model_output = wrapped_model(input_tensor)

model_output.shape

I get this error. It seems mostly connected to failure to pass both the model and input to the device (mps) - but as far as I can see in the code both are already passed to device.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[74], line 9
      7 # Make a prediction with the model
      8 with torch.no_grad():
----> 9     model_output = wrapped_model(input_tensor)
     11 model_output.shape

File ~/miniforge3/envs/pytorch-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/pytorch-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/miniforge3/envs/pytorch-env/lib/python3.10/site-packages/cjm_yolox_pytorch/inference.py:151, in YOLOXInferenceWrapper.forward(self, x)
    147 x = self.process_output(x)
    149 if self.run_box_and_prob_calculation:
    150     # Generate output grids
--> 151     output_grids = generate_output_grids(*input_dims, self.strides).to(x.device)
    152     # Calculate the bounding boxes and their probabilities
    153     x = self.calculate_boxes_and_probs(x, output_grids)

File ~/miniforge3/envs/pytorch-env/lib/python3.10/site-packages/cjm_yolox_pytorch/utils.py:56, in generate_output_grids(height, width, strides)
     52 # We will use a loop but it won't affect the exportability of the model to ONNX 
     53 # as the loop is not dependent on the input data (height, width) but on the 'strides' which is model parameter.
     54 for i, stride in enumerate(strides):
     55     # Calculate the grid height and width
---> 56     grid_height = height // stride
     57     grid_width = width // stride
     59     # Generate grid coordinates

File ~/miniforge3/envs/pytorch-env/lib/python3.10/site-packages/torch/_tensor.py:41, in _handle_torch_function_and_wrap_type_error_to_not_implemented.<locals>.wrapped(*args, **kwargs)
     39     if has_torch_function(args):
     40         return handle_torch_function(wrapped, args, *args, **kwargs)
---> 41     return f(*args, **kwargs)
     42 except TypeError:
     43     return NotImplemented

File ~/miniforge3/envs/pytorch-env/lib/python3.10/site-packages/torch/_tensor.py:999, in Tensor.__rfloordiv__(self, other)
    997 @_handle_torch_function_and_wrap_type_error_to_not_implemented
    998 def __rfloordiv__(self, other):
--> 999     return torch.floor_divide(other, self)

RuntimeError: Placeholder storage has not been allocated on MPS device!

Would you know how to handle such an issue? Or is it again a problem with MPS support?

agoransson commented 1 month ago

I'm not getting the same issues running this on colab. There it seems to run perfectly, so definitely related to Apple Silicon.

cj-mills commented 1 month ago

Yeah, that error message seems to be specific to the mps (Metal Performance Shaders) backend. Unfortunately, I won't be much help there. The last Mac I owned was from 2009.