In the sunny light of the future, it's no longer the right choice to write raw CUDA for a library like this. This whole system could be replaced with 1/10th the amount of code by rewriting it in JAX. JAX is basically numpy for GPUs. The new JAX-based cutde (or whatever you want to rename it!) would be:
much simpler
probably faster
cross platform, working on CPU, GPU, TPU all for one piece of code.
I have no intention of ever doing this. I work on other stuff now. But, I thought I'd put this issue here in case anyone stumbles on it!
I could also potentially be paid to make this kind of upgrade.
In the sunny light of the future, it's no longer the right choice to write raw CUDA for a library like this. This whole system could be replaced with 1/10th the amount of code by rewriting it in JAX. JAX is basically numpy for GPUs. The new JAX-based cutde (or whatever you want to rename it!) would be:
I have no intention of ever doing this. I work on other stuff now. But, I thought I'd put this issue here in case anyone stumbles on it!
I could also potentially be paid to make this kind of upgrade.