databricks / megablocks

Apache License 2.0
1.17k stars 169 forks source link

Sum missing axis arg in kernels.py #102

Closed jambo6 closed 5 months ago

jambo6 commented 5 months ago

I get the following error when I try to run bwds pass with moe_expert_model_parallelism=False. E.g. if I run moe_test.py it fails with this error.

    # Reduce to get the final result and store.
    out = tl.sum(acc).to(wgrad.dtype.element_ty)
                 ^
TypeError("sum() missing 1 required positional argument: 'axis'")

My versions

megablocks                   0.5.1
stanford-stk                 0.7.0
triton                       2.1.0
jambo6 commented 5 months ago

Looks like I have some old ver of triton somehow that does not have this commit

https://github.com/openai/triton/pull/1712/files

not sure how as we are on 2.1.0 but anyway

jambo6 commented 5 months ago

In case anyone has this issue, seems to be related to this issue and using a pytorch docker image which does not have the correct version of the code

pharaouk commented 5 months ago

Have you been able to fix this error?

pharaouk commented 5 months ago

Nvm- for future reference: you may have triton==2.1.0 installed from your docker container, but it may not match the latest triton version (they likely updated the pkg). So make sure to add pip install --force-reinstall triton==2.1.0. This will bring your triton version 2.1.0 up to date with the actual version on the openai repo