Open SJLMax opened 2 years ago
Hi, this change should help.
value_vectors_only_global = torch.Tensor(np.zeros([batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim])).detach().cpu()
value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero].detach().cpu()
value_vectors_only_global = value_vectors_only_global.numpy()
Maybe you should check your _tensor.py file where exists code " return self.numpy()" ,convert it to "return self.cpu().numpy()"
value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero].detach().cpu().numpy()
still didnt work