microsoft / otdd

Optimal Transport Dataset Distance
MIT License
151 stars 48 forks source link

IndexError: list index out of range while computing distance for datasets #25

Closed prabhant closed 1 year ago

prabhant commented 2 years ago

Code:

import numpy as np
from sklearn.impute import SimpleImputer
# With incomparable dataset distance cost
import torch
import openml
from otdd.pytorch.distance import IncomparableDatasetDistance
from torchvision.models import resnet18
from otdd.pytorch.datasets import load_torchvision_data
import numpy as np
from torch.utils.data import TensorDataset

imp = SimpleImputer(missing_values=np.nan, strategy='mean')
d1 = openml.datasets.get_dataset(8)
d2 = openml.datasets.get_dataset(39)

def dataset_from_numpy(X, Y, classes = None):
    targets =  torch.LongTensor(list(Y))
    ds = TensorDataset(torch.from_numpy(X).type(torch.FloatTensor),targets)
    ds.targets =  targets
    ds.classes = classes if classes is not None else [i for i in range(len(np.unique(Y)))]
    return ds

x1,y1,_,_ = d1.get_data(dataset_format="array", target=d1.default_target_attribute)
x2,y2,_,_ = d2.get_data(dataset_format="array", target=d2.default_target_attribute)
x1 = imp.fit_transform(x1)
x2 = imp.fit_transform(x2)
ds1 = dataset_from_numpy(x1,y1)
ds2 = dataset_from_numpy(x2,y2)
print('datasets created')

dist = IncomparableDatasetDistance(ds1, ds2,
                          debiased_loss = False,
                          inner_ot_method = 'exact',
                          p = 5, entreg = 10e-1,
                          device='cpu')

d = dist.distance()
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
[<ipython-input-3-1e152b2bf0fd>](https://localhost:8080/#) in <module>()
     37                           inner_ot_method = 'exact',
     38                           p = 5, entreg = 10e-1,
---> 39                           device='cpu')
     40 
     41 d = dist.distance()

3 frames
[/usr/local/lib/python3.7/dist-packages/otdd/pytorch/distance.py](https://localhost:8080/#) in __init__(self, *args, **kwargs)
   1093     """
   1094     def __init__(self, *args, **kwargs):
-> 1095         super(IncomparableDatasetDistance, self).__init__(*args, **kwargs)
   1096         if self.debiased_loss:
   1097             raise ValueError('Debiased GWOTDD not implemented yet')

[/usr/local/lib/python3.7/dist-packages/otdd/pytorch/distance.py](https://localhost:8080/#) in __init__(self, D1, D2, method, symmetric_tasks, feature_cost, src_embedding, tgt_embedding, ignore_source_labels, ignore_target_labels, loss, debiased_loss, p, entreg, λ_x, λ_y, inner_ot_method, inner_ot_loss, inner_ot_debiased, inner_ot_p, inner_ot_entreg, diagonal_cov, min_labelcount, online_stats, sqrt_method, sqrt_niters, sqrt_pref, nworkers_stats, coupling_method, nworkers_dists, eigen_correction, device, precision, verbose, *args, **kwargs)
    232 
    233         if self.D1 is not None and self.D2 is not None:
--> 234             self._init_data(self.D1, self.D2)
    235         else:
    236             logger.warning('DatasetDistance initialized with empty data')

[/usr/local/lib/python3.7/dist-packages/otdd/pytorch/distance.py](https://localhost:8080/#) in _init_data(self, D1, D2)
    320 
    321 
--> 322         self.classes1 = [classes1[i] for i in self.V1]
    323         self.classes2 = [classes2[i] for i in self.V2]
    324 

[/usr/local/lib/python3.7/dist-packages/otdd/pytorch/distance.py](https://localhost:8080/#) in <listcomp>(.0)
    320 
    321 
--> 322         self.classes1 = [classes1[i] for i in self.V1]
    323         self.classes2 = [classes2[i] for i in self.V2]
    324 

IndexError: list index out of range
sachitsaksena commented 1 year ago

Hi @prabhant, I am facing a similar issue. Did you have any luck with getting around this?

dmelis commented 1 year ago

Hi @sachitsaksena. The problem here is that otdd expects the targets to be contiguous integers with no gaps between them. In this case, there's lots of gaps:

print(ds1.targets.unique())
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 12, 15, 16, 20])

This is easily solved by reindexing, e.g., map the labels to {0,..,14} in this case.

sachitsaksena commented 1 year ago

Thanks for the help!