openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.55k stars 394 forks source link

Add more patterns to unstacking pass. #14886

Closed copybara-service[bot] closed 1 month ago

copybara-service[bot] commented 1 month ago

Add more patterns to unstacking pass.

Added the following patterns to the unstacking pass:

  1. GetReduceFusionPattern: fusion(stacked, loop_iteration_var) computation { p0 = parameter(0) p1 = parameter(1) slice = dynamic_slice(p0, p1, zero, ...) ROOT reduce = reduce(slice, constant) }
  2. GetDUSFusionWithPadPattern: fusion(stacked, update, loop_iteration_var) computation { p0 = parameter(0) p1 = parameter(1) p2 = parameter(2) pad = pad(p1, ...) update = bitcast(pad) ROOT dus = dynamic_update_slice(p0, update, p2, zero, ...) }