microsoft / otdd

Optimal Transport Dataset Distance
MIT License
151 stars 48 forks source link

Unexpected results on Apple M1 processor #31

Open joannacknight opened 1 year ago

joannacknight commented 1 year ago

The following code compares the exact otdd where the source and target datasets are the same. The expected result is that the distance d is zero. When the code is run on an Apple M1 processor, the code does not return zero, for the given random seed it returns 2.99 (2d.p.). One work-around when using an M1 is to set device="mps" and then the expected result of zero is returned.

from otdd.pytorch.distance import DatasetDistance
import torch
from torch.utils.data import TensorDataset

n_samples = 1000
n_feats = 100000
n_labels = 10
max_samples = n_samples

torch.manual_seed(42)
data_A = torch.randn(n_samples, n_feats)
labels_A = torch.randint(low=0, high=n_labels, size=(n_samples,))

ds_A = TensorDataset(data_A, labels_A)

dist = DatasetDistance(
    ds_A,
    ds_A,# compare to itself
    inner_ot_method="exact",
    debiased_loss=True,
    p=2,
    entreg = 1e-1,
    inner_ot_debiased=True,
    device="cpu",
)
d = dist.distance(maxsamples=max_samples)
print(f"distance: {d}")