rballester / tntorch

Tensor Network Learning with PyTorch
https://tntorch.readthedocs.io/
GNU Lesser General Public License v3.0
283 stars 42 forks source link

Fix incorrect parameter passing, shape is not kwargs now #44

Closed tczhangzhi closed 1 year ago

tczhangzhi commented 1 year ago

The existing case reports an error at runtime:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 1
----> 1 t = tn.rand(shape=[nticks]*N + [C], ranks_tt=10, ranks_tucker=6, requires_grad=True)
      2 t.set_factors('dct', dim=range(N))
      3 t

File ~/Data/anaconda3/envs/mental_wellness/lib/python3.10/site-packages/tntorch/create.py:40, in rand(*shape, **kwargs)
     22 def rand(*shape, **kwargs):
     23     """
     24     Generate a :class:`Tensor` with random cores (and optionally factors), whose entries are uniform in :math:`[0, 1]`.
     25 
   (...)
     37     :return: a random tensor
     38     """
---> 40     return _create(torch.rand, *shape, **kwargs)

TypeError: _create() got an unexpected keyword argument 'shape'