Closed lxiaohaung closed 5 months ago
Hi, @lxiaohaung
We run model.text_decoder.reset_cache(max_batch_size=8)
in the Build section.
This function creates a static KVCache for maximum 8 prompts in a batch.
The shape of yourinputs["points"]
is (9, 2, 3)
, which means the batch size is 9.
Problem solved, thank you very much for your help
When I change the number of the points in the Inference.ipynb :
inputs["points"] = np.array([[[1050.1, 900, 1],[0,0,4]],[[900,800,1],[0,0,4]],[[230, 900, 1],[0,0,4]],[[1120, 920, 1],[0,0,4]],[[1030, 900, 1],[0,0,4]],[[112, 93, 1],[0,0,4]],[[1022, 904, 1],[0,0,4]],[[1050,123, 1],[0,0,4]],[[50, 900, 1],[0,0,4]],], "float32")
Something goes wrong ,the details are shown as follows: `RuntimeError Traceback (most recent call last) Cell In[7], line 24 22 concepts, scores = model.predict_concept(sem_embeds[mask_index]) 23 print(sem_tokens[mask_index][:, None, :].shape) ---> 24 captions = model.generate_text(sem_tokens[mask_index][:, None, :]) 25 print(captions) 26 # Display comprehensive visual understanding.
File ~/anaconda3/envs/tap/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, *kwargs):
114 with ctx_factory():
--> 115 return func(args, kwargs)
File ~/anaconda3/envs/tap/lib/python3.10/site-packages/tokenize_anything/modeling/image_tokenizer.py:189, in ImageTokenizer.generate_text(self, visual_tokens, max_gen_len, temperature) 187 decode_seq_len = cur_pos - prev_pos 188 x = torch.as_tensor(tokens[:, prev_pos:cur_pos], device=prompts.device) --> 189 logits = self.text_decoder.transformer(prompts, x, prev_pos) 190 next_logits = logits[: x.size(0), decode_seq_len - 1] 191 if temperature > 0:
File ~/anaconda3/envs/tap/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, *kwargs) 1509 return self._compiled_call_impl(args, **kwargs) # type: ignore[misc] 1510 else: ... 1215 num_splits, 1216 ) 1217 return out
RuntimeError: seqlens_k must have shape (batch_size)`