FluxML / FluxJS.jl

I heard you like compile times
Other
42 stars 8 forks source link

Added support for LSTM and BatchNorm #24

Closed Roboneet closed 6 years ago

Roboneet commented 6 years ago

StagedArray uses val field instead of dims.

MikeInnes commented 6 years ago

Looks good to me, though I guess it needs rebasing over the NCHW changes.

MikeInnes commented 6 years ago

It would also be good to start having some basic tests for the generated code here.

MikeInnes commented 6 years ago

Looks like this has conflicts from the tensorflow.js update

MikeInnes commented 6 years ago

:+1: Thanks!

amellnik commented 6 years ago

I'm late to the party, but I'm still seeing Unsupported type Tuple{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}}} when trying to export a model that follows the char-rnn example from the model zoo. Is there something I'm missing?

Roboneet commented 6 years ago

Can you please post the code? I can't seem to reproduce the error with the char-rnn model.

amellnik commented 6 years ago

Here's a quick but not very minimal example:

using Flux
using Flux: onehot, argmax, chunk, batchseq, throttle, crossentropy
using FluxJS
using StatsBase: wsample
using Base.Iterators: partition
using BSON: @save, @load

text = collect(readstring("hp.txt")) # Or some other text file
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)
N = length(alphabet)

m = Chain(
  LSTM(N, 128),
  LSTM(128, 128),
  Dense(128, N),
  softmax)

m(trues(94)) # Returns a Tracked 94-element Array{Float64,1}:

@code_js m(trues(94)) # Unsupported type Tuple{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}}}

I get this error no matter what I call m with in @code_js.

Roboneet commented 6 years ago

That error seems to come up on commits before 4afb9e7. Can you check the git logs and see if you have the latest version of master?

amellnik commented 6 years ago

Ugh, my mistake -- thanks. On actual master I am still getting the following:

typeof(f) = Base.#one
one
Tuple{TrackedArray{…,Array{Float64,1}}}
no method found for the specified argument types
Roboneet commented 6 years ago

I'm not sure why this is happening, I can't reproduce it on my end. Now that we have tests for the primitives, can you check out which ones aren't working on yours?

amellnik commented 6 years ago

I was using the ASTInterpreter2 master which was also causing Flux tests to fail. I freed it back to 0.1.1, and while Flux tests now pass, FluxJS tests (and exporting models) fails with the following:

FluxJS: Error During Test
  Got an exception of type UndefRefError outside of a @test
  UndefRefError: access to undefined reference
  Stacktrace:
   [1] lookup_var_if_var at /home/alex/.julia/v0.6/ASTInterpreter2/src/ASTInterpreter2.jl:147 [inlined]
   [2] lookup(::ASTInterpreter2.JuliaStackFrame, ::SSAValue) at /home/alex/.julia/v0.6/Vinyl/src/interpret.jl:21
   [3] broadcast_t(::Function, ::Type{Any}, ::Tuple{Base.OneTo{Int64}}, ::CartesianRange{CartesianIndex{1}}, ::ASTInterpreter2.JuliaStackFrame, ::Array{Any,1}) at ./broadcast.jl:258
   [4] broadcast_c at ./broadcast.jl:321 [inlined]
   [5] broadcast(::Function, ::ASTInterpreter2.JuliaStackFrame, ::Array{Any,1}) at ./broadcast.jl:455
   [6] callargs(::DebuggerFramework.DebuggerState) at /home/alex/.julia/v0.6/Vinyl/src/interpret.jl:29
   [7] runall(::FluxJS.BTrace, ::DebuggerFramework.DebuggerState) at /home/alex/.julia/v0.6/Vinyl/src/interpret.jl:34
   [8] overdub(::FluxJS.BTrace, ::Function) at /home/alex/.julia/v0.6/Vinyl/src/interpret.jl:53
   [9] primitive(::FluxJS.Trace, ::Base.#broadcast, ::Function, ::Tuple{TrackedArray{…,Array{Float64,1}}}) at /home/alex/.julia/v0.6/Vinyl/src/hooks.jl:24
   [10] runall(::FluxJS.Trace, ::DebuggerFramework.DebuggerState) at /home/alex/.julia/v0.6/Vinyl/src/interpret.jl:37
   [11] overdub(::FluxJS.Trace, ::Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}, ::FluxJS.StagedArray{Float64,1}, ::Vararg{FluxJS.StagedArray{Float64,1},N} where N) at /home/alex/.julia/v0.6/Vinyl/src/interpret.jl:53
   [12] #trace#7(::FluxJS.Trace, ::Function, ::Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}, ::FluxJS.StagedArray{Float64,1}, ::Vararg{FluxJS.StagedArray{Float64,1},N} where N) at /home/alex/.julia/v0.6/FluxJS/src/trace.jl:46
   [13] (::FluxJS.#kw##trace)(::Array{Any,1}, ::FluxJS.#trace, ::Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}, ::FluxJS.StagedArray{Float64,1}, ::Vararg{FluxJS.StagedArray{Float64,1},N} where N) at ./<missing>:0
   [14] #_traceλ#10(::FluxJS.Trace, ::Function, ::Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}, ::Array{Float64,1}, ::Vararg{Array{Float64,1},N} where N) at /home/alex/.julia/v0.6/FluxJS/src/trace.jl:55
   [15] (::FluxJS.#kw##_traceλ)(::Array{Any,1}, ::FluxJS.#_traceλ, ::Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}, ::Array{Float64,1}, ::Vararg{Array{Float64,1},N} where N) at ./<missing>:0
   [16] #traceλ#11 at /home/alex/.julia/v0.6/FluxJS/src/trace.jl:60 [inlined]
   [17] traceλ at /home/alex/.julia/v0.6/FluxJS/src/trace.jl:60 [inlined]
   [18] macro expansion at /home/alex/.julia/v0.6/FluxJS/test/runtests.jl:12 [inlined]
   [19] macro expansion at ./test.jl:860 [inlined]
   [20] anonymous at ./<missing>:?
   [21] include_from_node1(::String) at ./loading.jl:576
   [22] include(::String) at ./sysimg.jl:14
   [23] process_options(::Base.JLOptions) at ./client.jl:305
   [24] _start() at ./client.jl:371
Test Summary: | Error  Total
FluxJS        |     1      1
ERROR: LoadError: Some tests did not pass: 0 passed, 0 failed, 1 errored, 0 broken.
while loading /home/alex/.julia/v0.6/FluxJS/test/runtests.jl, in expression starting on line 9
FluxJS had test errors

Stacktrace:
 [1] #test#62(::Bool, ::Function, ::Array{AbstractString,1}) at ./pkg/entry.jl:759
 [2] (::Base.Pkg.Entry.#kw##test)(::Array{Any,1}, ::Base.Pkg.Entry.#test, ::Array{AbstractString,1}) at ./<missing>:0
 [3] (::Base.Pkg.Dir.##4#7{Array{Any,1},Base.Pkg.Entry.#test,Tuple{Array{AbstractString,1}}})() at ./pkg/dir.jl:36
 [4] cd(::Base.Pkg.Dir.##4#7{Array{Any,1},Base.Pkg.Entry.#test,Tuple{Array{AbstractString,1}}}, ::String) at ./file.jl:70
 [5] #cd#1(::Array{Any,1}, ::Function, ::Function, ::Array{AbstractString,1}, ::Vararg{Array{AbstractString,1},N} where N) at ./pkg/dir.jl:36
 [6] (::Base.Pkg.Dir.#kw##cd)(::Array{Any,1}, ::Base.Pkg.Dir.#cd, ::Function, ::Array{AbstractString,1}, ::Vararg{Array{AbstractString,1},N} where N) at ./<missing>:0
 [7] #test#3(::Bool, ::Function, ::String, ::Vararg{String,N} where N) at ./pkg/pkg.jl:276
 [8] test(::String, ::Vararg{String,N} where N) at ./pkg/pkg.jl:276

What ASTInterpreter2 are you on?

Roboneet commented 6 years ago

I'm on master as well. You might want to try out this branch of Vinyl to get a stacktrace of the error.

amellnik commented 6 years ago
UndefRefError: access to undefined reference
Stacktrace of evaluated expression:
[1] #23() at :0

Stacktrace of evaluated expression:
[1] LSTMCell(h_, x) at /home/alex/.julia/v0.6/Flux/src/layers/recurrent.jl:132

Stacktrace of evaluated expression:
[1] #82(x, m) at /home/alex/.julia/v0.6/Flux/src/layers/basic.jl:31

Stacktrace of evaluated expression:
[1] mapfoldl_impl(f, op, v0, itr, i) at reduce.jl:39

Stacktrace of evaluated expression:
[1] mapfoldl(f, op, v0, itr) at reduce.jl:58

Stacktrace of evaluated expression:
[1] foldl(op, v0, itr) at reduce.jl:87

Stacktrace of evaluated expression:
[1] Chain(x) at /home/alex/.julia/v0.6/Flux/src/layers/basic.jl:31