Closed kmkurn closed 4 years ago
Almost perfect. You just have to update _unconvert
as well. It gets called in arrange marginals. If you could make it squeeze the last dimension as well that would keep the old code working.
I've updated _unconvert
as suggested, but both argmax and marginals still have 3 dimensions :-(
UPDATE: I think I've got it.
Awesome. Adding
This is work in progress and isn't ready to merge yet.
This seems to work for partition, but argmax and marginals don't return as I expect. Both return tensor of shape (B, N, N); I'd expect them to return (B, N, N, L) tensors instead. Any advice?