learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.61k stars 350 forks source link

Hello, I can't run the KroneckerLSTM demo directly. #368

Closed 0xzhouchenyu closed 1 year ago

0xzhouchenyu commented 1 year ago

I tried to run the KronckerLSTM, but I failed. AS follows:

import torch
from learn2learn.nn import KroneckerLSTM

m, n = 2, 3
x = torch.randn(6)
h = torch.randn(6)
kronecker = KroneckerLSTM(n, m)
y, new_h = kronecker(x, h)
y.shape  # (6, )

And the Error is:

Traceback (most recent call last):
  File "/Users/zhouchenyu/opt/anaconda3/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3369, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-49bff790c494>", line 1, in <cell line: 1>
    runfile('/Users/zhouchenyu/PycharmProjects/pythonProject/model/debug.py', wdir='/Users/zhouchenyu/PycharmProjects/pythonProject/model')
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/zhouchenyu/PycharmProjects/pythonProject/model/debug.py", line 9, in <module>
    y, new_h = kronecker(x, h)
  File "/Users/zhouchenyu/opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/zhouchenyu/opt/anaconda3/lib/python3.9/site-packages/learn2learn/nn/kroneckers.py", line 197, in forward
    h, c = hidden
ValueError: too many values to unpack (expected 2)

Can you help me fix it? Thank you!

seba-1511 commented 1 year ago

Hello @0xzhouchenyu,

This example is indeed wrong (will be fixed in #400) and it should read:

n, m = 2, 3
x = torch.randn(n, m)
h = torch.randn(n, m)
c = torch.zeros(n, m)
kronecker = KroneckerLSTM(n, m)
y, new_h = kronecker(x, (h, c))
y.shape  # (2, 3)

Thanks for pointing it out and apologies for the delay!