Tractables / pyjuice

Scalable training and inference for Probabilistic Circuits
https://tractables.github.io/pyjuice/
Apache License 2.0
47 stars 8 forks source link

Triton Illegal Memory Access #17

Closed mjojic closed 1 month ago

mjojic commented 2 months ago

I am running the code from example 01_train_pc.py to learn a HCLT on MNIST dataset. The only thing I have added is a marginal query after the parameter learning. All the code functions properly until the marginal query, which throws me an error.

query: data = torch.rand((28,28)).long().to(device) lls = juice.queries.marginal(pc, data)

error: RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

The sampling query works, but marginal and conditional do not work.

liuanji commented 2 months ago

Hi mjojic,

The PC in that script assumes categorical input variables with range 0-255. So you need to change the dtype of data to torch.long. Also, its shape needs to be [batch size, # variables].

You can refer to this tutorial for more information https://tractables.github.io/pyjuice/getting-started/tutorials/04_query_pc.html#sphx-glr-getting-started-tutorials-04-query-pc-py.

Best, Anji

mjojic commented 2 months ago

Yes, the memory access ended up being a shape issue.

With the correct shape, I ran into an error about the TILE_SIZE < 4 (coming from sum_layer.py block sparse kernel). By setting block_size=1 in the HCLT definition, then I was able to do a marginal call without the tile size error.

liuanji commented 2 months ago

You may increase the batch size to e.g. 16 to avoid the TILE_SIZE error for now. I will fix it in the near future. Thanks!