kakaobrain / trident

A performance library for machine learning applications.
https://www.kakaobrain.com
Apache License 2.0
178 stars 11 forks source link

Can't import latest version of Trident #159

Closed hypnopump closed 12 months ago

hypnopump commented 1 year ago

🐞 Describe the bug

Can't import the latest version of triton after following the installation steps, on a linux machine with an NVidia T4 GPU and CUDA version 12.0.

🧑‍🏫 Reproduction

To reproduce:

conda create -n python310 python=3.10 --yes
conda activate python310
pip3 install ipython torch torchvision torchaudio
git clone https://github.com/kakaobrain/trident
cd trident 
bash install_package.sh
cd ..
ipython
>>> import trident as tr

🎯 Expected behavior

Importing it should just work

🖼️ Screenshots

Here's my stacktrace: And the error i'm getting is:

ipython
imPython 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.16.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import trident as tr
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 1
----> 1 import trident as tr

File ~/trident/trident/__init__.py:15
      1 # Copyright 2023 ⓒ Kakao Brain Corp.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from . import function, kernel, util
     16 from .config import *
     17 from .module import *

File ~/trident/trident/function/__init__.py:15
      1 # Copyright 2023 ⓒ Kakao Brain Corp.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from .function import *

File ~/trident/trident/function/function.py:21
     17 from typing import Optional, Tuple, Union
     19 import torch
---> 21 from trident import operation
     24 def argmax(input: torch.Tensor, dim: int):
     25     """
     26     Returns the indices of the maximum value of all elements in an input.
     27     """

File ~/trident/trident/operation/__init__.py:15
      1 # Copyright 2023 ⓒ Kakao Brain Corp.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from .argmax import *
     16 from .attention import *
     17 from .batch_norm import *

File ~/trident/trident/operation/argmax.py:20
     17 import torch
     18 import triton
---> 20 from trident import kernel, util
     23 class Argmax(torch.autograd.Function):
     24     @staticmethod
     25     def forward(ctx: Any, *args: Any, **kwargs: Any):

File ~/trident/trident/kernel/__init__.py:15
      1 # Copyright 2023 ⓒ Kakao Brain Corp.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from .argmax import *
     16 from .attention import *
     17 from .batch_norm import *

File ~/trident/trident/kernel/argmax.py:19
     15 import triton
     16 import triton.language as tl
---> 19 class Argmax:
     20     @staticmethod
     21     @triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
     22     @triton.jit
   (...)
     32         require_x_boundary_check: tl.constexpr,
     33     ):
     34         y_offset = tl.program_id(0)

File ~/trident/trident/kernel/argmax.py:23, in Argmax()
     19 class Argmax:
     20     @staticmethod
     21     @triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size
     28         y_stride: tl.int32,
     29         x_stride: tl.int32,
     30         dtype: tl.constexpr,
     31         x_block_size: tl.constexpr,
     32         require_x_boundary_check: tl.constexpr,
     33     ):
     34         y_offset = tl.program_id(0)
     36         output_block_ptr = tl.make_block_ptr(
     37             output_ptr,
     38             shape=(y_size,),
   (...)
     42             order=(0,),
     43         )

File /opt/conda/envs/python310/lib/python3.10/site-packages/triton/runtime/jit.py:542, in jit(fn, version, do_not_specialize, debug, noinline, interpret)
    534         return JITFunction(
    535             fn,
    536             version=version,
   (...)
    539             noinline=noinline,
    540         )
    541 if fn is not None:
--> 542     return decorator(fn)
    544 else:
    545     return decorator

File /opt/conda/envs/python310/lib/python3.10/site-packages/triton/runtime/jit.py:534, in jit.<locals>.decorator(fn)                                    [0/569]
    532     return GridSelector(fn)
    533 else:
--> 534     return JITFunction(
    535         fn,
    536         version=version,
    537         do_not_specialize=do_not_specialize,
    538         debug=debug,
    539         noinline=noinline,
    540     )

File /opt/conda/envs/python310/lib/python3.10/site-packages/triton/runtime/jit.py:431, in JITFunction.__init__(self, fn, version, do_not_specialize, debug, noinline)
    429 self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()}
    430 # index of constexprs
--> 431 self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
    432 # launcher
    433 self.run = self._make_launcher()

File /opt/conda/envs/python310/lib/python3.10/site-packages/triton/runtime/jit.py:431, in <listcomp>(.0)
    429 self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()}
    430 # index of constexprs
--> 431 self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty]
    432 # launcher
    433 self.run = self._make_launcher()

TypeError: argument of type 'dtype' is not iterable

In [2]: import triton.language as tl

In [3]: tl.load
Out[3]: <function triton.language.core.load(pointer, mask=None, other=None, boundary_check=(), padding_option='', cache_modifier='', eviction_policy='', volatile=False, _builder=None)>

💻 Requirements

💬 Additional context

If you could share either a conda yaml or a requirements.txt that would make it just work, that would be awesome.

daemyung commented 1 year ago

@hypnopump Could install Trident again? There is no issue on my machine.

bc-user@instance-6285:~/projects/trident$ ipython
Python 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.12.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import trident as tr

You can add Trident in requirements.txt.

trident@git+https://github.com/kakaobrain/trident.git@main
hypnopump commented 1 year ago

Fails with the same error. Here's my conda env YAML file in case it's useful for reproduction. my_conda_env.yaml.txt

I have also tried a machine with a V100 GPU in case it was the compute architecture or so and it still fails

daemyung commented 1 year ago

@ansteve @mejai1206 Could you try reproduce this issue?

hypnopump commented 1 year ago

Found a solution:

# follow same steps at the start
conda create -n python310 python=3.10 --yes
conda activate python310
pip3 install ipython torch torchvision torchaudio
pip install git+https://github.com/kakaobrain/trident.git
ipython
>>> import trident as td
# this will raise an error

And now, update triton:

pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
ipython
>>> import trident as td
# this works

Probably it's just a bad commit of triton being installed by default. Could we try to fix the setup.py and install_package.sh to point to working versions?

daemyung commented 1 year ago

@hypnopump That's very strange. Because the latest of Triton is installed while installing Trident. It means that the installed Triton is more latest version than daily build. I guess somehow the old version of Triton is installed in your conda enviornment. Look at Trident's Actions. Trident is tested with the latest Triton.