crystal-data / num.cr

Scientific computing in pure Crystal
MIT License
151 stars 12 forks source link

Support grad backprop when add/sub use broadcast #87

Open nogginly opened 1 year ago

nogginly commented 1 year ago

@christopherzimmerman, this fixes grad backpropagation for addition/subtraction when broadcast occurs because one operand has a different rank than the other.

I've included tests, which are based on results I get from PyTorch.

christopherzimmerman commented 1 year ago

@nogginly thanks for the PR, going to need a bit longer to review this since when I implemented this I guess I made some different design decisions around broadcasted operations than other libraries.

I always set the gradient to match the value that was actually used in the calculation, so in this case, since the broadcast happens during the operation, the gradient will match that shape. Are you saying that Pytorch aggregates that back down to match the dimensions of the initial variable, before the operation?

christopherzimmerman commented 1 year ago

Also, for rank by rank some, there is a "view_along_axis" iterator that's in some version of this code that gives a view into multiple axes that can probably be used to reduce multiple rank sums, I'll look for it.

nogginly commented 1 year ago

I always set the gradient to match the value that was actually used in the calculation, so in this case, since the broadcast happens during the operation, the gradient will match that shape. Are you saying that Pytorch aggregates that back down to match the dimensions of the initial variable, before the operation?

Hola @christopherzimmerman. Yes, that is correct, I wrote a simple test using PyTorch and got exactly that. I ran into this as I was implementing a two-layer MLP and when I tried to update the the biases using the gradient, the shapes were off and matmul didn't work and that is how I discovered this. The tests I put in are based on those PyTorch test cases.