srush / Tensor-Puzzles

Solve puzzles. Improve your pytorch.
MIT License
3.19k stars 269 forks source link

Add error message for shape mismatch #18

Closed eatPizza311 closed 4 months ago

eatPizza311 commented 1 year ago

Address for #14 , catch the shape mismatch before PyTorch throws RuntimeError. Returns an error message as below:

AssertionError: Two tensors have a different shapes
 Spec: 
    Expect: torch.Size([5]) 
    Got: torch.Size([1, 5])
srush commented 1 year ago

Sorry, this unfortunately breaks broadcasting. You could try catching the runtime error.

eatPizza311 commented 1 year ago

Sorry, I didn't notice this would break broadcasting. I've changed the assertion only for this particular runtime error.