EnzymeAD / Reactant.jl

MIT License
26 stars 2 forks source link

xla RuntimeError on differentiating simple neural network #24

Closed avik-pal closed 2 weeks ago

avik-pal commented 2 weeks ago
using Reactant, Lux, Random # , Optimisers

function loss_fn(model, ps, st, data)
    y, st_new = model(data, ps, st)
    return sum(y) # , st_new, (;)
end

model = Dense(10, 5, tanh)

data = rand(Float32, 10, 3)

ps, st = Lux.setup(Xoshiro(0), model)

reactant_ps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete, nothing)
reactant_st = Reactant.make_tracer(IdDict(), st, (), Reactant.ArrayToConcrete, nothing)
reactant_data = Reactant.make_tracer(IdDict(), data, (), Reactant.ArrayToConcrete, nothing)

reactant_loss_fn = Reactant.compile(
    loss_fn, (model, reactant_ps, reactant_st, reactant_data))

function gradient_loss_fn(model, ps, st, data)
    dps = Enzyme.make_zero(ps)
    Enzyme.autodiff(Enzyme.Reverse, loss_fn, Active, Const(model),
        Duplicated(ps, dps), Const(st), Const(data))
    return dps
end

gradient_loss_fn(model, ps, st, data)  # works

reactant_gradient_loss_fn = Reactant.compile(
    gradient_loss_fn, (model, reactant_ps, reactant_st, reactant_data))

# Lux.Experimental.single_train_step(AutoReactant(), loss_fn, data, ts)
terminate called after throwing an instance of 'xla::XlaRuntimeError'
  what():  UNKNOWN: <unknown>:0: error: Reduction function must return a scalar or tuple of scalars but returns shape: f32[5]: 
<unknown>:0: note: see current operation: "func.return"(%11, %10) : (tensor<1x5xf32>, tensor<10x5xf32>) -> ()
wsmoses commented 2 weeks ago

Are you able to make a MWE of just the sum or whatever else causes the failure?

avik-pal commented 2 weeks ago

Let me try a bit more, but stripping things down to plain matmul and add makes it work

using Enzyme, Reactant # , Lux, Random # , Optimisers

# function loss_fn(model, ps, st, data)
# y, st_new = model(data, ps, st)
function loss_fn(weight, data, bias)
    y = weight * data .+ bias
    # y = muladd(weight, data, bias)
    return sum(y) # , st_new, (;)
end

# model = Dense(10, 5, tanh)

data = rand(Float32, 10, 3)
weight = rand(Float32, 10, 10)
bias = rand(Float32, 10)
# ps, st = Lux.setup(Xoshiro(0), model)

# reactant_ps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete, nothing)
# reactant_st = Reactant.make_tracer(IdDict(), st, (), Reactant.ArrayToConcrete, nothing)
# reactant_data = Reactant.make_tracer(IdDict(), data, (), Reactant.ArrayToConcrete, nothing)

reactant_weight = Reactant.ConcreteRArray(weight)
reactant_data = Reactant.ConcreteRArray(data)
reactant_bias = Reactant.ConcreteRArray(bias)

reactant_loss_fn = Reactant.compile(
    loss_fn, (reactant_weight, reactant_data, reactant_bias))

function gradient_loss_fn(weight, data, bias)
    dweight = Enzyme.make_zero(weight)
    dbias = Enzyme.make_zero(bias)
    Enzyme.autodiff(Enzyme.Reverse, loss_fn, Active, Duplicated(weight, dweight),
        Const(data), Duplicated(bias, dbias))
    return dweight
end

gradient_loss_fn(weight, data, bias)

reactant_gradient_loss_fn = Reactant.compile(
    gradient_loss_fn, (reactant_weight, reactant_data, reactant_bias))

# Lux.Experimental.single_train_step(AutoReactant(), loss_fn, data, ts)
avik-pal commented 2 weeks ago

Ok here we go, it is the activation function

using Enzyme, Reactant

function loss_fn(ps, data)
    y = tanh.(ps.weight * data .+ ps.bias)
    return sum(y)
end

data = rand(Float32, 10, 3)
weight = rand(Float32, 10, 10)
bias = rand(Float32, 10)
ps = (weight=weight, bias=bias)

reactant_ps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete, nothing)
reactant_data = Reactant.ConcreteRArray(data)

reactant_loss_fn = Reactant.compile(loss_fn, (reactant_ps, reactant_data))

function gradient_loss_fn(ps, data)
    dps = Enzyme.make_zero(ps)
    Enzyme.autodiff(Enzyme.Reverse, loss_fn, Active, Duplicated(ps, dps), Const(data))
    return dps
end

gradient_loss_fn(ps, data)

reactant_gradient_loss_fn = Reactant.compile(gradient_loss_fn, (reactant_ps, reactant_data))
wsmoses commented 2 weeks ago

Ah looks like its the broadcast in dim differentiation using the wrong inner ty.

preModule:
module {
  func.func private @"Const{typeof(loss_fn)}(loss_fn)_autodiff"(%arg0: tensor<3x10xf32>, %arg1: tensor<10x10xf32>, %arg2: tensor<10xf32>) -> (tensor<f32>, tensor<10x10xf32>, tensor<3x10xf32>, tensor<10xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
    %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %2 = stablehlo.transpose %arg2, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
    %3 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf32>, tensor<10x3xf32>) -> tensor<10x3xf32>
    %4 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<10xf32>) -> tensor<10x3xf32>
    %5 = stablehlo.add %3, %4 : tensor<10x3xf32>
    %6 = stablehlo.tanh %5 : tensor<10x3xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %7 = stablehlo.reduce(%6 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<10x3xf32>, tensor<f32>) -> tensor<f32>
    %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
    %9 = stablehlo.transpose %1, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %10 = stablehlo.transpose %0, dims = [1, 0] : (tensor<10x3xf32>) -> tensor<3x10xf32>
    %11 = stablehlo.transpose %2, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
    return %8, %9, %10, %11 : tensor<f32>, tensor<10x10xf32>, tensor<3x10xf32>, tensor<10xf32>
  }
  func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10x10xf32>, %arg2: tensor<3x10xf32>) -> (tensor<10xf32>, tensor<3x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
    %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<10xf32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<10x3xf32>) -> tensor<3x10xf32>
    %4 = stablehlo.transpose %1, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %5 = stablehlo.transpose %0, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
    %6 = stablehlo.transpose %cst_1, dims = [] : (tensor<f32>) -> tensor<f32>
    %7 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %8 = stablehlo.transpose %cst_0, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
    %9:5 = enzyme.autodiff @"Const{typeof(loss_fn)}(loss_fn)_autodiff"(%3, %4, %5, %6, %7, %8) {activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>]} : (tensor<3x10xf32>, tensor<10x10xf32>, tensor<10xf32>, tensor<f32>, tensor<10x10xf32>, tensor<10xf32>) -> (tensor<10x10xf32>, tensor<3x10xf32>, tensor<10xf32>, tensor<10x10xf32>, tensor<10xf32>)
    %10 = stablehlo.transpose %9#0, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %11 = stablehlo.transpose %9#1, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
    %12 = stablehlo.transpose %9#2, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
    %13 = stablehlo.transpose %9#3, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %14 = stablehlo.transpose %9#4, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
    %15 = stablehlo.transpose %14, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
    %16 = stablehlo.transpose %11, dims = [1, 0] : (tensor<10x3xf32>) -> tensor<3x10xf32>
    %17 = stablehlo.transpose %13, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %18 = stablehlo.transpose %10, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %19 = stablehlo.transpose %12, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
    return %15, %16, %17, %18, %19 : tensor<10xf32>, tensor<3x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10xf32>
  }
}
Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10x10xf32>, %arg2: tensor<3x10xf32>) -> (tensor<10xf32>, tensor<3x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10xf32>) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<10x3xf32>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<10xf32>
    %0 = stablehlo.dot_general %arg1, %arg2, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf32>, tensor<3x10xf32>) -> tensor<10x3xf32>
    %1 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<10xf32>) -> tensor<10x3xf32>
    %2 = stablehlo.add %0, %1 : tensor<10x3xf32>
    %3 = stablehlo.tanh %2 : tensor<10x3xf32>
    %4 = stablehlo.multiply %3, %3 : tensor<10x3xf32>
    %5 = stablehlo.subtract %cst, %4 : tensor<10x3xf32>
    %6 = stablehlo.reduce(%5 init: %cst_0) across dimensions = [1] : (tensor<10x3xf32>, tensor<10xf32>) -> tensor<10xf32>
     reducer(%arg3: tensor<10xf32>, %arg4: tensor<10xf32>)  {
      %8 = stablehlo.add %arg3, %arg4 : tensor<10xf32>
      stablehlo.return %8 : tensor<10xf32>
    }
    %7 = stablehlo.dot_general %arg2, %5, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x10xf32>, tensor<10x3xf32>) -> tensor<10x10xf32>
    return %6, %arg2, %7, %arg1, %arg0 : tensor<10xf32>, tensor<3x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10xf32>
  }
}
terminate called after throwing an instance of 'xla::XlaRuntimeError'
  what():  UNKNOWN: <unknown>:0: error: Reduction function must return a scalar or tuple of scalars but returns shape: f32[10]: 
wsmoses commented 2 weeks ago

ixed by latest jll bump