chrysts / dsn_fewshot

MIT License
89 stars 19 forks source link

tensors in double for SVD #2

Closed ZeWang95 closed 3 years ago

ZeWang95 commented 3 years ago

Hi! Thanks for sharing the code. I noticed that you converted the tensor to double before fedding to torch.svd May I ask why you have this implementation? Thanks!

chrysts commented 3 years ago

Hi Ze Wang,

The reason is that because backpropagation through SVD in PyTorch is quite sensitive to floating point overflow. This is just to prevent a weird behaviour when training the model, esp. with different versions of PyTorch.

ZeWang95 commented 3 years ago

Got it! Thanks! I was actually trying similar idea a long time ago, and did observe that torch.svd is very sensitve. This indeed is a good solution.