These changes are for the issue: All array should be from the same type/backend
The change in distance.py is for the flow example provided in the README. Btw, to make the example works, change symsqrt_v2(A, func='symeig') to symsqrt_v2(A, func='svd') is necessary.
For the change in wasserstein.py, it addresses the situation when user switches inner_ot_loss = 'sinkhorn' to inner_ot_loss = 'wasserstein'
These changes are for the issue: All array should be from the same type/backend
The change in
distance.py
is for the flow example provided in the README. Btw, to make the example works, change symsqrt_v2(A, func='symeig') to symsqrt_v2(A, func='svd') is necessary. For the change inwasserstein.py
, it addresses the situation when user switchesinner_ot_loss = 'sinkhorn' to inner_ot_loss = 'wasserstein'