balbasty / torch-interpol

High-order spline interpolation in PyTorch
MIT License
62 stars 4 forks source link

RuntimeError when import interpol #3

Closed Aria-K-Alethia closed 1 year ago

Aria-K-Alethia commented 1 year ago

Hi,

I try to use your package in my code, but I encountered a RuntimeError when I imported the package. Here is the information displayed when I execute import interpol:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-a972fa03d554> in <module>
      1 import torch
----> 2 import interpol
      3 import matplotlib.pyplot as plt

~/.local/lib/python3.7/site-packages/interpol/__init__.py in <module>
----> 1 from .api import *
      2 from .resize import *
      3 from . import _version
      4 __version__ = _version.get_versions()['version']

~/.local/lib/python3.7/site-packages/interpol/api.py in <module>
      2 import torch
      3 from .utils import expanded_shape, matvec
----> 4 from .jit_utils import movedim1
      5 from .autograd import (GridPull, GridPush, GridCount, GridGrad,
      6                        SplineCoeff, SplineCoeffND)

~/.local/lib/python3.7/site-packages/interpol/jit_utils.py in <module>
     70 
     71 @torch.jit.script
---> 72 def list_prod_tensor(x: List[Tensor]) -> Tensor:
     73     if len(x) == 0:
     74         return torch.ones([])

~/.local/lib/python3.7/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1309             _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
   1310         fn = torch._C._jit_script_compile(
-> 1311             qualified_name, ast, _rcb, get_default_args(obj)
   1312         )
   1313         # Forward docstrings

RuntimeError: 
Arguments for call are not valid.
The following variants are available:

  aten::ones.names(int[] size, *, str[]? names, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor):
  Expected a value of type 'List[int]' for argument 'size' but instead found type 'List[Tensor]'.
  Empty lists default to List[Tensor]. Add a variable annotation to the assignment to create an empty list of another type (torch.jit.annotate(List[T, []]) where T is the type of elements in the list for Python 2)

  aten::ones(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor):
  Expected a value of type 'List[int]' for argument 'size' but instead found type 'List[Tensor]'.
  Empty lists default to List[Tensor]. Add a variable annotation to the assignment to create an empty list of another type (torch.jit.annotate(List[T, []]) where T is the type of elements in the list for Python 2)

  aten::ones.out(int[] size, *, Tensor(a!) out) -> (Tensor(a!)):
  Expected a value of type 'List[int]' for argument 'size' but instead found type 'List[Tensor]'.
  Empty lists default to List[Tensor]. Add a variable annotation to the assignment to create an empty list of another type (torch.jit.annotate(List[T, []]) where T is the type of elements in the list for Python 2)

The original call is:
  File "/home/.local/lib/python3.7/site-packages/interpol/jit_utils.py", line 74
def list_prod_tensor(x: List[Tensor]) -> Tensor:
    if len(x) == 0:
        return torch.ones([])
               ~~~~~~~~~~ <--- HERE
    x0 = x[0]
    for x1 in x[1:]:

My environment:

python==3.7.5
torch==1.10.1
balbasty commented 1 year ago

Hi,

Many thanks for reporting this. Could you please try the branch fix-issue-3 and tell me if it fixes things for you (it works locally for me). I'll merge in master if it does.

Cheers Yael