facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.59k stars 123 forks source link

higher for dpt architectures? #132

Open Ainaz99 opened 2 years ago

Ainaz99 commented 2 years ago

Hi,

Thanks for the great library! Does higher support a dpt-based module passed to the higher.innerloop_ctx? I'm getting the following error:

  File "/home/rbachman/miniconda/envs/py38/lib/python3.8/contextlib.py", line 113, in __enter__
    return next(self.gen)
  File "/home/rbachman/miniconda/envs/py38/lib/python3.8/site-packages/higher/__init__.py", line 85, in innerloop_ctx
    fmodel = monkeypatch(
  File "/home/rbachman/miniconda/envs/py38/lib/python3.8/site-packages/higher/patch.py", line 542, in monkeypatch
    fmodule = make_functional(module, encapsulator=encapsulator)
  File "/home/rbachman/miniconda/envs/py38/lib/python3.8/site-packages/higher/patch.py", line 435, in make_functional
    _, fmodule, MonkeyPatched = _make_functional(module, params_box, 0)
  File "/home/rbachman/miniconda/envs/py38/lib/python3.8/site-packages/higher/patch.py", line 348, in _make_functional
    child_params_offset, fchild, _ = _make_functional(
  File "/home/rbachman/miniconda/envs/py38/lib/python3.8/site-packages/higher/patch.py", line 218, in _make_functional
    class MonkeyPatched(_ModuleType, _MonkeyPatchBase):  # type: ignore
  File "/home/rbachman/miniconda/envs/py38/lib/python3.8/abc.py", line 85, in __new__
    cls = super().__new__(mcls, name, bases, namespace, **kwargs)
TypeError: Cannot create a consistent method resolution
order (MRO) for bases Module, _MonkeyPatchBase

Thank you for your response!

HamedHemati commented 2 years ago

Hi! I don't know what a DPT architecture is, but if it contains any sort of RNN-based module inside, you may also need to use the context manager below:

with torch.backends.cudnn.flags(enabled=False):
    # Your meta-learning code here...

Have a look at the Known/Possible Issues section on the readme page regarding potential issues: https://github.com/facebookresearch/higher#knownpossible-issues

HamedHemati commented 2 years ago

@Ainaz99

Here is a follow-up to my previous comment. Today, I also got the same error for a simple MLP model that would inherit from multiple classes :

class MyModel(nn.Module, AnotherClass):
       def __init__(self):
              ...

I re-implemented the model as a class with single inheritance and added the methods from the second base (super) class directly to my new model:

class MyModel(nn.Module):
       def __init__(self):
              ...
      + def methods_from_AnotherClass:
             ...

This solved the issue for me. Let me know if this was also the case for you.

** This problem is probably triggered by the similar inheritances happening inside those different base classes when calling higher.patch.monkeypatch(model,...) .