pluskid / Mocha.jl

Deep Learning framework for Julia
Other
1.29k stars 254 forks source link

Faster predection #228

Open dryden3 opened 7 years ago

dryden3 commented 7 years ago

Is there any way to make a prediction on a network other than defining a new data layer, creating a network, loading the network from a snapshot, than forwarding it? A function that takes in the input to the layer and a net and returns the output with out having to read from disc would be ideal. Thanks for you help!

greenflash1357 commented 6 years ago

I think what you are looking for is MemoryDataLayer.

Jacob-Stevens-Haas commented 6 years ago

I don't think that MemoryDataLayer solves the problem. If I understand @dryden3 correctly, I'm having the same problem. I'm using a Neural Net for regression of dynamics, so my inputs and outputs of training are observed system values shifted by a time step.

To evaluate the model's performance in recreating dynamics, I need to feed initial system conditions forward, then take those outputs and feed them forward, eventually building up a trial path of the system.

Currently, I'm doing

for 1: timesteps
    memdata=MemoryDataLayer(..., data=current)
    memout=MemoryOutputLayer(...)
    pred=Net(...,[memdata, common_layers, memout])
    load_snapshot(pred)
    forward(pred,...)
    curr=to_array(pred.outputblobs[:lastlayer])
end

If I build pred in advance, I can't leave out the call to Net() or else the old data is run through again. Without building a recurrent neural net, is there a way to change the data in the MemoryDataLayer once the net is built?

greenflash1357 commented 6 years ago

You can access the data of a MemoryDataLayer directly and assign new input at each iteration:

for 1: timesteps
    memdata=MemoryDataLayer(..., data=current)
    memout=MemoryOutputLayer(...)
    pred=Net(...,[memdata, common_layers, memout])
    load_snapshot(pred)
    forward(pred,...)
    curr=to_array(pred.outputblobs[:lastlayer])
    memdata.data = curr
end

Make sure indices, dimensions, etc. match.