JAYATEJAK / GVAlign

Robust Feature Learning and Global Variance-Driven Classifier Alignment for Long-Tail Class Incremental Learning
4 stars 0 forks source link

run error #3

Closed YananGu closed 7 months ago

YananGu commented 7 months ago

Hi, thanks for your great work. I run your code on other dataset, but it appears some errors as follows, can you give me some help?

Traceback (most recent call last): File "/media/project/yanan/eccv2024/gvalign/src/main_incremental.py", line 338, in main() File "/media/project/yanan/eccv2024/gvalign/src/main_incremental.py", line 286, in main appr.train(t, trn_loader[t], tst_loader[t], args, tst_loader) File "/media/project/yanan/eccv2024/gvalign/src/approach/incremental_learning.py", line 57, in train self.train_loop(t, trn_loader, val_loader, args, tst_loader) File "/media/project/yanan/eccv2024/gvalign/src/approach/lucir_gvalign_2stage.py", line 288, in train_loop distrib = MultivariateNormal(loc=mean, covariance_matrix=cov_cls_ms) File "/media/project/yanan/anaconda3/envs/ahead/lib/python3.10/site-packages/torch/distributions/multivariate_normal.py", line 150, in init super(MultivariateNormal, self).init(batch_shape, event_shape, validate_args=validate_args) File "/media/project/yanan/anaconda3/envs/ahead/lib/python3.10/site-packages/torch/distributions/distribution.py", line 56, in init raise ValueError( ValueError: Expected parameter covariance_matrix (Tensor of shape (50, 512, 512)) of distribution MultivariateNormal(loc: torch.Size([50, 512]), covariance_matrix: torch.Size([50, 512, 512])) to satisfy the constraint PositiveDefinite(), but found invalid values: tensor([[[ 2.7521e-03, 1.8102e-03, -9.6236e-04, ..., -1.2610e-03, 5.4316e-04, -6.7986e-04], [ 1.8102e-03, 1.5385e-03, -9.5112e-04, ..., -7.8666e-04, 5.5435e-04, -5.4887e-04], [-9.6236e-04, -9.5112e-04, 1.2840e-03, ..., 2.5209e-04, -5.3971e-04, 6.2342e-04], ..., [-1.2610e-03, -7.8666e-04, 2.5209e-04, ..., 1.1643e-03, 3.7204e-05, 9.5319e-05], [ 5.4316e-04, 5.5435e-04, -5.3971e-04, ..., 3.7204e-05, 4.4342e-04, -3.3423e-04], [-6.7986e-04, -5.4887e-04, 6.2342e-04, ..., 9.5319e-05, -3.3423e-04, 4.5726e-04]],

    [[ 2.7521e-03,  1.8102e-03, -9.6236e-04,  ..., -1.2610e-03,
       5.4316e-04, -6.7986e-04],
     [ 1.8102e-03,  1.5385e-03, -9.5112e-04,  ..., -7.8666e-04,
       5.5435e-04, -5.4887e-04],
     [-9.6236e-04, -9.5112e-04,  1.2840e-03,  ...,  2.5209e-04,
      -5.3971e-04,  6.2342e-04],
     ...,
     [-1.2610e-03, -7.8666e-04,  2.5209e-04,  ...,  1.1643e-03,
       3.7204e-05,  9.5319e-05],
     [ 5.4316e-04,  5.5435e-04, -5.3971e-04,  ...,  3.7204e-05,
       4.4342e-04, -3.3423e-04],
     [-6.7986e-04, -5.4887e-04,  6.2342e-04,  ...,  9.5319e-05,
      -3.3423e-04,  4.5726e-04]],

    [[ 2.7521e-03,  1.8102e-03, -9.6236e-04,  ..., -1.2610e-03,
       5.4316e-04, -6.7986e-04],
     [ 1.8102e-03,  1.5385e-03, -9.5112e-04,  ..., -7.8666e-04,
       5.5435e-04, -5.4887e-04],
     [-9.6236e-04, -9.5112e-04,  1.2840e-03,  ...,  2.5209e-04,
      -5.3971e-04,  6.2342e-04],
     ...,
     [-1.2610e-03, -7.8666e-04,  2.5209e-04,  ...,  1.1643e-03,
       3.7204e-05,  9.5319e-05],
     [ 5.4316e-04,  5.5435e-04, -5.3971e-04,  ...,  3.7204e-05,
       4.4342e-04, -3.3423e-04],
     [-6.7986e-04, -5.4887e-04,  6.2342e-04,  ...,  9.5319e-05,
      -3.3423e-04,  4.5726e-04]],

    ...,

    [[ 2.7521e-03,  1.8102e-03, -9.6236e-04,  ..., -1.2610e-03,
       5.4316e-04, -6.7986e-04],
     [ 1.8102e-03,  1.5385e-03, -9.5112e-04,  ..., -7.8666e-04,
       5.5435e-04, -5.4887e-04],
     [-9.6236e-04, -9.5112e-04,  1.2840e-03,  ...,  2.5209e-04,
      -5.3971e-04,  6.2342e-04],
     ...,
     [-1.2610e-03, -7.8666e-04,  2.5209e-04,  ...,  1.1643e-03,
       3.7204e-05,  9.5319e-05],
     [ 5.4316e-04,  5.5435e-04, -5.3971e-04,  ...,  3.7204e-05,
       4.4342e-04, -3.3423e-04],
     [-6.7986e-04, -5.4887e-04,  6.2342e-04,  ...,  9.5319e-05,
      -3.3423e-04,  4.5726e-04]],

    [[ 2.7521e-03,  1.8102e-03, -9.6236e-04,  ..., -1.2610e-03,
       5.4316e-04, -6.7986e-04],
     [ 1.8102e-03,  1.5385e-03, -9.5112e-04,  ..., -7.8666e-04,
       5.5435e-04, -5.4887e-04],
     [-9.6236e-04, -9.5112e-04,  1.2840e-03,  ...,  2.5209e-04,
      -5.3971e-04,  6.2342e-04],
     ...,
     [-1.2610e-03, -7.8666e-04,  2.5209e-04,  ...,  1.1643e-03,
       3.7204e-05,  9.5319e-05],
     [ 5.4316e-04,  5.5435e-04, -5.3971e-04,  ...,  3.7204e-05,
       4.4342e-04, -3.3423e-04],
     [-6.7986e-04, -5.4887e-04,  6.2342e-04,  ...,  9.5319e-05,
      -3.3423e-04,  4.5726e-04]],

    [[ 2.7521e-03,  1.8102e-03, -9.6236e-04,  ..., -1.2610e-03,
       5.4316e-04, -6.7986e-04],
     [ 1.8102e-03,  1.5385e-03, -9.5112e-04,  ..., -7.8666e-04,
       5.5435e-04, -5.4887e-04],
     [-9.6236e-04, -9.5112e-04,  1.2840e-03,  ...,  2.5209e-04,
      -5.3971e-04,  6.2342e-04],
     ...,
     [-1.2610e-03, -7.8666e-04,  2.5209e-04,  ...,  1.1643e-03,
       3.7204e-05,  9.5319e-05],
     [ 5.4316e-04,  5.5435e-04, -5.3971e-04,  ...,  3.7204e-05,
       4.4342e-04, -3.3423e-04],
     [-6.7986e-04, -5.4887e-04,  6.2342e-04,  ...,  9.5319e-05,
      -3.3423e-04,  4.5726e-04]]])
JAYATEJAK commented 7 months ago

Hi @YananGu, how many epochs have you trained?

YananGu commented 7 months ago

@JAYATEJAK I trained 90 epochs on domainnet dataset. This is a dataset spanning multiple domains

JAYATEJAK commented 7 months ago

@YananGu, I think this issue will arise if proper variance isn't learned from the data. In CIFAR or ImageNet, you won't encounter these issues.

Maybe you can try adding a delta to the covariance matrix, some thing like: class_cov = torch.cov(torch.tensor(vectors, dtype=torch.float64).T) + torch.eye(class_mean.shape[-1]) * 1e-5

In common practice, people prefer this because the covariance matrix should be positive. In our approach, we calculated it from the class that has the most number of samples, so we didn't encounter this issue.

Maybe my suggestion is if you have fewer samples for all classes; try to get a variance from all classes and use it to tune the classifier. this is an interesting experiment to check what happens.

YananGu commented 7 months ago

Thanks, I add a delta to the covariance matrix and the problem fixed. The reason for this problem should be that in the domainnet dataset, even same classes data belongs to different domains, which increases the difficulty of learning variance.