mila-iqia / myia

Myia prototyping
MIT License
455 stars 46 forks source link

Inplace operations #36

Open breuleux opened 6 years ago

breuleux commented 6 years ago

This is a follow up to a discussion in #30, I'm making an issue out of it so that it isn't all buried over there (feel free to copy your points from there to here).

Syntax support for this feature has already been implemented, but is currently isolated in branch augassign.

The issue is whether we want to support the following operations, and if we do, how it should be done:

x[i] = y  # index/slice assignment
x.y = z   # attribute assignment
x += y    # augmented assignment

In Python these are destructive operations, meaning that this should succeed:

x = [1, 2, 3]
y = x
x[0] = 11
assert y[0] == 11

Incidentally, this will also work in Myia because of copy propagation, but there are many other ways through which x could be aliased someplace else (through an argument to a function, for example). In order to avoid bad surprises for the user we should avoid any situation where Myia and Python semantics would diverge: the same program should produce the same results in both cases. In divergent cases, Myia should raise an error at compile time, although @abergeron raised the point that the rules should be tractable: it should be easy for the user to understand why their program does not compile.

If we want to support this in the future, I think it'd be worth taking a look at linear types or uniqueness types, which essentially constrain values to only be used once. We could also look at Clojure's transient feature, which allows you to create mutable structures out of persistent ones, but doesn't let you "leak" them out.

I haven't read about it in detail yet, but a sketch of how I think a uniqueness type approach would work (I'll call it the Transient type here) is that the user can annotate a variable as Transient, at which point the type system has to ensure that it is only used once:

x: Transient = zeros(10)
x += y           # 1. OK: x: Transient = x + y
x = f(x)         # 2. OK: x: Transient = f(x)
for i in range(10):
    x[i] = i     # 3. OK: x: Transient = setitem(x, i, i)
    x[i] = f(x)  # 4. NO: x: Transient = setitem(x, i, f(x))
    x[i] += 3    # 5. NO: x: Transient = setitem(x, i, getitem(x, i) + 3)

The first three cases are fine, because each definition of x is only used once. The last two are trickier, because there are two uses of x. Case 5 could be patched with an augitem primitive, but generally, I don't know if the restriction can easily be relaxed: we want to forbid any use of a transient variable after the first use that might mutate it, but that requires reasoning about operation order.

Regarding free variables, I think a use of a transient free variable would entail that the closure itself is transient and can only be called once. That doesn't seem super useful (and the variable cannot be used outside of the closure at all), so it could also just be forbidden.

Overall, though, I think "a variable that may be mutated can only be used once" may be an acceptable compromise. It's restrictive, but easy to reason about, probably easy to implement, and a typical error would basically be: "this use of this variable might mutate it, so you can't use it in this and that other call," along with a code listing and multiple carets.

I'll note that this will interact poorly with CSE since CSE may merge two expressions into a single expression with two uses (notably, it would merge all calls to zeros(shape), so good luck initializing different mutable arrays to zero). Thus we might have to do CSE after inference.

abergeron commented 6 years ago

Also note that another approach to this problem is to rewrite code that does inplace modifications to return the modified values and make sure the new values are used.

Something like:

def f(a, i):
    y = a[i]
    a[i] = 33
    return y

def main():
    x = [1, 2, 3]
    assert f(x, 1) == 2
    assert x[2] == 33

Would get rewritten to

def f(a_0, i):
    y = a_0[i]
    a_1 = setitem(a_0, i, 33)
    return a_1, y

def main():
    x_0 = [1, 2, 3]
    x_1, _tmp = f(x_0, 1)
    assert _tmp == 2
    assert x_1[2] == 33

This should allow us to preserve the semantics of python and support all the cases of augmented assignment with no surprises to the user. This should also allow the gradient pass to work through this with no modification since the graph is fully functional at this point.

As for the analysis needed, we can just always transform functions that act inplace as above and maybe flag them in some way so the parser (or a pass soon after it) knows to transform the callsites like above. This may get hairy in the presence of HOF, but should be workable with some conditionals in the graph.

breuleux commented 6 years ago

@abergeron That would work in relatively trivial cases, but it's pretty brittle, I think. Take this case, for example:

def f(a, i):
    y = a[i]
    a[i] = 33
    return y

def main(which):
    t = ([1, 2, 3], [1, 2, 3])
    x = t[which]
    first = t[0]
    second = t[1]
    assert f(x, 1) == 2
    assert x[2] == 33
    assert t[which][2] == 33
    assert first[2] != second[2]

Your transform would replace x, but it also has to replace t, and either first or second (not both, naturally), in order to preserve Python semantics. This kind of situation is going to be pretty common, e.g. when the user calls update(parameters.weights). To solve this, I suppose we could backtrack through every aliasing operation such as getitem or getattr and generate corresponding setitem/setattr instructions. We would also need to duplicate alias-creating operations on changed nodes. Thus we might generate something like this:

def f(a_0, i):
    y = a_0[i]
    a_1 = setitem(a_0, i, 33)
    return a_1, y

def main():
    t_0 = ([1, 2, 3], [4, 5, 6])
    x_0 = getitem(t, which)
    first_0 = getitem(t, 0)
    second_0 = getitem(t, 1)
    x_1, _tmp = f(x_0, 1)  # update x
    t_1 = setitem(t, which, x_1)  # update t
    first_1 = getitem(t_1, 0)  # re-acquire first from new t
    second_1 = getitem(t_1, 1)  # re-acquire second from new t
    assert _tmp == 2
    assert x_1[2] == 33
    assert t_1[which][2] == 33
    assert first_1[2] != second_1[2]

So, I don't know. It might get a bit heavy, and we still have to demonstrate that this would always preserve semantics, even in situations where first or second are independently modified.

The uniqueness type approach is simpler: f requires x to be transient because it modifies it, therefore f must be the only use of x. Since x is taken from indexing a tuple, we can only demonstrate that the return value is transient if the tuple is transient. However, the tuple is used a whopping four times, so this is an error. Even if the tuple was transient, x is used twice, so that's also an error. It's extremely restrictive, but at least I think it's pretty foolproof.

breuleux commented 6 years ago

Regarding HOF and rewriting, one approach would be to indiscriminately modify the return value of every function to return the new value of each parameter. For example,

a = f(x, y)

would be systemically rewritten as:

a, f, x, y = f(x, y)

If f modifies its closure, it would return a new f. Unmodified closure/arguments would be returned as they are.

The rest of the transform would consist of propagating changes through every operation that does aliasing (which includes closure creation). Of course, that, also, is a big problem, but there could be a generic way to handle it, e.g. upon setting x_1, x_0 = f(t_0) could entail t_1 = inv(f)(t_0, x_1), where e.g. inv(getitem) == setitem, and then y_0 = g(t_0) would entail y_1 = realias(g)(t_1) where realias(g) is a pure version of g that refreshes aliasing (realias(getitem) == getitem). In other words, we propagate changes backward through alias-creating operations (i.e. if we update a value we got from a tuple, we update that tuple), and then forward again (i.e. if we get other values from that tuple, we need to refresh them).

To be honest, I don't know if this procedure works. I don't know if it's consistent (update rules could clash), and I don't know if it even has an equilibrium, because at a glance it kind of looks like it'd loop forever. But the idea is to devise a transform that's purely syntactic and doesn't require analysis. It would result in utterly monstrous graphs (worse than grad) but that might be fine if the optimizer can simplify them.

breuleux commented 6 years ago

Hmm, my idea won't work, because it assumes parameters are not aliased at the beginning. If the body of f(x, y) modifies x, but it turns out y contains x, I have no way to track that.

Edit: unless changing x inserts the node y = copy_with_substitution(y, {x_0: x_1}). But that's getting sort of ridiculous and I'm sure there's some other cases I'm not considering.