baaivision / tokenize-anything

[ECCV 2024] Tokenize Anything via Prompting
Apache License 2.0
503 stars 19 forks source link

Inference from multiple points #15

Closed lxiaohaung closed 5 months ago

lxiaohaung commented 5 months ago

When I change the number of the points in the Inference.ipynb : inputs["points"] = np.array([[[1050.1, 900, 1],[0,0,4]],[[900,800,1],[0,0,4]],[[230, 900, 1],[0,0,4]],[[1120, 920, 1],[0,0,4]],[[1030, 900, 1],[0,0,4]],[[112, 93, 1],[0,0,4]],[[1022, 904, 1],[0,0,4]],[[1050,123, 1],[0,0,4]],[[50, 900, 1],[0,0,4]],], "float32")

Something goes wrong ,the details are shown as follows: `RuntimeError Traceback (most recent call last) Cell In[7], line 24 22 concepts, scores = model.predict_concept(sem_embeds[mask_index]) 23 print(sem_tokens[mask_index][:, None, :].shape) ---> 24 captions = model.generate_text(sem_tokens[mask_index][:, None, :]) 25 print(captions) 26 # Display comprehensive visual understanding.

File ~/anaconda3/envs/tap/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, kwargs) 112 @functools.wraps(func) 113 def decorate_context(*args, *kwargs): 114 with ctx_factory(): --> 115 return func(args, kwargs)

File ~/anaconda3/envs/tap/lib/python3.10/site-packages/tokenize_anything/modeling/image_tokenizer.py:189, in ImageTokenizer.generate_text(self, visual_tokens, max_gen_len, temperature) 187 decode_seq_len = cur_pos - prev_pos 188 x = torch.as_tensor(tokens[:, prev_pos:cur_pos], device=prompts.device) --> 189 logits = self.text_decoder.transformer(prompts, x, prev_pos) 190 next_logits = logits[: x.size(0), decode_seq_len - 1] 191 if temperature > 0:

File ~/anaconda3/envs/tap/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, *kwargs) 1509 return self._compiled_call_impl(args, **kwargs) # type: ignore[misc] 1510 else: ... 1215 num_splits, 1216 ) 1217 return out

RuntimeError: seqlens_k must have shape (batch_size)`

PhyscalX commented 5 months ago

Hi, @lxiaohaung

  1. We run model.text_decoder.reset_cache(max_batch_size=8) in the Build section. This function creates a static KVCache for maximum 8 prompts in a batch.

  2. The shape of yourinputs["points"] is (9, 2, 3), which means the batch size is 9.

lxiaohaung commented 5 months ago

Problem solved, thank you very much for your help