rballester / tntorch

Tensor Network Learning with PyTorch
https://tntorch.readthedocs.io/
GNU Lesser General Public License v3.0
283 stars 42 forks source link

For using default dtype torch.float32 #2

Closed gngdb closed 5 years ago

gngdb commented 5 years ago

Everything I needed for TT and Tucker decompositions almost worked already when I changed the default dtype to float32 after importing tntorch. Tucker had some problems because numpy would give float64 as output to a sqrt when calling truncated_svd. This fixes that problem, and shouldn't affect float64.

rballester commented 5 years ago

Gavin,

I've pushed a different solution that goes more to the root of the problem and avoids using NumPy's sqrt().

Thanks for spotting this issue, nonetheless!

gngdb commented 5 years ago

Thanks, yeah, my solution was a bit of a hack.