Closed zou3519 closed 1 year ago
Previously we had a fast batching rule for vmap x jvp on torch.linalg.det (https://github.com/pytorch/pytorch/commit/67b104af02584ef60b8d26a0392947367767fbf1). We do not have that anymore (as demonstrated by https://github.com/pytorch/pytorch/pull/85141). This is a performance regression.
We should try to fix it; I don't know how difficult it will be to fix.
Fixed in https://github.com/pytorch/pytorch/pull/85175
Previously we had a fast batching rule for vmap x jvp on torch.linalg.det (https://github.com/pytorch/pytorch/commit/67b104af02584ef60b8d26a0392947367767fbf1). We do not have that anymore (as demonstrated by https://github.com/pytorch/pytorch/pull/85141). This is a performance regression.
We should try to fix it; I don't know how difficult it will be to fix.