huggingface / huggingface_hub

The official Python client for the Huggingface Hub.
https://huggingface.co/docs/huggingface_hub
Apache License 2.0
1.82k stars 470 forks source link

[PyTorchModelHubMixin] argparse.Namespace config does not seem to be pushed #2334

Closed NielsRogge closed 2 weeks ago

NielsRogge commented 2 weeks ago

Describe the bug

When a PyTorch model takes an argparse.Namespace as first keyword argument, then the config is normally pushed as shown here (assuming a coders argument is passed): https://huggingface.co/docs/huggingface_hub/v0.23.3/en/guides/integrations#config. I've run that code snippet and it works, the config gets pushed.

However, when trying out the same for https://github.com/hamadichihaoui/BIRD/compare/master...NielsRogge:BIRD:update_mixin?expand=1, the config (which is also an instance of argparse.Namespace) does not get pushed: https://huggingface.co/nielsr/bird-celeba-hq/tree/main.

Reproduction

Here's a notebook to reproduce this: https://colab.research.google.com/drive/1iRj6cJtWlLJQEkWkjfZxYCIeC-xHmjuN?usp=sharing.

System info

huggingface_hub version: 0.23.3
Environment: Google Colab
not-lain commented 2 weeks ago

@NielsRogge I believe this can broken down into 2 parts :

def serialize(config) : print("serialize has been called successfully")

creating a copy so we will not affect the original variable when serializing and when doing a recursive function calling

v = dict(vars(config)) for key,value in v.items(): if isinstance(value, Namespace): v[key] = serialize(value) return v

def dict2namespace(config): namespace = Namespace() for key, value in config.items(): if isinstance(value, dict): new_value = dict2namespace(value) else: new_value = value setattr(namespace, key, new_value) return namespace

config = {"i":2,"j":3} config = dict2namespace(config)

class Model(nn.Module, PyTorchModelHubMixin, coders={ Namespace: ( lambda x: serialize(x),
lambda data: dict2namespace(data), ) } ): def init(self, config: Namespace): super().init() self.layer = nn.Linear(config.i,config.j)

model = Model(config) print("config params : ", model._hub_mixin_config)

serialize has been called successfully config params : None

Wauplin commented 2 weeks ago

Hey there! Thanks to @not-lain reproducible example, I have been able to identify and fix the issue. I just opened https://github.com/huggingface/huggingface_hub/pull/2337 for review. Once merged, I'll make a hot-fix release so that you can use it right away @NielsRogge.

The TD;LR: config is treated separately and was not benefiting from the custom encoders/decoders. That's now fixed.

Wauplin commented 2 weeks ago

@NielsRogge @not-lain The fix has been shipped as a hot-fix release https://github.com/huggingface/huggingface_hub/releases/tag/v0.23.4. You can use it for your PRs :)

NielsRogge commented 2 weeks ago

Thanks a lot for looking into this and fixing it, it's now working as expected :) see https://github.com/hamadichihaoui/BIRD/pull/10