Open wiseodd opened 2 months ago
Currently, we still assume that logits.shape == (batch_size, n_classes). E.g.: https://github.com/aleximmer/Laplace/blob/a4d3ed631c87ddb55a1bc7aa053da249180480f7/laplace/curvature/curvature.py#L332-L334
logits.shape == (batch_size, n_classes)
Currently, we still assume that
logits.shape == (batch_size, n_classes)
. E.g.: https://github.com/aleximmer/Laplace/blob/a4d3ed631c87ddb55a1bc7aa053da249180480f7/laplace/curvature/curvature.py#L332-L334