namedtensor / notation

108 stars 5 forks source link

Derivatives #45

Closed davidweichiang closed 3 years ago

davidweichiang commented 3 years ago

Do we want a section on how to take derivatives with respect to tensors? It ought to be easy compared to matrices.

srush commented 3 years ago

That's really interesting. In my experience, students find this notational ambiguity very confusing:

image

Do you have thoughts of how our method improves upon this? Might be interesting to do convolution as an example as it has a simple but non-trivial derivative.

davidweichiang commented 3 years ago

Yeah that table is exactly what I want to forget. Tensors let you treat the entire table uniformly and also fill in the lower half (which is simply missing because those derivatives are order-3 or 4 tensors). I’m trying to write something up; meanwhile you might be interested in

https://papers.nips.cc/paper/2018/file/0a1bf96b7165e962e90cb14648c9462d-Paper.pdf

srush commented 3 years ago

This is extremely cool. I just went through some examples and convinced myself that many of the annoyances go away. Broadcasting is particularly nice here for derivatives of matrix products, the chain rule also becomes extremely clear.

Another thing that occurs to me is that Matrix Calculus is often described using traces which are quite counter-intuitive, e.g. in https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf

However terms like this really are just very awkward ways are writing contractions.

image

davidweichiang commented 3 years ago

What do you mean by "Broadcasting is particularly nice here for derivatives of matrix products"?

What does your chain rule look like -- does it use the two-name contraction operator?

image

The annoyance I'm running up against is that the derivative of a function from m axes to n axes is a function from m axes to m+n axes, and you have to choose names for the m new axes in the return type (that's what the primes are in the equation above). I'm trying to see if it can mostly be swept under the rug.

davidweichiang commented 3 years ago

If the function outputs a scalar then you don’t need two names. But in general you do...

srush commented 3 years ago

Right, I see what you mean now. I think your way is the best we can do. This seems like a real instance of duality and so we should use this approach to handle it.

srush commented 3 years ago

Actually now I'm curious how you write \del_X X \ndot{ax2} Y ? for X \in R^{ax1 x ax2} and Y \in R^{ax2 x ax3}

The type would need to be R^{ax1' x ax2' x ax1 x ax3 } , so Y doesn't work right? So |Y|_{ax2->ax2'} ? Or do you need to explicitly add the extra dimension so you can do the contraction.

davidweichiang commented 3 years ago

It would be I_{ax1,ax2} \ndot{ax2} Y, where the first term is like the identity matrix from ax1 to ax1' and from ax2 to ax2'.

The rules used to get this are:

image image
srush commented 3 years ago

Nice, that makes sense.

srush commented 3 years ago

One last random thought. I kind of like the "differential form" style in the matrix cookbook https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf and in the "Identities in differential form" section on wikipedia.

image

It should play nicely with what you did above, and it has a couple of benefits:

1) Easier to work with multivariate functions. I.e. for a function like conv1d you first just compute the differential and then read off the derivatives wrt W, X, b . 2) Less \frac in the derivations. 3) When differentiating a single function, it removes any thoughts about " I_S" (which I find a bit complex to think about). These come in at the end.

I think to make this work you would just define a conversion rule like this. Something like

dY = A \ndot{ax} dX => dY / dX = I_{S_x} \ndot{ax} A

should work?

image

davidweichiang commented 3 years ago

OK let's think about this. I hadn't thought before about advantage (1). Number (2) is not a big deal; if we don't like fractions we could use some other notation like D_x. Number (3) is nice (and I agree that the I_S is annoying), but maybe we should practice a bit to see how easy/difficult it is to get an expression into the form A \ndot{ax} dX.

The main con I can think of is that it might be less familiar because it's different from how derivatives are done in high-school / freshman calculus. And it's also less necessary because tensors are less clumsy here than matrices/vectors.

srush commented 3 years ago

Yeah I agree it is a little weird and different (and that 2 is silly). I don't have a strong preference yet.

I think the part that I found a bit hard compared to high-school calculus with the current approach is keeping around both ax' and ax axes in my head when differentiating a function. In differential form, the dX retains its type (it's just an X variable) and you only have n axes in scope. The ax' only come in when converting to a jacobian.

davidweichiang commented 3 years ago

I think the canonical form would have to be dy = A \ndot{S'|S} dx where S is the shape of x. Then dy/dx = A.

Some special cases:

In practice, isn't the current approach not that different?

image

In this derivation, there aren't any primes or I_s's; they come in at the very end, if you let u = x so that dx/dx = I_s.

I feel like even though the various derivatives in the above derivation have an extra ax' axis, you don't have to think about them because the calculation works out the same with or without them, until that final step.

srush commented 3 years ago

Yes, I agree this is very nice derivation.

Also Keeping x abstract prevents you from overthinking the prime axes.

srush commented 3 years ago

I'm still struggling a bit to apply the I's and \delta's in practice. I think this is hard in standard notation too, but I keep feeling like there should be a nice way to do it.

Here's an example. If I want to prove that the backward of a conv1d is also a conv1d. I think I do something like this? However I feel like I am missing a step to show how the indexing changes when you contract the " seq|seq' " with an I .

image

davidweichiang commented 3 years ago

It looks right to me, and I think we'd need a few missing pieces:

I haven't thought much about backwards-mode differentiation, and it does look like the primes pop up a lot more here. Do differentials help?

srush commented 3 years ago

That all makes sense.

I guess the standard notation for backwards-mode differentiation is to use adjoints, which would make the above superficially simpler.

image

I don't really mind the seq' though, they are actually pretty intuitive. It's more contraction over indexing that I find hard to reason about. This feels like a "generalized transpose" and it makes my head hurt.

davidweichiang commented 3 years ago

If I write out all the steps, I get:

image
srush commented 3 years ago

Oh wow, thanks for writing that out, it was bothering me. I get what you mean by expanding the sum now. I don't really see any steps that can be simplified. I was hoping there might be a way to do the last three steps in a more general way, but it really seems like you need to do the re-indexing with manual indexes.

Generally, I think (besides the rearragement) this is a success? This would have been really painful with matrix calculus as there are minimally 3 critical axes before differentiation. Also seq' and seq are helpful in the V calculation. Even when we lose contractions, the typed indexing still helps with bugs.

Couple thoughts:

davidweichiang commented 3 years ago

How about this?

image

About I^S versus I_S, this could be thought of as an instance of #22.

srush commented 3 years ago

Nice! That's what I was hoping for.

davidweichiang commented 3 years ago

Great. I pushed a commit that includes this, except that Conv1d maps from seq to seq instead of in to out. Not sure if you will like what I did to Conv2d.

davidweichiang commented 3 years ago

I'm not very happy with the fact that the definition of derivative imposes a naming convention for the input axes (stars). It might be nicer if the choice of input axis names was up to the writer.

The idea is to define derivatives only for functions that don't have any shared input/output axes:

image

Note that there are no stars, so this definition is much simpler than the one in the current document.

If someone wants to differentiate a function with shared input/output axes, they have to explicitly rename the input axes. Define

image

Notes:

The rule for dx/dx becomes

image

Example:

image

I have more thoughts but will leave it at that for now!

srush commented 3 years ago

Just out of curiosity, if you are going with this strategy, why not encourage people to rename the out axes i.e. softmax : R^ax -> R^ax'? that would fit the convention.

davidweichiang commented 3 years ago

I can think of two reasons:

davidweichiang commented 3 years ago

Continued: I think differentials will help here.

The canonical form is: dY = U \ndot{T} dX{S->T} implies dY/dX{S->T} = U, where shape T is orthogonal to the shape of Y.

Example:

image

This simplifies more aggressively than the example that's currently in the document, so I don't think it's truly longer. The nice thing is that there are no stars or renaming.

srush commented 3 years ago

That last line should be df(Y) / dX right?

davidweichiang commented 3 years ago

Oops, yes.

srush commented 3 years ago

And just to understand. If we did want to directly access the jabobian of the first function, it we would need to have a some sort of new name. We would need to rename the dx in order to apply the canonicalization rule.

davidweichiang commented 3 years ago

Right.

srush commented 3 years ago

I think this is neat. I like that you can think of dx as typed like x (until the end), that was the thing I originally found quite appealing about differentials. Also that it conveniences the scalar / backprop case which is the main use for our target domain.

Two small thoughts that come up:

davidweichiang commented 3 years ago

We could write some rules for regrouping of contractions, but it really boils down to writing contractions as sums of \odots, so I started to think that that was the better way to reason about it.

I should update the current document to fully simplify the backprop rule for softmax, for comparison. Even though it uses the starred names, I think the two names might actually make it easier to think about. (In the above example, the step from line 4 to 5 of \partial f(Y) is really tricky because of the two contractions over ax.)

davidweichiang commented 3 years ago

main is updated. The derivations are roughly the same length, I think...

srush commented 3 years ago

Sorry, don't totally understand. Are you now back to two names? Or are we talking about algebra.

davidweichiang commented 3 years ago

I'm trying to make an apples-to-apples comparison...the derivation above in this thread looked long to me, so I wanted to see if it was because it simplifies more. Now the two derivations arrive at the same answer so they can be compared fairly. I'm not sure which I like better -- I'll have to think about it.

davidweichiang commented 3 years ago

Oh, with differentials, you can rename the numerator as you suggested. That's a little easier to think about and would give a more consistent naming convention (for first derivatives, at least). I'll try that out and see how it looks.

davidweichiang commented 3 years ago

Trying this out in #50.