pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.01k stars 21.99k forks source link

vmapping `index_put_` doesn't handle scalar values properly (results in device mismatch error). #130225

Closed Chillee closed 4 days ago

Chillee commented 2 months ago

🐛 Describe the bug

import torch
torch.set_default_device('cuda')

num_rows = 128
num_cols = 128
device='cuda'
def create_dense_one(kv_num_blocks, kv_indices):
    dense_mask = kv_indices.new_zeros(num_rows, num_cols + 1, dtype=torch.int32)

    row_indices = torch.arange(
        num_rows, dtype=torch.int, device=device
    ).unsqueeze(-1)
    col_indices = torch.arange(num_cols, dtype=torch.int, device=device)
    index_mask = col_indices < kv_num_blocks.unsqueeze(-1)

    # We write to one spot "out of bounds"
    valid_indices = torch.where(index_mask, kv_indices, num_cols)

    # set the values in 'a' to 1 where the indices are valid
    dense_mask[row_indices, valid_indices] = 1
    return dense_mask[:, :num_cols]

kv_num_blocks = torch.zeros(3, 128, device='cuda', dtype=torch.int)
kv_indices = torch.zeros(3, 128, 128, device='cuda', dtype=torch.int)
out = torch.vmap(create_dense_one, in_dims=(0, 0))(kv_num_blocks, kv_indices)

Results in

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

cc: @zou3519 @guilhermeleobas

Versions

N/A

cc @zou3519 @samdow @kshitij12345 @janeyx99

guilhermeleobas commented 1 month ago

Partially fixed in https://github.com/pytorch/pytorch/pull/130479. Reproducer should work on main now.