shawntan / scattermoe

Triton-based implementation of Sparse Mixture of Experts.
Apache License 2.0
170 stars 13 forks source link

pytest fail #3

Closed Eutenacity closed 6 months ago

Eutenacity commented 6 months ago

Sorry, i am not familiar with triton.

After pytest. AssertionError happen

../../miniconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl return self._call_impl(*args, kwargs) ../../miniconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _call_impl return forward_call(*args, *kwargs) scattermoe/mlp.py:83: in forward h = self.experts( ../../miniconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl return self._call_impl(args, kwargs) ../../miniconda3/envs/dsmii/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _call_impl return forward_call(*args, kwargs) scattermoe/parallel_experts.py:142: in forward results = ParallelLinear.apply( ../../miniconda3/envs/dsmii/lib/python3.10/site-packages/torch/autograd/function.py:539: in apply return super().apply(*args, *kwargs) # type: ignore[misc] scattermoe/parallel_experts.py:14: in forward output = kernels.ops.scatter2scatter( scattermoe/kernels/ops.py:146: in scatter2scatter _scatter2scatter[grid]( ../../miniconda3/envs/dsmii/lib/python3.10/site-packages/triton/runtime/autotuner.py:114: in run ret = self.fn.run(args, num_warps=config.num_warps, num_stages=config.num_stages, kwargs, *config.kwargs) ../../miniconda3/envs/dsmii/lib/python3.10/site-packages/triton/runtime/autotuner.py:232: in run return self.fn.run(args, **kwargs)

:63: in _scatter2scatter ??? ../../miniconda3/envs/dsmii/lib/python3.10/site-packages/triton/compiler/compiler.py:430: in compile fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs)) ../../miniconda3/envs/dsmii/lib/python3.10/site-packages/triton/compiler/compiler.py:253: in make_hash key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}" ../../miniconda3/envs/dsmii/lib/python3.10/site-packages/triton/runtime/jit.py:445: in cache_key dependencies_finder.visit(self.parse()) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = JITFunction(scattermoe.kernels.ops:_scatter2scatter) def parse(self): tree = ast.parse(self.src) assert isinstance(tree, ast.Module) > assert len(tree.body) == 1 E AssertionError need help
Eutenacity commented 6 months ago

i know the resaon. Just update the python version newer than 3.10.10

Eutenacity commented 6 months ago

But another error. Traceback (most recent call last): File "/home/wenxianglin/scattermoe/scattermoe-main/demo.py", line 25, in Y = mlp( File "/home/wenxianglin/miniconda3/envs/lwxtr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/wenxianglin/miniconda3/envs/lwxtr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/home/wenxianglin/scattermoe/scattermoe-main/scattermoe/mlp.py", line 83, in forward h = self.experts( File "/home/wenxianglin/miniconda3/envs/lwxtr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/wenxianglin/miniconda3/envs/lwxtr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/home/wenxianglin/scattermoe/scattermoe-main/scattermoe/parallel_experts.py", line 142, in forward results = ParallelLinear.apply( File "/home/wenxianglin/miniconda3/envs/lwxtr/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(*args, *kwargs) # type: ignore[misc] File "/home/wenxianglin/scattermoe/scattermoe-main/scattermoe/parallel_experts.py", line 14, in forward output = kernels.ops.scatter2scatter( File "/home/wenxianglin/scattermoe/scattermoe-main/scattermoe/kernels/ops.py", line 146, in scatter2scatter _scatter2scatter[grid]( File "/home/wenxianglin/miniconda3/envs/lwxtr/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 114, in run ret = self.fn.run(args, num_warps=config.num_warps, num_stages=config.num_stages, kwargs, *config.kwargs) File "/home/wenxianglin/miniconda3/envs/lwxtr/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 232, in run return self.fn.run(args, **kwargs) File "", line 65, in _scatter2scatter File "/home/wenxianglin/miniconda3/envs/lwxtr/lib/python3.10/site-packages/triton/compiler/compiler.py", line 579, in getattribute self._init_handles() File "/home/wenxianglin/miniconda3/envs/lwxtr/lib/python3.10/site-packages/triton/compiler/compiler.py", line 568, in _init_handles raise OutOfResources(self.shared, max_shared, "shared memory") triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or num_stages may help.

Eutenacity commented 6 months ago

I test on A6000.

Eutenacity commented 6 months ago

update the verion of triton solve all my problem....

shawntan commented 6 months ago

Good to know your problem's fixed. I did test mostly on an A100 and a Titan RTX, so it will be good to know about device issues, but they might largely be Triton problems.

findmyway commented 6 months ago

This might be related https://github.com/openai/triton/issues/1589

I can confirm upgrading Python fixed this issue. Better to add an instruction in README.