Closed avik-pal closed 2 weeks ago
Are you able to make a MWE of just the sum or whatever else causes the failure?
Let me try a bit more, but stripping things down to plain matmul and add makes it work
using Enzyme, Reactant # , Lux, Random # , Optimisers
# function loss_fn(model, ps, st, data)
# y, st_new = model(data, ps, st)
function loss_fn(weight, data, bias)
y = weight * data .+ bias
# y = muladd(weight, data, bias)
return sum(y) # , st_new, (;)
end
# model = Dense(10, 5, tanh)
data = rand(Float32, 10, 3)
weight = rand(Float32, 10, 10)
bias = rand(Float32, 10)
# ps, st = Lux.setup(Xoshiro(0), model)
# reactant_ps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete, nothing)
# reactant_st = Reactant.make_tracer(IdDict(), st, (), Reactant.ArrayToConcrete, nothing)
# reactant_data = Reactant.make_tracer(IdDict(), data, (), Reactant.ArrayToConcrete, nothing)
reactant_weight = Reactant.ConcreteRArray(weight)
reactant_data = Reactant.ConcreteRArray(data)
reactant_bias = Reactant.ConcreteRArray(bias)
reactant_loss_fn = Reactant.compile(
loss_fn, (reactant_weight, reactant_data, reactant_bias))
function gradient_loss_fn(weight, data, bias)
dweight = Enzyme.make_zero(weight)
dbias = Enzyme.make_zero(bias)
Enzyme.autodiff(Enzyme.Reverse, loss_fn, Active, Duplicated(weight, dweight),
Const(data), Duplicated(bias, dbias))
return dweight
end
gradient_loss_fn(weight, data, bias)
reactant_gradient_loss_fn = Reactant.compile(
gradient_loss_fn, (reactant_weight, reactant_data, reactant_bias))
# Lux.Experimental.single_train_step(AutoReactant(), loss_fn, data, ts)
Ok here we go, it is the activation function
using Enzyme, Reactant
function loss_fn(ps, data)
y = tanh.(ps.weight * data .+ ps.bias)
return sum(y)
end
data = rand(Float32, 10, 3)
weight = rand(Float32, 10, 10)
bias = rand(Float32, 10)
ps = (weight=weight, bias=bias)
reactant_ps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete, nothing)
reactant_data = Reactant.ConcreteRArray(data)
reactant_loss_fn = Reactant.compile(loss_fn, (reactant_ps, reactant_data))
function gradient_loss_fn(ps, data)
dps = Enzyme.make_zero(ps)
Enzyme.autodiff(Enzyme.Reverse, loss_fn, Active, Duplicated(ps, dps), Const(data))
return dps
end
gradient_loss_fn(ps, data)
reactant_gradient_loss_fn = Reactant.compile(gradient_loss_fn, (reactant_ps, reactant_data))
Ah looks like its the broadcast in dim differentiation using the wrong inner ty.
preModule:
module {
func.func private @"Const{typeof(loss_fn)}(loss_fn)_autodiff"(%arg0: tensor<3x10xf32>, %arg1: tensor<10x10xf32>, %arg2: tensor<10xf32>) -> (tensor<f32>, tensor<10x10xf32>, tensor<3x10xf32>, tensor<10xf32>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%2 = stablehlo.transpose %arg2, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
%3 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf32>, tensor<10x3xf32>) -> tensor<10x3xf32>
%4 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<10xf32>) -> tensor<10x3xf32>
%5 = stablehlo.add %3, %4 : tensor<10x3xf32>
%6 = stablehlo.tanh %5 : tensor<10x3xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%7 = stablehlo.reduce(%6 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<10x3xf32>, tensor<f32>) -> tensor<f32>
%8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
%9 = stablehlo.transpose %1, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%10 = stablehlo.transpose %0, dims = [1, 0] : (tensor<10x3xf32>) -> tensor<3x10xf32>
%11 = stablehlo.transpose %2, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
return %8, %9, %10, %11 : tensor<f32>, tensor<10x10xf32>, tensor<3x10xf32>, tensor<10xf32>
}
func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10x10xf32>, %arg2: tensor<3x10xf32>) -> (tensor<10xf32>, tensor<3x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10xf32>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<10xf32>
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<10x3xf32>) -> tensor<3x10xf32>
%4 = stablehlo.transpose %1, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%5 = stablehlo.transpose %0, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
%6 = stablehlo.transpose %cst_1, dims = [] : (tensor<f32>) -> tensor<f32>
%7 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%8 = stablehlo.transpose %cst_0, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
%9:5 = enzyme.autodiff @"Const{typeof(loss_fn)}(loss_fn)_autodiff"(%3, %4, %5, %6, %7, %8) {activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>]} : (tensor<3x10xf32>, tensor<10x10xf32>, tensor<10xf32>, tensor<f32>, tensor<10x10xf32>, tensor<10xf32>) -> (tensor<10x10xf32>, tensor<3x10xf32>, tensor<10xf32>, tensor<10x10xf32>, tensor<10xf32>)
%10 = stablehlo.transpose %9#0, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%11 = stablehlo.transpose %9#1, dims = [1, 0] : (tensor<3x10xf32>) -> tensor<10x3xf32>
%12 = stablehlo.transpose %9#2, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
%13 = stablehlo.transpose %9#3, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%14 = stablehlo.transpose %9#4, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
%15 = stablehlo.transpose %14, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
%16 = stablehlo.transpose %11, dims = [1, 0] : (tensor<10x3xf32>) -> tensor<3x10xf32>
%17 = stablehlo.transpose %13, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%18 = stablehlo.transpose %10, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%19 = stablehlo.transpose %12, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
return %15, %16, %17, %18, %19 : tensor<10xf32>, tensor<3x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10xf32>
}
}
Module:
module attributes {transform.with_named_sequence} {
func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10x10xf32>, %arg2: tensor<3x10xf32>) -> (tensor<10xf32>, tensor<3x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10xf32>) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<10x3xf32>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<10xf32>
%0 = stablehlo.dot_general %arg1, %arg2, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf32>, tensor<3x10xf32>) -> tensor<10x3xf32>
%1 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<10xf32>) -> tensor<10x3xf32>
%2 = stablehlo.add %0, %1 : tensor<10x3xf32>
%3 = stablehlo.tanh %2 : tensor<10x3xf32>
%4 = stablehlo.multiply %3, %3 : tensor<10x3xf32>
%5 = stablehlo.subtract %cst, %4 : tensor<10x3xf32>
%6 = stablehlo.reduce(%5 init: %cst_0) across dimensions = [1] : (tensor<10x3xf32>, tensor<10xf32>) -> tensor<10xf32>
reducer(%arg3: tensor<10xf32>, %arg4: tensor<10xf32>) {
%8 = stablehlo.add %arg3, %arg4 : tensor<10xf32>
stablehlo.return %8 : tensor<10xf32>
}
%7 = stablehlo.dot_general %arg2, %5, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x10xf32>, tensor<10x3xf32>) -> tensor<10x10xf32>
return %6, %arg2, %7, %arg1, %arg0 : tensor<10xf32>, tensor<3x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10xf32>
}
}
terminate called after throwing an instance of 'xla::XlaRuntimeError'
what(): UNKNOWN: <unknown>:0: error: Reduction function must return a scalar or tuple of scalars but returns shape: f32[10]:
ixed by latest jll bump