microsoft / otdd

Optimal Transport Dataset Distance
MIT License
156 stars 48 forks source link

When I execute the function `pwdist_exact`, I get an error #30

Open xlcbingo1999 opened 1 year ago

xlcbingo1999 commented 1 year ago

When I try to compare the distance of two subsets, which randomly sampled from the EMNIST dataset, I use the 'exact' method and follow the format of example.py, but I always enter except at the function pwdist_exact.

from otdd.pytorch.datasets import load_torchvision_data
from otdd.pytorch.distance import DatasetDistance
from torch.utils.data import DataLoader, SubsetRandomSampler, Dataset
from torchvision.datasets import EMNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Grayscale
import json

class CustomDataset(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = [int(i) for i in indices]
        self.targets = dataset.targets # 保留targets属性
        self.classes = dataset.classes # 保留classes属性

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, item):
        x, y = self.dataset[self.indices[item]]
        return x, y

    def get_class_distribution(self):
        sub_targets = self.targets[self.indices]
        return sub_targets.unique(return_counts=True)

raw_data_path = '/mnt/linuxidc_client/dataset/Amazon_Review_split/EMNIST'
sub_train_config_path = '/mnt/linuxidc_client/dataset/Amazon_Review_split/sub_train_datasets_config.json'
sub_test_config_path = '/mnt/linuxidc_client/dataset/Amazon_Review_split/test_dataset_config.json'
train_id = 0
test_id = 0
dataset_name = "EMNIST"
sub_train_key = 'train_sub_{}'.format(train_id)
sub_test_key = 'test_sub_{}'.format(test_id)
BATCH_SIZE = 2048
with open(sub_train_config_path, 'r+') as f:
    current_subtrain_config = json.load(f)
    f.close()
with open(sub_test_config_path, 'r+') as f:
    current_subtest_config = json.load(f)
    f.close()
real_train_index = sorted(list(current_subtrain_config[dataset_name][sub_train_key]["indexes"]))
print("check last real_train_index: ", real_train_index[-1])
print(len(real_train_index))
real_test_index = sorted(list(current_subtest_config[dataset_name][sub_test_key]["indexes"])) 
print("check last real_test_index: ", real_test_index[-1])
print(len(real_test_index))

transform = Compose([
    Grayscale(3),
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])
train_dataset = EMNIST(
    root=raw_data_path,
    split="bymerge",
    download=False,
    train=True,
    transform=transform
)
test_dataset = EMNIST(
    root=raw_data_path,
    split="bymerge",
    download=False,
    train=False,
    transform=transform
)
print("begin train: {} test: {}".format(train_id, test_id))
print("check all size: train[{}] and test[{}]".format(len(train_dataset), len(test_dataset)))
train_dataset = CustomDataset(train_dataset, real_train_index)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE)
test_dataset = CustomDataset(test_dataset, real_test_index)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE)
print("Finished split datasets!")
print("check train_loader: {}".format(len(train_loader) * BATCH_SIZE))
print("check test_loader: {}".format(len(test_loader) * BATCH_SIZE))

# Instantiate distance
dist = DatasetDistance(train_loader, test_loader,
                          inner_ot_method = 'exact',
                          debiased_loss = True,
                          p = 2, entreg = 1e-1,
                          device='cuda:3')

d = dist.distance(maxsamples = 1000)
print(f'OTDD-EMNIST(train,test)={d:8.2f}')

It is worth noting that the label distribution of the MNIST subset is not the same as that of the entire EMNIST dataset. In the subset, the number of instances of some labels is 0.

The function pwdist_exact seems to return the correct result when I take evenly spaced samples. Here is the code.

# real_train_index = sorted(list(current_subtrain_config[dataset_name][sub_train_key]["indexes"]))
real_train_index = list(range(0, 697932, 10))
# real_test_index = sorted(list(current_subtest_config[dataset_name][sub_test_key]["indexes"])) 
real_test_index = list(range(0, 116323, 3))

You can download my sub_train_datasets_config.json and test_dataset_config.json in Google Drive. Link: https://drive.google.com/drive/folders/1r_vyLJ-RmuuNZqneBP3meexrEZvgc_Ce?usp=sharing

xlcbingo1999 commented 1 year ago

error log:

"This is awkward. Distance computation failed. Geomloss is hard to debug" \
"But here's a few things that might be happening: "\
" 1. Too many samples with this label, causing memory issues" \
" 2. Datatype errors, e.g., if the two datasets have different type"
Distance computation failed. Aborting.
xlcbingo1999 commented 1 year ago

When the product of the two matrices used to calculate cost exceeds 5000**2, Geomloss goes into backend = 'online', and your code does not handle the cost of online, thus causing a bug. I did some processing in the code below

if cost_function == 'euclidean':
    if p == 1:
        small_cost_function = lambda x, y: geomloss.utils.distances(x, y)
        big_cost_function = "Norm2(X-Y)"
    elif p == 2:
        small_cost_function = lambda x, y: geomloss.utils.squared_distances(x, y)
        big_cost_function = "(SqDist(X,Y) / IntCst(2))"
    else:
        raise ValueError()

if loss == 'sinkhorn':
    small_distance = geomloss.SamplesLoss(
        loss=loss, p=p,
        cost=small_cost_function,
        debias=debias,
        blur=entreg**(1 / p),
    )
    big_distance = geomloss.SamplesLoss(
        loss=loss, p=p,
        cost=big_cost_function,
        debias=debias,
        blur=entreg**(1 / p),
    )
elif loss == 'wasserstein':
    def small_distance(Xa, Xb):
        C = small_cost_function(Xa, Xb).cpu()
        return torch.tensor(ot.emd2(ot.unif(Xa.shape[0]), ot.unif(Xb.shape[0]), C))#, verbose=True)
    def big_distance(Xa, Xb):
        C = big_cost_function(Xa, Xb).cpu()
        return torch.tensor(ot.emd2(ot.unif(Xa.shape[0]), ot.unif(Xb.shape[0]), C))#, verbose=True)
else:
    raise ValueError('Wrong loss')

logger.info('Computing label-to-label (exact) wasserstein distances...')
pbar = tqdm(pairs, leave=False)
pbar.set_description('Computing label-to-label distances')
D = torch.zeros((n1, n2), device = device, dtype=X1.dtype)
for i, j in pbar:
    try:
        temp_left = X1[Y1==c1[i]].to(device)
        temp_right = X2[Y2==c2[j]].to(device)
        if temp_left.shape[0] * temp_right.shape[0] >= 5000 ** 2:
            D[i, j] = big_distance(temp_left, temp_right).item()
        else:
            D[i, j] = small_distance(temp_left, temp_right).item()
    except:
        print("This is awkward. Distance computation failed. Geomloss is hard to debug" \
              "But here's a few things that might be happening: "\
              " 1. Too many samples with this label, causing memory issues" \
              " 2. Datatype errors, e.g., if the two datasets have different type")
        sys.exit('Distance computation failed. Aborting.')
    if symmetric:
        D[j, i] = D[i, j]