srush / Tensor-Puzzles

Solve puzzles. Improve your pytorch.
MIT License
2.96k stars 242 forks source link

Upgrade Tensor Puzzlers from torchtyping to jaxtyping. #25

Open davideger opened 6 months ago

davideger commented 6 months ago

jaxtyping is a new, improved, and maintained version of torchtyping by the same author (patrick kidger). When used with beartype, jaxtyping can inform the user of shape mismatch errors at run time.

Other minor formatting issues were fixed in the Tensor Puzzles notebook.

srush commented 6 months ago

Amazing! I was just planning on doing this.

srush commented 6 months ago

Just curious, Why do you need "{j j}"?

davideger commented 6 months ago

My understanding (and @patrick-kidger correct me if I am wrong) is that if you want jaxtyping to ensure that a return type dimension matches a scalar function parameter, you need to use f-string escaping to reference the function parameter. At least that's what I get from interpreting https://github.com/google/jaxtyping/blob/main/docs/api/array.md.