locuslab / deq

[NeurIPS'19] Deep Equilibrium Models
MIT License
728 stars 79 forks source link

Broyden defeats the purpose of DEQs? #28

Open polo5 opened 2 years ago

polo5 commented 2 years ago

Heya,

Thanks for your continued work in building better DEQs.

The main selling point of DEQs is that the solver can take as many steps as required to converge without increasing the memory. This isn't true for your implementation of broyden, which starts off with:

Us = torch.zeros(bsz, total_hsize, seq_len, max_iters).to(dev)
VTs = torch.zeros(bsz, max_iters, total_hsize, seq_len).to(dev)

and therefore has a memory cost linear with max_iters, even though the ops aren't tracked. Anderson also keeps the previous m states in memory, where m is usually larger than the number of solver iterations needed anyways. Don't those solvers contradict the claim of constant memory cost?

On a related note, I've found it quite hard to modify these solvers even after going over the theory. Is there any notes or resources you could point to to help people understand your implementation? Thanks!

jerrybai1995 commented 2 years ago

Hello @polo5,

Thanks for your interest in our repo and DEQ!

To begin with, we want to caution that "constant memory cost" is constant w.r.t. the number of layers. That is, we do have one layer only (e.g., one Transformer layer), and the memory consumption is not that of 2, 3, etc. layers. That said, you are absolutely right that Broyden or Anderson both needs to store some past fixed point estimates. In fact, we analyzed this "hidden cost" in the Jacobian regularization paper (see Sec. 3.4).

However, we do want to point out that this is only very minimal memory cost, because we only need to store the activation tensor. For example, in the DEQ-Transformer case (which is pretty a large-scale model), the hidden units could have shape [bsz x seq_len x n_hid] = [15 x 150 x 700]. This is 15*150*700 = 1575000 floating numbers in total, and thus 1575000*4 = 6300000 bytes (each float is 4 bytes). This is only 6.3MB (or 0.0063GB) memory per solver iteration. Therefore, even if we run Broyden through 10-20 iterations, it adds only very little memory cost in itself.

However, conventional neural networks are costly because the layers are complex (each layer could cost hundreds or thousands of MBs). For example, in a Transformer layer, not only do we have to memorize what the output is (which is what DEQ only needs), but also everything that happened within the layer--- e.g., the activation after self-attention; that after LayerNorm, etc.


On your second question, my Broyden's method implementation is solely based on its wikipedia page: https://en.wikipedia.org/wiki/Broyden%27s_method (look at "good Broyden's method").

However, in order to make it GPU-efficient, there are two things that potentially made this resemblance a bit obscure to see, and I'm happy to explain:

1) Note that in the good Broyden update formula, we can write the update in terms of J^{-1}' = J^{-1} + uv^T, where u is the \frac{\delta x_n - ...}{\delta x_n^T ...} part (a dx1 vector), and v^T = \delta xn^T J{n-1}^{-1} (a 1xd vector). Moreover, I initialized J_0^{-1} to be -I (i.e., negative identity matrix). Therefore, for any n, I can essentially write J_n^{-1} as: J_n^{-1} = -I + u_1 v_1^T + u_2 v_2^T + ...

This means that instead of actually keeping a large matrix J^{-1} around in the memory, I can simply store the u vectors and the v vectors. This amounts to keeping a big U matrix and a bid V matrix whose columns are u1,u2,... and v1,v2,... In other words, we can write J_n^{-1} = -I + UV^T. At each Broyden iteration, I append the new u and v to the matrix column, which is here. So after L steps of Broyden iterations, U is of shape dxL and V^T has shape Lxd.

2) Note that in both Newton and Quasi-Newton methods, we don't actually need the J^{-1}. What matters is J^{-1} g(x), where g(x) is a vector (recall that the update rule for solving g(x)=0 essentially has form new_x = x - \alpha * J^{-1} g(x)). Therefore, together with the (1) I mentioned above, we can further write J_n^{-1} g = (-I + UV^T)g = -g + UV^T g.

Since U has dimension dxL, V^T has dimension Lxd, where L is the # of past Broyden steps, and g has dimension (dx1), it is much more efficient to compute UV^T g by U(V^T g)--- because V^T g is simply a matrix-vector product. This is important especially when the dimension d is large. This is therefore this step. The matvec operation is the key to making this efficient.

Similarly, the update rule itself contains things like J_{n-1}^{-1} \delta f_n, which can be computed in a similar fashion by computing V^T \delta f_n first, and then U (V^T \delta f_n).

I hope this answers your question!

jerrybai1995 commented 2 years ago

Also, I want to add that in Anderson we usually keep m=5 or m=6, which is usually significantly smaller than the number of solver iterations (e.g., 25 in DEQ-Transformer).

polo5 commented 2 years ago

Thanks a lot @jerrybai1995 ! Your explanations are very clear and I understand your broyden code much better :)

  1. I take your point about solver memory. In my setting I do have large activations which makes those U and V matrices large for broyden. However I've found that when f(x) is normalized just right you can get away with very few solver iterations, and the gradient remains useful even if you're still pretty far away from reaching an actual FP. I found the same thing for MDEQ on Cifar-10 experiments a while back (you can simply use a fixed-point-iteration solver for this task with one iteration and reach the same test acc as broyden with 15 iterations).

  2. The reason I'm asking about broyden is that I've been testing its robustness for different settings e.g. g(x) = |f(x) - x| + |\phi(x)| instead of g(x) = f(x) - x, by simply changing the line here. Since the performance was quite poor (doesn't converge) I was wondering if I needed to change something else in the broyden code. I don't think I do after your explanations, so it's probably just too hard a problem for broyden...

  3. I would try anderson instead but it seems to me this solver is only applicable for fixed point problems of the form g(x)=f(x) - x. Is that correct? Some quick ressources/pointers to understand the use of the H matrix here would also be greatly appreciated, as the explanations I could find here were sadly very light. I would definitely need to change more code in anderson for my goal, as G = F - X is hard coded in a few places. Btw I'm surprised you didn't use line search for anderson as well, although perhaps you have and it didn't help.

Apologies for abusing your kindness:)

P

jerrybai1995 commented 2 years ago

Hi @polo5,

  1. Interesting observation on MDEQ...! I didn't know that you can achieve the same accuracy on CIFAR-10 with just 1 iteration but it's likely closely related to the inexact/phantom gradient effect as described in this paper. It turns out that you don't always need the exact IFT gradient. (But personally I would guess that one Broyden iteration is far from enough on a large scale, like in ImageNet.)

  2. There is a slight difference and I'm not sure how much it matters. By using Broyden's method, we are NOT minimizing |f(x)-x|. We are finding the root of f(x)-x. Similarly, in your case, if you want to minimize |f(x)-x| + |\phi(x)| (which is strictly non-negative), you'd better: i) check that a minimum exists; and 2) use an optimization procedure like L-BFGS or Newton.

  3. See my answer to (2) above. It is a different problem per se. Anderson and Broyden are for root finding. If your goal is to minimize some function g, then the root equation you want to solve is \nabla g(x) = 0 (i.e., KKT = 0). Regarding the use of H, that is simply the linear regression closed form solution. Recall that in order to minimize ||G \alpha|| for some \alpha, you do (G^T G)^{-1} .... In some cases, G^T G is only PSD but not PD, so we introduce a small \lambda so that G^T G + \lambda I is invertible. Nothing too fancy here. Anderson worked very well without line search because the algorithm itself is a greedy one that minimizes the linear combination of the past few fixed-point estimates.

Hope this helps.

Gsunshine commented 2 years ago

Hi @polo5 ,

In addition to @jerrybai1995 ,

  1. Yes. For MDEO-Tiny, different inexact gradients can work properly. But for the DEQ-transformer, it would be much harder to directly employ the inexact gradients. (Instead, we need more delicate methods like regularizing the training by combining phantom grad and Jacobian Regularization. See the paper.) It's a problem of the scaling law that I am recently thinking about, i.e., for different model designs and different model sizes, the approximation, or even the gradient noises from the solver itself (you're solving the gradient and the solver naturally contains noises!) can have pretty different impacts. (Consider a Pareto surface of performance, model size, and gradient exactness.)

    More observations include that you can even train the MDEQ-Large using a truncated Broyden method in the backward pass when keeping the forward pass Broyden iterations untouched. But you have to keep the pretraining stage for MDEQ, otherwise, this simple backward solver truncation would not work. We've tried to provide a theoretical analysis of Broyden's truncated inexact gradient but found it not easy because it might not always contribute to a descent direction. Instead, any order of phantom gradients can be the descent direction, including considering the damping factor.

  2. Agree with @jerrybai1995 . In addition, you may try to rewrite some forms of your \phi(x) into f(x) as a new F(x) instead of directly adding it. For example, if you are optimizing a LASSO problem, rewriting the regularization term \phi(x) will lead to the proximal gradient method (i.e., projecting the results back) rather than the original gradient update rule f(x). That is to say, we define a new fixed point function F(x) to iterate after considering the sparsity. Note that you can treat the g(x) = f(x) - x function as a gradient function of an energy function because we are trying to solve for the root of g(x), i.e., the stationary point of the `invisible'' energy landscape. (Of course, in most cases, it's definitely not easy to recover the energy fromf(x)if you only know thef(x).) But if you happen to know that and wanna add a new term to the energy, try to add its **gradient** to the DEQ function or derive a newF(x)` based on it. I think this might help.

In my setting I do have large activations which makes those U and V matrices large for broyden.

I wonder how large the activation could be to prevent using the Broyden solver. In some cases, after properly normalizing the DEQ function f(x) using weight norm or spectral norm, the naive fixed point iteration can work pretty well (I mean the convergence). But do pay attention to the training if you are employing the IFT training combined with the naive forward solver because the IFT will make the network ``deeper'' and needs more solver steps to reach a reasonably low error to keep the training stable, but the naive solver is actually less efficient to solve the hard dynamics. Using the naive solver with the phantom gradient can extricate you from this stability problem but may suffer from slightly sacrificing the performance (for the DEQ-transformer case).

If you encounter some stability issues under IFT+naive forward solver, consider inexact grad, Jacobian regularization, fixed point correction, or their combination to handle the naive forward solver. DEQ-Flow might be a helpful reference for this as we successfully train the DEQ to a satisfying forward pass convergence and quite strong performance using the naive solver, phantom grad, and fixed point correction compared to the one trained solely by IFT or a very expensive dense BPTT. And this method scales up well to larger models. See the latest implementation.

Hope this might help.

Zhengyang

polo5 commented 2 years ago

Great points @jerrybai1995 @Gsunshine !

  1. I agree that root solving and minimization aren't the same thing here, but the line can be quite blurry in some problems. In the vanilla DEQ setting we find the root x* of g(x) = f(x) - x but of course this is equivalent to x* = argmin ||g(x)|| . In fact I can train my model fine by using basic gradient descent on this objective, but the convergence error is usually much larger than using broyden in my implementation (@jerrybai1995 I should try L-BFGS though thx). Turning a root solving problem into an optimization problem can often make things easier. In my case I want to do root solving for g(x) = f(x) - x given some regularization/argmin conditions on the root x = argmin \phi(x) (which only makes sense because there exists several roots x* in DEQs). This can easily be written as an optimization problem (using Lagrange multipliers). But in practice since broyden is so good at converging fast I'm trying to cast my objective as an efficient root solving problem. @Gsunshine in this case one cannot solve the root of g(x) = f(x) - x + \nabla \phi(x) (which could otherwise be nicely rewritten as a FP problem as you suggest) because these roots aren't roots of f(x) - x. Instead one would need to solve roots for g(x) = |f(x) - x| + |\nabla \phi(x)| (assuming phi(x) has a single extrema). Annoyingly this isn't a FP problem anymore.

  2. The second issue here is that solvers like broyden seem to struggle with absolute terms. The function g(x) = |f(x) - x| has exactly the same roots as g(x) = f(x) - x, but broyden struggles to find roots in the former (which I guess is what you'd expect from simpler solvers like bisection where the sign is informative?). This can be annoying in some DEQ variants where one may want to find simultaneous roots of two functions f(x) and h(x). As in (1), you cannot write g(x) = (f(x) - x) + (h(x) - x) but instead you'd need something like g(x) = |f(x) - x| + |h(x) - x| which broyden would struggle with.

Thanks a lot for the discussion! I think DEQs are very promising and don't get the attention they deserve ;)