google-research / scenic

Scenic: A Jax Library for Computer Vision Research and Beyond
Apache License 2.0
3.14k stars 417 forks source link

How to detect more than one predictions for target image? #1061

Closed DishantMewada closed 1 month ago

DishantMewada commented 2 months ago

In the minimal_example colab, we have block for 'Get predictions for target image with the query embedding', which detects the single closest match according to the source image. Is it possible to detect more than one object with a specific 'score'?

I am talking about the code below, present in the colab -

feature_map = image_embedder(target_image[None, ...])

b, h, w, d = feature_map.shape
target_boxes = box_predictor(
    image_features=feature_map.reshape(b, h * w, d), feature_map=feature_map
)['pred_boxes']

target_class_predictions = class_predictor(
    image_features=feature_map.reshape(b, h * w, d),
    query_embeddings=query_embedding[None, None, ...],  # [batch, queries, d]
)

# Remove batch dimension and convert to numpy:
target_boxes = np.array(target_boxes[0])
target_logits = np.array(target_class_predictions['pred_logits'][0])

top_ind = np.argmax(target_logits[:, 0], axis=0)
score = sigmoid(target_logits[top_ind, 0])

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.imshow(target_image, extent=(0, 1, 1, 0))
ax.set_axis_off()

cx, cy, w, h = target_boxes[top_ind]
ax.plot(
    [cx - w / 2, cx + w / 2, cx + w / 2, cx - w / 2, cx - w / 2],
    [cy - h / 2, cy - h / 2, cy + h / 2, cy + h / 2, cy - h / 2],
    color='lime',
)

ax.text(
    cx - w / 2 + 0.015,
    cy + h / 2 - 0.015,
    f'Score: {score:1.2f}',
    ha='left',
    va='bottom',
    color='black',
    bbox={
        'facecolor': 'white',
        'edgecolor': 'lime',
        'boxstyle': 'square,pad=.3',
    },
)

ax.set_xlim(0, 1)
ax.set_ylim(1, 0)
ax.set_title(f'Closest match')

Thank you so much.

DishantMewada commented 1 month ago

I have modified the code cell as follows to detect multiple objects.

Mainly,top_ind = np.argmax(target_logits[:, 0], axis=0)provides the index of closest match, which I have changed to top_ind = np.argsort(target_logits[:, 0], axis=0)[-i] and iterating through the len(target_boxes).

DESIERED_SCORE = 0.97
NUMBER_OF_CLOSEST_OBJECTS = 5

feature_map = image_embedder(target_image[None, ...])

b, h, w, d = feature_map.shape
target_boxes = box_predictor(
    image_features=feature_map.reshape(b, h * w, d), feature_map=feature_map
)['pred_boxes']

target_class_predictions = class_predictor(
    image_features=feature_map.reshape(b, h * w, d),
    query_embeddings=query_embedding[None, None, ...],  # [batch, queries, d]
)

# Remove batch dimension and convert to numpy:
target_boxes = np.array(target_boxes[0])
target_logits = np.array(target_class_predictions['pred_logits'][1])

len_target_boxes = len(target_boxes)
# top_ind = np.argmax(target_logits[:, 0], axis=0)

dimension_list = []
objects_counter = 0
score_list = []

for i in range(len_target_boxes):

    top_ind = np.argsort(target_logits[:, 0], axis=0)[-i]

    score = sigmoid(target_logits[top_ind, 0])

    objects_counter = objects_counter + 1

    if score > DESIERED_SCORE and objects_counter <= NUMBER_OF_CLOSEST_OBJECTS:

        cx, cy, w, h = target_boxes[top_ind]
        dimension_list.append([cx, cy, w, h])

        score_list.append(score)

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.imshow(target_image, extent=(0, 1, 1, 0))
ax.set_axis_off()

for i in range(len(dimension_list)):
    cx = dimension_list[i][0]
    cy = dimension_list[i][1]
    w = dimension_list[i][2]
    h = dimension_list[i][3]

    ax.plot(
        [cx - w / 2, cx + w / 2, cx + w / 2, cx - w / 2, cx - w / 2],
        [cy - h / 2, cy - h / 2, cy + h / 2, cy + h / 2, cy - h / 2],
        color='lime',
    )

    ax.text(
        cx - w / 2 + 0.015,
        cy + h / 2 - 0.015,
        f'Score: {score_list[i]:1.2f}',
        ha='left',
        va='bottom',
        color='black',
        bbox={
            'facecolor': 'white',
            'edgecolor': 'lime',
            'boxstyle': 'square,pad=.3',
        },
    )

ax.set_xlim(0, 1)
ax.set_ylim(1, 0)
ax.set_title(f'Closest match')

Let me know if you find something is wrong in the code, or have tips to write the code better.