talmolab / sleap

A deep learning framework for multi-animal pose tracking.
https://sleap.ai
Other
427 stars 97 forks source link

Hard limit for instance count in multi-instance models inside of inference graph #1011

Open talmo opened 1 year ago

talmo commented 1 year ago

It's sometimes desirable to have the ability to select the max number of instances that a multi-instance model will return. Currently, we implement this through tracking, but sometimes we might want to do this without having to run the tracker.

Use cases:

The problem is that there are several strategies for selecting among N instance detections.

717 should solve this in a more general form by providing standalone filtering functions that operate on single LabeledFrames.

This issue proposes a smaller and less general solution to this that will work for some of the use cases.

The idea is to implement this with tensorflow graph compatible ops within the InferenceLayer/InferenceModel subclasses. It's less general, but compatible with exported models.

This could go here: https://github.com/talmolab/sleap/blob/6cac6519208dbc77a89a1e7fb019fed03d9514ac/sleap/nn/inference.py#L2034

Or even better, here during centroid detection/cropping: https://github.com/talmolab/sleap/blob/6cac6519208dbc77a89a1e7fb019fed03d9514ac/sleap/nn/inference.py#L1659-L1660

Where we could use tf.math.top_k on the peak values like:

max_instances: Optional[int] = None
# ...
# in call() method:
if self.max_instances is not None:
    top_points = tf.math.top_k(centroid_vals, k=self.max_instances)
    top_inds = top_points.indices

    centroid_vals = tf.gather(centroid_vals, top_inds)
    centroid_points = tf.gather(centroid_points, top_inds)
    crop_sample_inds = tf.gather(crop_sample_inds, top_inds)
roomrys commented 1 year ago

We still need to allow setting the max instances for bottom-up models