bdusell / semiring-einsum

Generic PyTorch implementation of einsum that supports different semirings
https://bdusell.github.io/semiring-einsum/
MIT License
45 stars 7 forks source link

Default block_size #6

Closed davidweichiang closed 2 years ago

davidweichiang commented 2 years ago

Or can the documentation at least recommend a good block_size?

bdusell commented 2 years ago

This question has come up before, and I should finally add a note about it. I purposely required the user to provide a block_size because the right value really depends on their application and resources. It also depends on the number of variables being summed out; if there are n variables, then the size of the block tensor is block_size ** n. When in doubt, a value of 10 is a good place to start.

davidweichiang commented 2 years ago

Following up on this...maybe it would be more user friendly if they could specify that the temporary space usage is some constant number of floats, or some multiple of the size of the inputs and outputs?

bdusell commented 2 years ago

I've now implemented a default block size based on available GPU memory. On CPU, it chooses a block size that does not exceed 1 GiB.

davidweichiang commented 2 years ago

Cool! Is it the case that PyTorch runs only one einsum at a time, so it's safe to use all available memory?

bdusell commented 2 years ago

I know that CUDA kernels can run asynchronously with the Python code, but I'm not sure if kernels are run in parallel with each other. But I've been running this code in experiments for a while and have not yet encountered problems.

Querying the amount of available CUDA memory is actually very slow (it slows down my code by about 200%), presumably because it requires synchronization with the GPU, so the default behavior now is to cache the available memory after the first einsum call, and to use only 80% of it in case the amount of free memory decreases. You still have the option to query it every time, or to select a memory limit yourself. There are more details in the updated docstrings.