huggingface / huggingface_hub

The official Python client for the Huggingface Hub.
https://huggingface.co/docs/huggingface_hub
Apache License 2.0
2.02k stars 531 forks source link

Union typehint in conjunction with coders causes error for PyTorchModelHubMixin #2283

Closed gorold closed 3 months ago

gorold commented 4 months ago

Describe the bug

Using a union typehint in conjunction with custom encoder/decoder raises the below errors.

Reproduction

from torch import nn
from huggingface_hub import PyTorchModelHubMixin

class CustomArg:
    @classmethod
    def encode(cls, arg): return "custom"

    @classmethod
    def decode(cls, arg): return CustomArg()

class OKModel(
    nn.Module, 
    PyTorchModelHubMixin, 
    coders={CustomArg: (CustomArg.encode, CustomArg.decode)}
):
    def __init__(self, a: int):
        super().__init__()
        self.a = a

class NotOKModel(
    nn.Module, 
    PyTorchModelHubMixin, 
    coders={CustomArg: (CustomArg.encode, CustomArg.decode)}
):
    def __init__(self, a: int | float):
        super().__init__()
        self.a = a

ok_model = OKModel(1)
ok_model.save_pretrained("model")
ok_model = OKModel.from_pretrained("model")

not_ok_model = NotOKModel(1)
not_ok_model.save_pretrained("model")
not_ok_model = NotOKModel.from_pretrained("model")

Logs

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[8], line 26
     24 not_ok_model = NotOKModel(1)
     25 not_ok_model.save_pretrained("model")
---> 26 not_ok_model = NotOKModel.from_pretrained("model")

File .../lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
    111 if check_use_auth_token:
    112     kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)

File .../lib/python3.10/site-packages/huggingface_hub/hub_mixin.py:472, in ModelHubMixin.from_pretrained(cls, pretrained_model_name_or_path, force_download, resume_download, proxies, token, cache_dir, local_files_only, revision, **model_kwargs)
    470         expected_type = cls._hub_mixin_init_parameters[key].annotation
    471         if expected_type is not inspect.Parameter.empty:
--> 472             config[key] = cls._decode_arg(expected_type, value)
    474 # Populate model_kwargs from config
    475 for param in cls._hub_mixin_init_parameters.values():

File .../lib/python3.10/site-packages/huggingface_hub/hub_mixin.py:317, in ModelHubMixin._decode_arg(cls, expected_type, value)
    315 """Decode a JSON serializable value into an argument."""
    316 for type_, (_, decoder) in cls._hub_mixin_coders.items():
--> 317     if issubclass(expected_type, type_):
    318         return decoder(value)
    319 return value

TypeError: issubclass() arg 1 must be a class

### System info

```shell
Copy-and-paste the text below in your GitHub issue.

- huggingface_hub version: 0.23.0
- Platform: Linux-6.1.58+-x86_64-with-glibc2.31
- Python version: 3.10.11
- Running in iPython ?: No
- Running in notebook ?: No
- Running in Google Colab ?: No
- Token path ?: /root/.cache/huggingface/token
- Has saved token ?: True
- Who am I ?: gorold
- Configured git credential helpers: 
- FastAI: N/A
- Tensorflow: N/A
- Torch: 2.2.0
- Jinja2: 3.1.3
- Graphviz: N/A
- keras: N/A
- Pydot: N/A
- Pillow: 10.2.0
- hf_transfer: N/A
- gradio: N/A
- tensorboard: N/A
- numpy: 1.26.4
- pydantic: 2.6.1
- aiohttp: 3.9.3
- ENDPOINT: https://huggingface.co
- HF_HUB_CACHE: /root/.cache/huggingface/hub
- HF_ASSETS_CACHE: /root/.cache/huggingface/assets
- HF_TOKEN_PATH: /root/.cache/huggingface/token
- HF_HUB_OFFLINE: False
- HF_HUB_DISABLE_TELEMETRY: False
- HF_HUB_DISABLE_PROGRESS_BARS: None
- HF_HUB_DISABLE_SYMLINKS_WARNING: False
- HF_HUB_DISABLE_EXPERIMENTAL_WARNING: False
- HF_HUB_DISABLE_IMPLICIT_TOKEN: False
- HF_HUB_ENABLE_HF_TRANSFER: False
- HF_HUB_ETAG_TIMEOUT: 10
- HF_HUB_DOWNLOAD_TIMEOUT: 10
Wauplin commented 4 months ago

Good catch @gorold, thanks for reporting! Would you like to open a PR to fix this? The line if issubclass(expected_type, type_): should be fixed to work with unions as well. You will probably have to use typing.get_args for this.

gorold commented 4 months ago

Sure, I could take a stab at it, but I think there's some undefined behaviours that I'd like to clarify first.

  1. If coders contains both parent and child class, and an argument has the typehinted to be the child class, it would always use the parent decoder?
class ParentArg: ...

class ChildArg(ParentArg): ...

class Model(
    nn.Module, 
    PyTorchModelHubMixin, 
    coders={ParentArg: ..., ChildArg: ...}
):
    def __init__(self, a: ChildArg):
        self. a = a

This could be straightforward to solve, just check if expected_type in cls._hub_mixin_coders before looping through the dict? But this may not be correct for complex inheritance structures. I guess the ideal behaviour is to use the nearest ancestor...

  1. How should we actually handle unions which have multiple possible candidate coders?
    
    class Arg1: ...

class Arg2: ...

class Model( nn.Module, PyTorchModelHubMixin, coders={Arg1: ..., Arg2: ...} ): def init(self, a: Arg1 | Arg2 | int): self. a = a

This case seems to be more challenging, since we can't tell for sure which decoder we should actually use.
My proposal would be to use the first one that works
```python
for etype in typing.get_args(expected_type):
    try:
        if etype in cls._hub_mixin_coders:
            _, decoder = cls._hub_mixin_coders[etype]
            out = decoder(value)
    except:
        # do smth
    return out
return value

This doesn't check for subclassing though.. need to think about it a little more

Wauplin commented 4 months ago

Thanks for asking the right questions @gorold! I've digged a bit more into it and I think we should:

  1. Fix the NotOKModel class in your example. Annotation int | float is not even related to the encoder/decoder so it's definitely a bug to raise an error here. I think this can be fixed by checking first if the annotation is a class: if inspect.isclass(expected_type) and issubclass(expected_type, type_):
  2. Handle optional type annotation, e.g. CustomArg, Optional[CustomArg] and CustomArg | None (last 2 being the same). This is straightforward in a deterministic way for both encoding and decoding. It also covers most of the use cases.
  3. Not handle more complex types. For union/list/tuple/... annotations we should just ignore them, no matter if the encoder/decoder could have been used. If in the future we get feedback that this use case would make sense in a practical use case, then we would reassess. But let's avoid adding complex logic if it's not necessary.

What do you think?

gorold commented 4 months ago

That makes sense! I'll make a PR based on this

Wauplin commented 3 months ago

Fixed in https://github.com/huggingface/huggingface_hub/pull/2291 by @gorold, thanks!