microsoft / otdd

Optimal Transport Dataset Distance
MIT License
156 stars 48 forks source link

fix computing backend issue #20

Open ChenChengKuan opened 2 years ago

ChenChengKuan commented 2 years ago

These changes are for the issue: All array should be from the same type/backend

The change in distance.py is for the flow example provided in the README. Btw, to make the example works, change symsqrt_v2(A, func='symeig') to symsqrt_v2(A, func='svd') is necessary. For the change in wasserstein.py, it addresses the situation when user switches inner_ot_loss = 'sinkhorn' to inner_ot_loss = 'wasserstein'

ghost commented 2 years ago

CLA assistant check
All CLA requirements met.