Changes:
Previously, the torch.einsum() was causing memory issues, leading to out-of-memory errors during the call of diag function.
Replacing it with the equivalent opt_einsum implementation solved the issue and alleviated memory consumption during runtime.
Changes: Previously, the torch.einsum() was causing memory issues, leading to out-of-memory errors during the call of diag function. Replacing it with the equivalent opt_einsum implementation solved the issue and alleviated memory consumption during runtime.