titu1994 / tfdiffeq

Tensorflow implementation of Ordinary Differential Equation Solvers with full GPU support
MIT License
218 stars 52 forks source link

Neural ODE to estimate dynamics of forced systems #4

Closed saihv closed 4 years ago

saihv commented 4 years ago

I've been experimenting with the neural ODE demo by trying more complex systems inspired by your ode_usage.ipynb. For example, if I wanted to train an ODE for an inverted pendulum on a cart, under a certain force F that's applied at time t_k, but in a way that it generalizes over any value of force, I'd have to also pass the value of force to the NN but I wouldn't want F in the full system state (i.e., the integrator doesn't have to care about F). What's a good way to achieve this?

titu1994 commented 4 years ago

I don't particularly know a good way to accomplish that. The model function can use a small NN to approximate a value of F itself to pass into the ODE, and if the loss is sufficiently low, then you could make the integrator invariant to external F but not sure how well that would work c

saihv commented 4 years ago

The model function can use a small NN to approximate a value of F itself to pass into the ODE

Sorry, I may not have made myself clear - I already know what F is, and I can use this knowledge at runtime. It's only the dynamics that I'm treating a black box; and instead of x_dot = F(x), I want to parameterize it as x_dot = F(x,u). At test time, both x and u will be known. I am just curious about the implementation of this because while I want the value of u to be used by the ODEFunc() neural network, but I don't want it to be a part of the state going into odeint().

titu1994 commented 4 years ago

If you want to parametrize x_dot = F(x, u), then you can pass on your known x_t and u_t value as inputs to an NN (or any other explicit model) that emits the output of just x_dot.

If u_t is changing, but you don't want it to depend as a variable of the system itself but as an external variable, you can have an explicit formulating of u_t = U(t) or U(x, t) or any other such combination as an external method and simply call that to compute u_t at every step without passing it as a state variable to odeint.

If you look at the LotkeVolterra eqns, if we make the 4 variable as functions of time then we can modulate all 4 wrt time even though they aren't explicitly being passed to the odefunc

saihv commented 4 years ago

Right, that makes sense. Just a followup question though: when we do a forward pass through the model, the call() function receives a batch of states, but t seems to be a scalar? How do the received states correspond to t?

titu1994 commented 4 years ago

When we do call, it receives a single timestep t and state y_t as input. Y_t is a tensorflow vector of length = number of stated passed to input state y0 inside odeint

saihv commented 4 years ago

Correct me if I am wrong, if I have a state vector of length k, and a batch size of m, y_t received by call() will be mxk. And as you mentioned before,

you can have an explicit formulating of u_t = U(t) or U(x, t) or any other such combination as an external method and simply call that to compute u_t at every step

If the state variables are functions of time (or need to be augmented with some time dependent input), I would need the corresponding 'times' for each of those states to modify them accordingly.

titu1994 commented 4 years ago

There's no batch operations inside of a odeint call. Your state gets built per timestep, and the function evaluated fixed/dynamic number of times until the acceptance criteria is met to take a single step v

After T (total number of timestep) accepted steps, the result is (TxK). If you want an NN to perform ode calculations within a layer, rather than use an NN inside an ODE, then yes, you can pass a tensor of states as input (see ODEModel and ConvODE models in the models portion of the lib)

titu1994 commented 4 years ago

Oh one way you could do batches is as you said, have a state matrix of shape Y_t = (M, K) and at each step t, you will get a state an MxK matrix as state. This will be performed T times, so you get (T, M, K) tensor as your result

saihv commented 4 years ago

So what does the batch_size parameter do? I was confused about the batch supposedly sent to the network, as received from get_batch() (https://github.com/titu1994/tfdiffeq/blob/master/examples/ode_demo.py#L44)

titu1994 commented 4 years ago

Look at lines 27 to 40 of that example. You pass (1, 2) sized vector as input, and get an output of shape (data_size, 1, 2) out. That's the general execution of the odeint.

Then later, you can perform a batch computation of (M,2) to compute via the odefunc model to get out (M,2) and this is run for T accepted steps to get out (T,M,2). Here T is data size, M is batch size.

Then you minimise the loss wrt the ground truth which is also of shape (T,M,2).

Basically at every forward call, you get state at time step t only. Doesn't matter if your state is a single value, a vector, a matrix or even a full tensor itself. As long as output is same shape as state shape, you will get an output (T, ...) Where ... Represents the number of dimensions of the state

saihv commented 4 years ago

Then later, you can perform a batch computation of (M,2) to compute via the odefunc model to get out (M,2)

Right, that's what I wanted to clarify. So if I had M states (which is what happens now in that script, as batch size is set to 20 by default, the call() function is receiving a 20x1x2 tensor), and I wanted to augment each of those 1x2 state vectors with a time dependent u value (making the batch 20x1x3), I would need to know which time steps all those states came from - information that is not part of t, hence my confusion.

I'm only concerned with the input I am sending to the neural network, because the output of the model would still be 20x1x2, so odeint() shouldn't have any issues integrating it over T steps and computing the subsequent loss. But I think I'd need a vector of t going to call(), basically like what s is doing here

titu1994 commented 4 years ago

Do one thing. Print out the shape of y inside call. For this problem, all M are independent from each other, as one would expect in a batch. But they globally work on t.

There is no explicit t_t for every single m_t. There is only one global t for the entire state. If you need to make time dependent variable behaviour you need to augment the state with another state variable whose value increments by a predefined step size (constant value when integrated gives constant * x). Basically your state becomes (20, 3) wkth the last axis having different step counters for each of the (20,2) other states.

saihv commented 4 years ago

Do one thing. Print out the shape of y inside call. For this problem, all M are independent from each other, as one would expect in a batch. But they globally work on t.

y.shape for that example is [20, 1, 2]. But I think I see it now, the states being passed into the batch are not really representing x(t_1), x(t_k) etc. but rather all corresponding to a single time instant t.

As an example, what I want to replicate is this (taking the Lotke-Volterra system for example): x_dot = A*x + u, where u is [1, 1] for t < 2.0, and [0, 0] for t > 2.0. To infer this u-x relationship, my thought was to augment y with another row, where each element in this new row is: 1 for t < 2.0, 0 for t > 2.0 so the network can figure out how x_dot relates to x. But I can see that probably won't work because of the way t is treated. I will think about it a bit more.

titu1994 commented 4 years ago

That's a discontinuous function. I think by definition they can't be integrated at all (I'm probably mistaken).

Or they can be integrated, but need backward solvers like BS3 cause they are stiff ODEs

saihv commented 4 years ago

That may have been a particularly bad example, you make a good point :)

But yes, if you look at a cartpole system, the input force typically shows up as an additive term to the second derivative of the position. I guess the applicability of ODE solvers to an actuated system like that is interesting.

titu1994 commented 4 years ago

I have seen some demos of cartpole being solved using an odesolver via language level autodiff in Julia. Practical speaking, it should be possible to do so even in TF, though the way to do it may be more involved.