Bihaqo / t3f

Tensor Train decomposition on TensorFlow
https://t3f.readthedocs.io/en/latest/index.html
MIT License
222 stars 55 forks source link

Round can't accept array of ranks because of missing .any() #188

Closed faysou closed 5 years ago

faysou commented 5 years ago

There's a small bug at this line in _round_tt that makes the code break if the argument max_tt_rank is an array or a list: https://github.com/Bihaqo/t3f/blob/develop/t3f/decompositions.py#L251

if max_tt_rank < 1:

should be replaced by

if (max_tt_rank < 1).any():
Bihaqo commented 5 years ago

Thanks, created a pull request #189, does this make sense?