KumarLabJax / MouseTracking

Neural network tracking code applied to single mice in open fields assays.
MIT License
9 stars 5 forks source link

Training Advice #4

Open DeclanE101 opened 2 years ago

DeclanE101 commented 2 years ago

I've been trying to use this code for inference of top view mice videos and I've been using the Full Model pretrained data set with 415000 steps. I've also tried using different pretrained models to see if I can get better results, but the Full Model has given me the best results. I am looking for some advice on how to get better results, like, is it better to train the model more than 415000 steps or is it more worthwhile to use the provided tools to create a dataset from our own data? Our data has a few objects in the environment with the mouse and this program is really promising because it has worked on data that includes mice and extra objects. Currently it predicts the mouse's position about 30% of the time but am unable to figure out how to do better. I was curious if there was a way we could discuss over zoom or some other medium? I would really appreciate any help/advice!

SkepticRaven commented 2 years ago

Generally, adding in your own data should have the best result. This is a rather tiny network compared to others (say U-Net) and learns rather quickly.

This is most likely a distribution shift in the visual appearance between training (our data) and inference (your data). Introducing your data to be trained on should allow the network to adapt. This could be either training exclusively on your data or on a mix of ours and your data.

There are 2 other things that could potentially be done to see if they improve performance:

  1. Adjust training augmentation. This could be a real hit-or-miss. I conducted a small hyperparameter scan of the augmentation strengths and as such the default settings are optimal for our data. I didn't expose these values in the command line calls, but they're located here. Also, since I originally wrote this code, TF and other deep learning libraries have added a lot of newer augmentation approaches, such as input masking or color inversion which could be added.
  2. Alter inference input to more closely match averages in the training data. We happen to record using a target brightness of 190 instead of a typical camera default of 128. If you're using the segmentation based approach, you could do this by changing this line to something like frame = np.uint8(next(im_iter)*128/190). The network does apply per-image standardization, so this may not do much. However, other methods of altering the visual appearance to be closer to ours could work. The 2 distinct arena appearances present in the full model are a white floor with gray walls and gray floor with white walls.
DeclanE101 commented 2 years ago

Thank you for the response, let me try these few things and I'll comment again when I get some results.