pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
340 stars 101 forks source link

Missed scan rewrites #787

Open aseyboldt opened 4 months ago

aseyboldt commented 4 months ago

Description

There are two issues with the code generated by this snippet:

def update(x):
    return pt.exp(x) - 5

x_init = pt.vector("x_init", shape=(7,))
x_init_tangent = pt.vector("x_init_tangent", shape=(7,))
seq, updates = pytensor.scan(update, outputs_info=[x_init], n_steps=10)
outputs = seq[-1]
output_tangent = pytensor.Rop(outputs, x_init, eval_points=x_init_tangent)

with pytensor.config.change_flags(optimizer_verbose=False):
    func = pytensor.function([x_init, x_init_tangent], [outputs, output_tangent], mode=pytensor.compile.mode.get_mode("FAST_RUN"))

pytensor.dprint(func, print_type=True, print_destroy_map=True)
``` Subtensor{i} [id A] 13 ├─ Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all}.0 [id B] 12 │ ├─ 10 [id C] │ ├─ SetSubtensor{:stop} [id D] 11 │ │ ├─ AllocEmpty{dtype='float64'} [id E] 10 │ │ │ ├─ 2 [id F] │ │ │ └─ 7 [id G] │ │ ├─ SpecifyShape [id H] 7 │ │ │ ├─ Unbroadcast{0} [id I] 6 │ │ │ │ └─ ExpandDims{axis=0} [id J] 5 │ │ │ │ └─ x_init [id K] │ │ │ ├─ 1 [id L] │ │ │ └─ 7 [id M] │ │ └─ 1 [id N] │ ├─ SetSubtensor{:stop} [id O] 9 │ │ ├─ AllocEmpty{dtype='float64'} [id P] 8 │ │ │ ├─ 1 [id Q] │ │ │ └─ 7 [id G] │ │ ├─ SpecifyShape [id H] 7 │ │ │ └─ ··· │ │ └─ 1 [id N] │ └─ SetSubtensor{:stop} [id R] 4 │ ├─ AllocEmpty{dtype='float64'} [id S] 3 │ │ ├─ 2 [id T] │ │ └─ 7 [id G] │ ├─ SpecifyShape [id U] 2 │ │ ├─ Unbroadcast{0} [id V] 1 │ │ │ └─ ExpandDims{axis=0} [id W] 0 │ │ │ └─ x_init_tangent [id X] │ │ ├─ 1 [id L] │ │ └─ 7 [id M] │ └─ 1 [id N] └─ 1 [id Y] Subtensor{i} [id Z] 14 ├─ Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all}.2 [id B] 12 │ └─ ··· └─ 1 [id Y] Inner graphs: Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all} [id B] ← Composite{(exp(i0) - 5.0)} [id BA] └─ *0- [id BB] -> [id D] ← Composite{...}.0 [id BC] ├─ *1- [id BD] -> [id O] └─ *2- [id BE] -> [id R] ← Composite{...}.1 [id BC] └─ ··· Composite{(exp(i0) - 5.0)} [id BA] ← sub [id BF] 'o0' ├─ exp [id BG] │ └─ i0 [id BH] └─ 5.0 [id BI] Composite{...} [id BC] ← sub [id BJ] 'o0' ├─ exp [id BK] 't3' │ └─ i0 [id BL] └─ 5.0 [id BM] ← mul [id BN] 'o1' ├─ exp [id BK] 't3' │ └─ ··· └─ i1 [id BO] ```

cc @ricardoV94

ricardoV94 commented 3 months ago

Regarding the (2,7) instead of (1, 7), my guess is this may be to facilitate inplace rewrites without the need for deepcopy? For instance the inplace logic for CompositeOps with multiple outputs is not trivial, because we have to make sure the input is not modified in place when it is still needed to compute other outputs (#138 )?