Open bw4sz opened 11 months ago
@bw4sz, How to reproduce this error?
@bw4sz - can you post a short reprex for this since it would be a good issue for others to work on and @Om-Doiphode is interested in it
# Multi-class models
from deepforest import main
from deepforest import get_data
import pandas as pd
import os
import tempfile
m = main.deepforest(num_classes=3, label_dict={"label1": 0, "label2": 1, "label3": 2})
# Create a dataframe with the image path and labels and bounding boxes of just label1 and label2, but label3 is misspelled
image_path = get_data("OSBS_029.png")
basename = os.path.basename(image_path)
dirname = os.path.dirname(image_path)
df = pd.DataFrame({"image_path": ["OSBS_029.png", "OSBS_029.png","OSBS_029.png"], "xmin": [0, 0, 0], "ymin": [0, 0, 0], "xmax": [100, 100, 100], "ymax": [100, 100, 100], "label": ["label1", "label2", "label3misspelled"]})
df.to_csv("{}/sample.csv".format(tempfile.gettempdir()), index=False)
m.config["train"]["csv_file"] = "{}/sample.csv".format(tempfile.gettempdir())
m.config["train"]["root_dir"] = dirname
m.create_trainer()
m.trainer.fit(m)
yields an ugly stack
>>> m.trainer.fit(m)
| Name | Type | Params
-----------------------------------------------------
0 | model | RetinaNet | 32.2 M
1 | iou_metric | IntersectionOverUnion | 0
2 | mAP_metric | MeanAveragePrecision | 0
-----------------------------------------------------
32.0 M Trainable params
222 K Non-trainable params
32.2 M Total params
128.758 Total estimated model params size (MB)
Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
call._call_and_handle_interrupt(
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
results = self._run_stage()
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1035, in _run_stage
self.fit_loop.run()
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
self.advance()
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 359, in advance
self.epoch_loop.run(self._data_fetcher)
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 136, in run
self.advance(data_fetcher)
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 202, in advance
batch, _, __ = next(data_fetcher)
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/loops/fetchers.py", line 127, in __next__
batch = super().__next__()
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/loops/fetchers.py", line 56, in __next__
batch = next(self.iterator)
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/utilities/combined_loader.py", line 326, in __next__
out = next(self._iterator)
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pytorch_lightning/utilities/combined_loader.py", line 74, in __next__
out[i] = next(self.iterators[i])
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
data = self._next_data()
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1376, in _next_data
return self._process_data(data)
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1402, in _process_data
data.reraise()
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/torch/_utils.py", line 461, in reraise
raise exception
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
data = fetcher.fetch(index)
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/Users/benweinstein/Documents/DeepForest/deepforest/dataset.py", line 106, in __getitem__
targets["labels"] = image_annotations.label.apply(
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pandas/core/series.py", line 4356, in apply
return SeriesApply(self, func, convert_dtype, args, kwargs).apply()
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pandas/core/apply.py", line 1036, in apply
return self.apply_standard()
File "/Users/benweinstein/.conda/envs/DeepForest/lib/python3.9/site-packages/pandas/core/apply.py", line 1092, in apply_standard
mapped = lib.map_infer(
File "pandas/_libs/lib.pyx", line 2859, in pandas._libs.lib.map_infer
File "/Users/benweinstein/Documents/DeepForest/deepforest/dataset.py", line 107, in <lambda>
lambda x: self.label_dict[x]).values.astype(np.int64)
KeyError: 'label3misspelled'
Here the number of labels is the correct, but we misspelled label3. This should be checked right away.
When we are dealing with multi-class models, we wait pretty late until a non-matching label error is thrown.