sustainability-lab / ASTRA

"AI for Sustainability" Toolkit for Research and Analysis
2 stars 7 forks source link

`torch.set_diff1d` comparable to numpy one #1

Open nipunbatra opened 1 year ago

nipunbatra commented 1 year ago

https://numpy.org/doc/stable/reference/generated/numpy.setdiff1d.html

a = np.array([1, 2, 3, 2, 4, 1])
b = np.array([3, 4, 5, 6])
np.setdiff1d(a, b)

array([1, 2])
def torch_set_diff1d(a, b):
    # Convert input tensors to sets
    set_a = set(a.tolist())
    set_b = set(b.tolist())

    # Compute the set difference
    result = set_a - set_b

    # Convert the result back to a PyTorch tensor
    result_tensor = torch.tensor(list(result))

    return result_tensor
a = torch.tensor(a)
b = torch.tensor(b)
torch_set_diff1d(a, b)

tensor([1, 2])

There can be more variants, that would likely be quicker by masking than by converting to set. Another variation is with converting to numpy and then back.