Open wiseodd opened 2 months ago
Still WIP test cases
Ready to review!
Discussion points:
logits = logits.view(-1, logits.size(logit_class_dim))
, do they even compute the correct quantities? logit_class_dim = 1
. Do we want to make them respect logit_class_dim
? I don't think flattening as above is the correct approach, right?Merged with main
and ready for review!
It seems more complicated than anticipated. This PR is useful for models with image outputs like diffusion models.
Considering v0.2 is all about LLMs, let's defer this to v0.3!
It seems more complicated than anticipated. This PR is useful for models with image outputs like diffusion models.
Considering v0.2 is all about LLMs, let's defer this to v0.3!
Ok, for now maybe we can add a note in the README and the docstring that clearly states how multi dim outputs are handled?
So what do you have in mind regarding the wording? Something like this in README.md?
## Caveats
- Currently, this library always assumes that the model has an
output tensor of shape `(batch_size, ..., n_classes)`, so in
the case of image outputs, you need to rearrange from NCHW to NHWC.
So what do you have in mind regarding the wording? Something like this in README.md?
Yes, this is exactly what I was thinking!
Closes #163
Please wait until #144 is merged.