FluxML / Torch.jl

Sensible extensions for exposing torch in Julia.
Other
211 stars 14 forks source link

gradient() throws `ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")` #46

Open michelangelo21 opened 3 years ago

michelangelo21 commented 3 years ago

When executing example from README, gs = gradient(x -> sum(tresnet(x)), tip); throws ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")

minimal working example (simple Chain substituted for ResNet from README):

using Flux
using Torch
using Torch: torch

net = Chain(
    Dense(10, 5, σ),
    Dense(5, 2),
    softmax)

tnet = net |> torch

ip = rand(Float32, 10, 1)
tip = tensor(ip, dev = 0)

gs = gradient(x -> sum(tnet(x)), tip)

Result:

ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
  [1] check_broadcast_shape(#unused#::Tuple{}, Ashp::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}})
    @ Base.Broadcast ./broadcast.jl:518
  [2] check_broadcast_axes
    @ ./broadcast.jl:523 [inlined]
  [3] check_broadcast_axes
    @ ./broadcast.jl:526 [inlined]
  [4] instantiate(bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{}, typeof(*), Tuple{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Tensor{Float32, 2}}})
    @ Base.Broadcast ./broadcast.jl:269
  [5] materialize!
    @ ./broadcast.jl:894 [inlined]
  [6] materialize!
    @ ./broadcast.jl:891 [inlined]
  [7] ∇softmax!(out::Tensor{Float32, 0}, Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, x::Tensor{Float32, 2}, y::Tensor{Float32, 2}; dims::Int64)
    @ NNlib ~/.julia/packages/NNlib/TOStL/src/softmax.jl:70
  [8] ∇softmax(Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, x::Tensor{Float32, 2}, y::Tensor{Float32, 2}; dims::Int64)
    @ NNlib ~/.julia/packages/NNlib/TOStL/src/softmax.jl:62
  [9] softmax_pullback
    @ ~/.julia/packages/NNlib/TOStL/src/softmax.jl:81 [inlined]
 [10] ZBack
    @ ~/.julia/packages/Zygote/RxTZu/src/compiler/chainrules.jl:77 [inlined]
 [11] Pullback
    @ ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:36 [inlined]
 [12] (::typeof(∂(applychain)))(Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:36 [inlined]
 [14] (::typeof(∂(applychain)))(Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:36 [inlined]
 [16] (::typeof(∂(applychain)))(Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:38 [inlined]
 [18] (::typeof(∂(λ)))(Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [19] Pullback
    @ ./REPL[20]:1 [inlined]
 [20] (::typeof(∂(#3)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#41#42"{typeof(∂(#3))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:41
 [22] gradient(f::Function, args::Tensor{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
 [23] top-level scope