google-research / scenic

Scenic: A Jax Library for Computer Vision Research and Beyond
Apache License 2.0
3.34k stars 441 forks source link

[OWL] Extremely Slow #1103

Closed percypeng5221 closed 2 months ago

percypeng5221 commented 2 months ago

Hi, I have successfully installed this repo and I can run clip_vit_l14_with_masks_6 somec17944 on my own images. But the speed is extremely slow like 2 minutes for 1 image. My GPU is 4090 laptop and it's running 100%. Also, the speed is not very sensitive to the image resolution. I've attached the image I'm using. room Am I doing something wrong? Here's my code:

import os
_CUR_DIR = os.path.dirname(os.path.realpath(__file__))

import sys
sys.path.append(_CUR_DIR+'/big_vision/')

import jax
from matplotlib import pyplot as plt
import numpy as np
from scenic.projects.owl_vit import configs
from scenic.projects.owl_vit import models
from scipy.special import expit as sigmoid
import pprint
from skimage import io as skimage_io
from flax import linen as nn
from skimage import transform as skimage_transform

config = configs.clip_l14_with_masks.get_config(init_mode='canonical_checkpoint')
module = models.TextZeroShotDetectionModule(
    body_configs=config.model.body,
    mask_head_configs=config.model.mask_head,
    normalize=config.model.normalize,
    box_bias=config.model.box_bias)

variables = module.load_variables(_CUR_DIR+"/clip_vit_l14_with_masks_6c17944")

# Load example image:
filename = os.path.join(_CUR_DIR+'/images/room.jpg')
image_uint8 = skimage_io.imread(filename)
image = image_uint8.astype(np.float32) / 255.0

# Factor down the image size:
scaling_factor = 1  # Define the scaling factor (e.g., 0.5 to reduce the resolution by half)
h, w, _ = image.shape
new_h, new_w = int(h * scaling_factor), int(w * scaling_factor)
print(new_h, new_w)
# Resize the image to the new dimensions:
image_rescaled = skimage_transform.resize(
    image, (new_h, new_w), anti_aliasing=True
)

# Pad the rescaled image to square with gray pixels on bottom and right:
size = max(new_h, new_w)
image_padded = np.pad(
    image_rescaled, ((0, size - new_h), (0, size - new_w), (0, 0)), constant_values=0.5)

# Resize to model input size:
input_image = skimage_transform.resize(
    image_padded,
    (config.dataset_configs.input_size, config.dataset_configs.input_size),
    anti_aliasing=True
)

text_queries = ['table']
tokenized_queries = np.array([
    module.tokenize(q, config.dataset_configs.max_query_length)
    for q in text_queries
])

# Pad tokenized queries to avoid recompilation if number of queries changes:
tokenized_queries = np.pad(
    tokenized_queries,
    pad_width=((0, 100 - len(text_queries)), (0, 0)),
    constant_values=0)

jitted = jax.jit(module.apply, static_argnames=('train',))
# Note: The model expects a batch dimension.
import time
time1 = time.time()
predictions = jitted(
    variables,
    input_image[None, ...],
    tokenized_queries[None, ...],
    train=False)
time2 = time.time()
print(time2 - time1)
# Remove batch dimension and convert to numpy:

predictions = jax.tree_util.tree_map(lambda x: np.array(x[0]), predictions )

score_threshold = 0.3

logits = predictions['pred_logits'][..., :len(text_queries)]  # Remove padding.
scores = sigmoid(np.max(logits, axis=-1))
labels = np.argmax(predictions['pred_logits'], axis=-1)
boxes = predictions['pred_boxes']

masks = [None] * len(boxes)
if 'pred_masks' in predictions:
  masks = sigmoid(predictions['pred_masks'])

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.imshow(input_image, extent=(0, 1, 1, 0))
ax.set_axis_off()

for score, box, label, mask in zip(scores, boxes, labels, masks):
  if score < score_threshold:
    continue
  cx, cy, w, h = box
  ax.plot([cx - w / 2, cx + w / 2, cx + w / 2, cx - w / 2, cx - w / 2],
          [cy - h / 2, cy - h / 2, cy + h / 2, cy + h / 2, cy - h / 2], 'r')

  if mask is not None:
    mask_img = plt.cm.viridis(mask)
    mask_img[..., -1] = (mask > 0.5) * 0.8
    extent = np.array((cx - w / 2, cx + w / 2, cy + h / 2, cy - h / 2))
    ax.imshow(mask_img, extent=np.clip(extent, 0, 1))

  ax.text(
      cx - w / 2,
      cy + h / 2 + 0.015,
      f'{text_queries[label]}: {score:1.2f}',
      ha='left',
      va='top',
      color='red',
      bbox={
          'facecolor': 'white',
          'edgecolor': 'red',
          'boxstyle': 'square,pad=.3'
      })

ax.set_xlim(0, 1)
ax.set_ylim(1, 0)

plt.show()

And here's the output of my terminal:

3024 4032
100%|█████████████████████████████████████| 1.29M/1.29M [00:00<00:00, 25.0MiB/s]
2024-09-03 11:46:05.026150: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng0{} for conv (f32[2304,256,32,32]{3,2,1,0}, u8[0]{0}) custom-call(f32[2304,132,32,32]{3,2,1,0}, f32[256,132,1,1]{3,2,1,0}), window={size=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-09-03 11:46:08.940115: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 4.914129581s
Trying algorithm eng0{} for conv (f32[2304,256,32,32]{3,2,1,0}, u8[0]{0}) custom-call(f32[2304,132,32,32]{3,2,1,0}, f32[256,132,1,1]{3,2,1,0}), window={size=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
116.44812893867493
eliabntt commented 2 months ago

hey @percypeng5221 did you solve this?

percypeng5221 commented 2 months ago

hey @percypeng5221 did you solve this?

Oh, I think the time is mainly used for jax warm up. The actually 1 time running takes around 1 second which is okay!. I will close this issue.