Closed gorold closed 3 months ago
Good catch @gorold, thanks for reporting! Would you like to open a PR to fix this? The line if issubclass(expected_type, type_):
should be fixed to work with unions as well. You will probably have to use typing.get_args for this.
Sure, I could take a stab at it, but I think there's some undefined behaviours that I'd like to clarify first.
class ParentArg: ...
class ChildArg(ParentArg): ...
class Model(
nn.Module,
PyTorchModelHubMixin,
coders={ParentArg: ..., ChildArg: ...}
):
def __init__(self, a: ChildArg):
self. a = a
This could be straightforward to solve, just check if expected_type in cls._hub_mixin_coders
before looping through the dict? But this may not be correct for complex inheritance structures. I guess the ideal behaviour is to use the nearest ancestor...
class Arg1: ...
class Arg2: ...
class Model( nn.Module, PyTorchModelHubMixin, coders={Arg1: ..., Arg2: ...} ): def init(self, a: Arg1 | Arg2 | int): self. a = a
This case seems to be more challenging, since we can't tell for sure which decoder we should actually use.
My proposal would be to use the first one that works
```python
for etype in typing.get_args(expected_type):
try:
if etype in cls._hub_mixin_coders:
_, decoder = cls._hub_mixin_coders[etype]
out = decoder(value)
except:
# do smth
return out
return value
This doesn't check for subclassing though.. need to think about it a little more
Thanks for asking the right questions @gorold! I've digged a bit more into it and I think we should:
NotOKModel
class in your example. Annotation int | float
is not even related to the encoder/decoder so it's definitely a bug to raise an error here. I think this can be fixed by checking first if the annotation is a class: if inspect.isclass(expected_type) and issubclass(expected_type, type_):
CustomArg
, Optional[CustomArg]
and CustomArg | None
(last 2 being the same). This is straightforward in a deterministic way for both encoding and decoding. It also covers most of the use cases.What do you think?
That makes sense! I'll make a PR based on this
Fixed in https://github.com/huggingface/huggingface_hub/pull/2291 by @gorold, thanks!
Describe the bug
Using a union typehint in conjunction with custom encoder/decoder raises the below errors.
Reproduction
Logs