Michaelvll / DeepCCA

An implementation of Deep Canonical Correlation Analysis (DCCA or Deep CCA) with pytorch.
Other
291 stars 65 forks source link

Calculation of trace norm when using all singular values #8

Open jameschapman19 opened 4 years ago

jameschapman19 commented 4 years ago

The order of operations when using all singular values is currently sqrt(trace(T'T)) - note this is not the case when using topK.

Equation (10) in Andrew's original paper due to notation is slightly ambiguous but his description (and the derivation of gradients) suggest that the correct order is trace(sqrt(T'T)).

I've been working on fixing this with the main problem being that T'T here seems to have negative entries off the diagonals which in pytorch result in nan gradients. I think that tensorflow possibly doesn't produce nans for gradient sqrt(0) which is what allows tensorflow implementations to work out of the box with trace(sqrt(T'T)). Note: it doesn't actually matter what the backend produces in either case because the off diagonals should be thrown by the trace operations.

A short term fix is to push everything through the topk route?

jameschapman19 commented 4 years ago

@Michaelvll the change you have just made will break people's code I think! For the reasons I address above! Need to work round pytorch gradient sqrt(0) to fix the all singular values path.

jameschapman19 commented 4 years ago

I have an idea I'm just testing. I think given what I've just said we could do sum(sqrt(diag(T'T)))

jameschapman19 commented 4 years ago

Keeps the efficiency ie not using another eigendecomposition but hopefully won't break

jameschapman19 commented 4 years ago

Yep this looks like the one:

torch.sum(torch.sqrt(torch.diag(torch.matmul(Tval.t(), Tval))))

Should do the job.

arminarj commented 4 years ago

Hi @jameschapman19,

Thank you for opening a new issue and Thank you for your suggestion. I have made a new branch based on your valuable suggestion, please double-check it before we merge it to the master.

ChenyuxinXMU commented 2 years ago

Hi @arminarj @jameschapman19 @Michaelvll , When I was using this code, I had the problem that loss is NaN or the algorithm failed to converge. The problem existed in the following code: tmp = torch.matmul(Tval.t(), Tval) corr = torch.trace(torch.sqrt(tmp)) How can I solve this problem?

arminarj commented 2 years ago

Hi @Eason24,

This type of error usually happens when you feed a zero value to the torch.sqrt() function. As for a short-term solution, I suggest you use the self.use_all_singular_values as False (there should not be a much of difference in the final values if you use a reseanable self.outdim_size).

Also you can implement the regularization like this line to it.

Best, Armin

ChenyuxinXMU commented 2 years ago

@arminarj , Thank you for your suggestion. I've tried your method and others, but the problem still exists. At present, I want to find a method to align multi-view data. I notice that there is an algorithm called DGCCA in your github project and I will try it in the next step. I would appreciate if you have some suggestion in alignment of muti-view data.

jameschapman19 commented 2 years ago

One thing to be aware of is that the gradients of this objective are pretty unstable as you reduce the batch size (basically because eigenvalue solvers are inexact so even if you add regularisation matrix can appear not positive semidefinite). In the original paper they actually propose batch gradient descent (l-bfgs optimizer).

In any case it converges pretty badly for small batch sizes. A couple of things you can try (and I promise not just a plug from me!) are the stochastic decorrelation loss based dcca and the non-linear orthogonal iterations based dcca which you can look at one of my repos to find. They were designed for the stochastic minibatch case.

guomanshan commented 1 year ago

Hi, I used the dcca loss to train my dnn and got the following error message:

46 [D1, V1] = torch.symeig(SigmaHat11, eigenvectors=True) 47 [D2, V2] = torch.symeig(SigmaHat22, eigenvectors=True) 48 # assert torch.isnan(D1).sum().item() == 0

RuntimeError: symeig_cuda: The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: 1807).

Do you have some methods to fix it? my cuda version is 11.3.