Open rlouf opened 2 years ago
More generally, many issues with the JAX backend could be resolved with rewrites to simplify slicing operations in the graph and constant-folding for slice elements.
Possibly related to https://github.com/aesara-devs/aesara/issues/940
Working on #1202 I noticed that
Subtensor
Ops with a step index equal to 1 do not get simplified:This node is created by the
local_subtensor_merge
rewrite:The
Elemwise{mul}
Op is also doing wasteful computation, but I do not quite understand what sequence of rewrites led to this: