backprop64 / DAMM

A codebase for tracking laboratory mice in videos
10 stars 1 forks source link

Error when runing the mouse_tracker #6

Open Andrianarivelo opened 1 month ago

Andrianarivelo commented 1 month ago

Hello,

Thanks for this very promessing tool ! We are working on social behavior and deeplabcut multi-animal is still very bad at distinguishing two animals if they are too similar...

First I had a lot of trouble installing detectron2 on windows it was impossible to build since there's no prebuilt version so I installed it on linux where the installation worked fine.

Now when I run my script I have an error that seems to be related to numpy, I was wondering if you have a fix, I'm on ubuntu, with a RTX 3090, I followed the installation guide (although I had to uninstall torchvision 0.19.1 after sam2 installation to avoid a conflict between torch and torchvision version).

My script :

from DAMM.tracking import PromptableVideoTracker
import pdb

#initilize tracking setup
mouse_tracker = PromptableVideoTracker(
    # sam_config: you dont need to download this or specify a full path, look for associated config file above
    sam2_model_cfg="sam2_hiera_b+.yaml",
    sam2_checkpoint= "sam2_hiera_base_plus.pt",
    damm_model_cfg="DAMM_config.yaml",
    damm_checkpoint="DAMM_weights.pth"
)

# Track the first 250 frames of demo_video.mp4
# Save the output and visualization to the output_dir
mouse_tracker.predict_video(
    video_path='videos/Trial     5-TRIXIE-QUIXIE.mp4',
    output_dir='output/',
    batch_size=64,
    start_frame=0,
    end_frame=2500,
    visualize=True
)

The error:

/home/andry/anaconda3/envs/DAMM/bin/python /home/andry/DAMM/DAMM-tracking.py 
/home/andry/anaconda3/envs/DAMM/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/andry/anaconda3/envs/DAMM/lib/python3.10/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
Running SAM 2 Video Segmentation on: cuda
ERROR:iopath.common.file_io:An exception occurred in telemetry logging.Disabling telemetry to prevent further exceptions.
Traceback (most recent call last):
  File "/home/andry/anaconda3/envs/DAMM/lib/python3.10/site-packages/iopath/common/file_io.py", line 946, in __log_tmetry_keys
    handler.log_event()
  File "/home/andry/anaconda3/envs/DAMM/lib/python3.10/site-packages/iopath/common/event_logger.py", line 97, in log_event
    del self._evt
AttributeError: _evt
 - //////////////////////////////// -
 - Using sam2 + damm to track video - 
 - # of Frames to Track: 2500
 - Batch Size: 64
 - Number Batches: 40
 - Storing predictions in: output/frame_predictions
 - ||||||||||||||||||||||||||||||| -
Traceback (most recent call last):
  File "/home/andry/DAMM/DAMM-tracking.py", line 16, in <module>
    mouse_tracker.predict_video(
  File "/home/andry/DAMM/DAMM/tracking/promptable_video_tracker.py", line 73, in predict_video
    mice_prompts = self.damm_predictor.get_frame_masks(self.video_path, 0, 3)
  File "/home/andry/DAMM/DAMM/detection/damm_detector.py", line 65, in get_frame_masks
    outputs = self.detector(frame)
  File "/home/andry/anaconda3/envs/DAMM/lib/python3.10/site-packages/detectron2/engine/defaults.py", line 316, in __call__
    image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
RuntimeError: Could not infer dtype of numpy.float32

Thanks for your help