EnzymeAD / Enzyme-JAX

Other
49 stars 7 forks source link

Tracking issue for missing Batch Op Interface #152

Open Pangoraw opened 3 weeks ago

Pangoraw commented 3 weeks ago

NOTE: Strikethrough ops are deliberately not implemented because the default broadcasting behavior of enzyme batch is enough.

mofeing commented 3 weeks ago

... because the default broadcasting behavior of enzyme batch is enough.

Do you mean unrolling or leave the op unchanged?

Pangoraw commented 3 weeks ago

The batch pass will take the original op:

%0 = stablehlo.add %arg0, %arg1 : tensor<10xf32>

and just prepend the broadcasted dimensions (e.g. 20x4):

%0 = stablehlo.add %arg0, %arg1 : tensor<20x4x10xf32>

This is the default behavior for all ops unless they implement BatchOpInterface.

wsmoses commented 2 weeks ago

whileop definitely shouldn't be unrolled here in most cases [since it almost always has a number of iterations fixed by a constant aka non data-dependent value]