baal-org / baal

Bayesian active learning library for research and industrial usecases.
https://baal.readthedocs.io
Apache License 2.0
868 stars 86 forks source link

update Model Wrapper to work with pytorch detection models #188

Open noknok00 opened 2 years ago

noknok00 commented 2 years ago

Is your feature request related to a problem? Please describe. A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

As explained in Slack, I am trying to use BaaL with a Object Detection Model from pytorch. Since the Model Wrapper was coded with Classification models in mind, it doesn't work very well with Object Detection Models. I found pytorch use this Training Loop for their Detection Model, so we can take inspiration from it: https://github.com/pytorch/vision/blob/main/references/detection/engine.py function "train_one_epoch".

Describe the solution you'd like A clear and concise description of what you want to happen.

I believe we only need to update the "train_on_batch" function to work with Detection's models. perhaps one flag to let the Model Wrapper class know if it need to use a "training loop" for classification or for detection.

Describe alternatives you've considered A clear and concise description of any alternative solutions or features you've considered.

Active Learning Dataset I believe has a similar situation of being coded with "classification" in mind, and doesn't play well with a "detection" kind of dataset with bounding boxes and labels, rather than just labels as in "classification". I worked around by using my own Dataset class and Collate Function. Not sure if I should submit another Feature Request for this.

Additional context Add any other context or screenshots about the feature request here.

with these changes I believe we could support Semantic and Instance segmentation too, since their Training loop (and datasets) are similar to Object Detection.

Dref360 commented 2 years ago

Hello,

I tried to make this work in this gist, but the predict_on_batch doesn't work yet.

If we run model(data) n times, we would get the following output:

[ [{'boxes': tensor([[...]]), 'scores': tensor([...]), 'labels': tensor([...])}],
 [{'boxes': tensor([[...]]), 'scores': tensor([...]), 'labels': tensor([...])}],
...
 [{'boxes': tensor([[...]]), 'scores': tensor([...]), 'labels': tensor([...])}]
]

How should we stack these to feed to the heuristics? The number of boxes is also potentially variable between inferences.

In our experiments, we computed the variance of the raw output directly from the heads.

Dref360 commented 2 years ago

Anything I can help with? Happy to schedule a meeting as needed.