SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
847 stars 151 forks source link

Documentation of sciml_train does not match actual arg prototype #162

Closed cems2 closed 4 years ago

cems2 commented 4 years ago

Problem: The documentation of scimltrain indicates the 4 args are: (neuralode, param, data, optimizer) But the only available methods have 3 args: _(neuralode, param, optimizer)

Reproducing this: help gives 4 argument required signature:

help?> DiffEqFlux.sciml_train
  train!(loss, params, data, opt; cb)

  For each datapoint d in data computes the gradient of loss(p,d...) through
  backpropagation and calls the optimizer opt. Takes a callback as keyword argument cb.
  For example, this will print "training" every 10 seconds:

  DiffEqFlux.sciml_train(loss, params, data, opt,
              cb = throttle(() -> println("training"), 10))

  The callback can call Flux.stop() to interrupt the training loop. Multiple optimisers
  and callbacks can be passed to opt and cb as arrays.

but this 4 arg signature is not available:

julia> DiffEqFlux.sciml_train()
ERROR: MethodError: no method matching sciml_train()
Closest candidates are:
  sciml_train(::Any, ::Any, ::Optim.AbstractOptimizer; cb, maxiters) at /Users/cems/.julia/packages/DiffEqFlux/UNVXS/src/train.jl:70
  sciml_train(::Any, ::Any, ::Optim.AbstractConstrainedOptimizer; lower_bounds, upper_bounds, cb, maxiters) at /Users/cems/.julia/packages/DiffEqFlux/UNVXS/src/train.jl:95
  sciml_train(::Any, ::Any, ::Any; cb, maxiters) at /Users/cems/.julia/packages/DiffEqFlux/UNVXS/src/train.jl:15
Stacktrace:
 [1] top-level scope at none:0

Discussion: The fact that there is not a "Data" arg for this seems strange and if intentionally omitted then it seems like a design flaw: that is, Evidently one is supposed to provide labels and inputs to the loss function by globals? Yuck! Not only Yuck but it messes with the whole idea of epochs and different sorts of information-- not just the "lablels" but other things like say initial_conditions, constraints, error bars and masks-- that one might want to pass into the loss function that varied from training example to training example in the epoch dataset.

Also As long as I'm brining up documentation issues I wanted to point out that the readme tutorial at DiffEqFLux.lj has some confusing things.

One of the major confusions is that the use pattern for NeuralODE doesn't seem to follow the use pattern for Flux nor it's predecessor neural_ODE. With the latter, one had a neuralnet model and a set of extracted parameters (Flux.params). One could then withiut changing the neuralODE itself, change these parameters in the model. The new pattern is that the NeuralODE seems to be that it holds the parameters. When you display NeuralODE.p it shows the parameters as just a flat list. they are not labeled as param[]. This change in philosophy isn't documented clearly enough to understand how to use it.

For example, when the user was controlling the params, then one could have params that one left out of the optimization, or additional ones added in. But now it seems to be that the only time the params are extracted is when NeuralODE is invoked, which is going to extract the params it self an put them into a flat list, leaving the user kinda lost on how to do something more controlled.

cems2 commented 4 years ago

Example of data one might like to pass in.

struct Datum
   u0
   labels0
   mask
end

function loss(p, data)
  pred =  n_ode(data.u0,p)  # training example initial conditions vary
  sum(abs2, data.mask*(data.labels0-pred) )  # training example labels vary
        # and we add a mask to reweight the points in case our true data has error bars that vary from example to example.
end

As a more complex example, notice the above had the additional yuck of having n_ode be a global too. So it would be nice to also pass in the n_ode in the loss.
And in fact one way to pass it in is to actually package it into the Data. But why you'd want to do that may be more than I can explain here.

There's no need to hardcode Datum in the sciml_train. The only place that cares about Data is the Loss function, so only it needs to know what to do with it. But because training examples vary, it's better to let sciml_train do the bookkeeping of which Datum is being passed to the loss as it rasters through the epoch of all the batches of training data. This is why you want to have a data argument in sciml_train. iterations, batches, and epochs are different things.

ChrisRackauckas commented 4 years ago

You can also use a closure or a call overloaded type. What's the feature that is gained here? That would determine the implementation.

cems2 commented 4 years ago

I'm unclear on what your proposal is here as an elternate solution.

in #162 I note that the loss function as currently implemented in sciml is following a different input arg pattern than FLux.train! loss function. there one passes in the Data. in this case you would want to pass in both p and data.

if you don't then Loss(p) has to get access to Data using globals. In fact it also needs to access n_ode too (to do the prediction), so that's yet another global loss() needs. So rather than passing in "p", one might want to pass in n_ode::NeuralODE which has a field n_ode.p accessible.

You can see I'm espousing avoiding globals.

In the case when Data is list of batches then one would like Sciml_train to work it's way through the batches like they were minibatches. that's different than iterations.

Ergo passing in Data is a good idea I think

On Feb 17, 2020, at 8:23 PM, Christopher Rackauckas notifications@github.com wrote:

You can also use a closure or a call overloaded type. What's the feature that is gained here? That would determine the implementation.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/162?email_source=notifications&email_token=ACRAR7QKXEE3E6B3MRRRJLDRDNIB3A5CNFSM4KW3OWXKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEMAODMA#issuecomment-587260336, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACRAR7XY5RCCOXBLGYR4J7TRDNIB3ANCNFSM4KW3OWXA.

cems2 commented 4 years ago

So just to be clear.
in Sciml_train( loss, n_ode, array of Datum, number of epochs)

it should iterate over the array of Datum, passing in one single Datum struct for each call to loss(.) After it has worked it's way theough the list of Datum, then that is a single epoch.

THe Datum struct might be a single time series label or it might be a mini-batch of many labels, each one computed from a different intial condition u0. Loss() handles the details of Datum (single or minibatch). sciml_train handeles the bookkeeping of iterating over all the Datums for the epoch. then it repeats this for however many epochs are specified.

ChrisRackauckas commented 4 years ago

It doesn't have to access globals.

loss -> _loss(p,data)

makes data local. What will this give that's not just a closure?

ChrisRackauckas commented 4 years ago

How is it different from https://github.com/JuliaDiffEq/DiffEqFlux.jl/pull/154 ?

I'm not opposed to the idea, just nothing concrete about "why" has been written down, and so there's no guide as to how it should be made (since the "how" should meet the "why")

cems2 commented 4 years ago

Okay. That is a way to hand in the data, yes. But that still is pushing the logic of iterating over a set of examples into the Loss function rather than the training funtion. THis latter feature is how most ML programs seem to handle this. You give them a set of mini batches or an iterable of some sort that will cough up one mini batch for each invocation. FLux does it that way with Train!. Only in the trivial case where one is not using minibatches are the two the same.

cems2 commented 4 years ago

As an example, often this iterable data is a iterable that pipelines data off of the disk. Each chunk is a mini-batch.

another way of saying this is that Train!(.) handles the iteration over minibatches and Loss(.) just has to be able to compute the loss on any dataset or minibatch that is handed to it. Loss(.) should not have to know about where the data is being sourced or how to iterate over it. Separating these concepts lets you plug in different types of stochasitic optimizers without messing with the loss function. E.g. may you want to do a Brower update of the parameters for each elemetn of the mini batch, or maybe you just want the gradient of the loss for the sum(.) over all the minibatch data sets. You also want to be able to change the size of the minibatch later in the training as noise becomes impportant.

if you use the closure then you have to have a different closure for every different kind of minibatch processing plan.

So yes you could push all of that into the Loss(). It's just very common to separate these things. Which is why we have terms like minibatch and epochs and iterations.

cems2 commented 4 years ago

replying to the reference to #154 In that github issue it seems to hand off to https://github.com/ali-ramadhan/neural-differential-equation-climate-parameterizations/blob/master/diffusion_equation/Diffusion%20neural%20PDE%20and%20DAE.ipynb

and that is a notebook which is using Flux.train! not sciml_train.

Flux.train! has the data arg. which is what makes the minibatch logic easy to implement. So... I'm unclear how #154 is a solution here. It seems like it's the same lament I am making. How should one minibatch? As it stands one has to push that logic into the loss, making a custom loss function for every different sort of minibatching.

cems2 commented 4 years ago

I'll pause here just to note that the documentation for sciml_train includes the data arg. It's just not correct! (there is no data arg in the method). Thus just implementing it like Flux.Train does would be the "how" and the "why".

ChrisRackauckas commented 4 years ago

There is no docstring on sciml_train yet, you just found old reminents of nothing. We are trying to come up with a solution here, and I don't see something coherent in there.

My implementation would just enclose and iteration of the data in with the loss. Everything would look almost exactly the same to the user. How is it different from that?

ChrisRackauckas commented 4 years ago

Specifically, the solution that I have in mind right now if there was an optional data argument would just be to wrap it:

_loss(p) -> loss(p,iterate(data)...)

So, how is this better than the user doing that? If you have something concrete in mind, please spell it out with a nice clear example. It sounds like you have something in mind, but whatever you have in mind, if it's actually fixing something, would need an entirely different implementation. If it needs a different implementation, I need to know what it is in order to implement it correctly!

cems2 commented 4 years ago
  1. Not sure what you mean there is no doc string. If you look in ~/.julia/packages/DiffEqFlux/UNVXS/src/train.jl you see: """ train!(loss, params, data, opt; cb) For each datapoint d in data computes the gradient of loss(p,d...) through backpropagation and calls the optimizer opt. Takes a callback as keyword argument cb. For example, this will print "training" every 10 seconds:
    DiffEqFlux.sciml_train(loss, params, data, opt,
            cb = throttle(() -> println("training"), 10))

    The callback can call Flux.stop() to interrupt the training loop. Multiple optimisers and callbacks can be passed to opt and cb as arrays. """ function sciml_train(loss, _θ, opt; cb = (args...) -> (false), maxiters) θ = copy(_θ) ps = Flux.params(θ)

On Feb 17, 2020, at 9:06 PM, Christopher Rackauckas notifications@github.com wrote:

There is no docstring on sciml_train yet, you just found old reminents of nothing. We are trying to come up with a solution here, and I don't see something coherent in there.

My implementation would just enclose and iteration of the data in with the loss. Everything would look almost exactly the same to the user. How is it different from that?

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/162?email_source=notifications&email_token=ACRAR7S5HNWJ6P5AX2IWPKLRDNNDJA5CNFSM4KW3OWXKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEMAQBVQ#issuecomment-587268310, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACRAR7TFG27WIINFVVTNCGLRDNNDJANCNFSM4KW3OWXA.

ChrisRackauckas commented 4 years ago

Ignore that.

cems2 commented 4 years ago

There is obviously nothing wrong with your proposed approach.

My "why" here is to normalize it so that it works like Flux.Train! logically does. And the reason copying certain aspects of Flux.Train! is good is because they are the same sort of calling idiom that pytorch and TensorFLow uses.

you push the minibatch handling into the training. And you let the loss be agnostic to how you call it. THis division of labor seems like a good one. That's my main point.

But I can also elaborate a little more on the design principle here. In tensorFlow and Pytorch when you have an input tensor you can turn this into a minibatch just by adding one dimension to the front of the tensor. The idiom there is that this index is always the batch index. It a heavily vetted idiom so discaring it should be done carefully.

THE REASON YOU WANT THAT: the loss function doesn't need to know if it's processing a single input or broadcasing over a list of inputs in a minibatch

In julia, Flux does the same thing, except instead of being the first index it's the last index (julia style). the last index is the batch.

And once you adopt that idiom, it's easy to put the logic for that into Train! instead of a bespoke Loss for every different case.

Part and parcel with this is another issue I filed. When using NeuralODE on a batch it unfortunately makes the time axis the last dimension. THis too clobbers this idiom. (the batch index becomse the second to last index when time becomes the last index).

TO my mind, the time series is a filter layer, not a batch index. the NEuralODE solvers should logically make time the second to last index so batch can remain the last index.

that can be munged with permute indicies.

ANyhow the point is, the idiom on ML seems to be that the terminal index is the batch index.

On Feb 17, 2020, at 9:10 PM, Christopher Rackauckas notifications@github.com wrote:

Specifically, the solution that I have in mind right now if there was an optional data argument would just be to wrap it:

_loss(p) -> loss(p,iterate(data)...) So, how is this better than the user doing that? If you have something concrete in mind, please spell it out with a nice clear example. It sounds like you have something in mind, but whatever you have in mind, if it's actually fixing something, would need an entirely different implementation. If it needs a different implementation, I need to know what it is in order to implement it correctly!

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/162?email_source=notifications&email_token=ACRAR7QQC4BWUUXOP5SCODTRDNNRRA5CNFSM4KW3OWXKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEMAQHMA#issuecomment-587269040, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACRAR7R7REKRYK5W23RX67LRDNNRRANCNFSM4KW3OWXA.

cems2 commented 4 years ago

Hello again. I wanted to revisit this issue about why there should be an optional data argument in sciml_train.

It's true that we can sneak the data into a loss function with a bespoke closure. And this data could even be batch data.

But now lets consider iterating over a series of mini batches in the training. How does one do this? Let the Loss function have access to a list of data organized in a series of mini batches. How does the loss function know when the sciml_train is done with one iteration? It would need to know this in order to go onto the next mini batch.

Bad idea #1 Let the Loss function manage the rotation through the minibatches To see the issue this creates consider the following wrongheaded approach. place a counter in the loss function which increments a static variable each time loss is called so it knows which minibatch in the list of batches to use on the next iteration. That's a bad idea because if for any reason the loss function gets called more than once per iteration (say because of a callback or because the optimization method has to make multiple calls for each iteration or because you intermittently call it on a testset or hold-outset to assess the trainin) then this fails.

Bad idea #2 Let the callback manage the rotation through the minibatches So you could maybe offload the counter incrementing to the callback instead. THat's a very yucky solution both because now you have side effects passing from the call back to the loss function and because if you modify the call back in some way such as wrapping it in a timer to throttle it this flunks the stategy.

Bad idea #3: communicating the number of batches back to the epoch counter Finally, even if you got that wobbly kludge working then you alsowant to synchronize the number of iterations in the training to the number of minibatches so that things get processes in full epochs (complete run through all minibatches). But this means the scipy_train needs to be called a specific number of times per epoch depending on the number of minibatches.

The clean elegant solution But hey! who in this process actually knows exactly what iteration it is on and which of the mini batches in the list it needs to use now and knows not to change minibatches if the solver needs to call the loss function multiple times per iteration. Why gosh golly that would be sciml_train itself. So if you put the list of minibatches as an optional data parameter it can just handle on the increments from batch to batch, passing the right one to the loss function. And since it knows how long the list is, it can make sure it processes it fully in an epoch.

thus this is the natural way to handle minibatches.

tthe other ways require side effects in both the loss and callback functions. yick.

ChrisRackauckas commented 4 years ago

Yes, that's exactly why I added it last night, and that's your BFGS issue.

cems2 commented 4 years ago

Oh.. From the comment when this issue was closed, i took it to mean you DELETED the optiona_data fork entirely, meaning no optional_data. So I didn't update my julia libs to test it. :-( I think you just told me you added this feature. my bad...

I think you are saying that BFGS calls multiple times per iteration (that would be my understanding of how numerical Recipies does BFGS, though I've never looks at how it might be done in the age of AD methods).

But if you are refering to my other issue where the BFGS optimization failed on a mini-batch example, then no. I'm not rotating the minibatches in that one. it's just one single small batch all the time fully processed every time. So it should not screw up BFGS.

ChrisRackauckas commented 4 years ago

But if you are refering to my other issue where the BFGS optimization failed on a mini-batch example, then no. I'm not rotating the minibatches in that one. it's just one single small batch all the time fully processed every time. So it should not screw up BFGS.

Then it would be good to get an MWE for it.

ChrisRackauckas commented 4 years ago

Minimal working example. A minimal example that demonstrates/isolates the issue.

cems2 commented 4 years ago

Yeah. tommorrow I'll try taking things out of the demo to narrow it down.