chalk-diagrams / chalk

A declarative drawing API in Python
MIT License
267 stars 13 forks source link

Semantics for batched Diagrams? #138

Open srush opened 1 week ago

srush commented 1 week ago

This is actually a pretty interesting question that I'm stuck on. In Jax, I'm thinking of a diagram as a https://jax.readthedocs.io/en/latest/pytrees.html . A pytree is basically a tree of arrays. When you call vmap, and return a diagram, Jax returns a diagram where there is an extra dimension on all the arrays.

@jax.vmap
def draw(i):
    return circle(i)
draw(np.arange(10))

This object is a Primitive where the transform/style has a batch of 10 in front of it. By default I am interpreting this as a concat of 10 primitives. However one might also interpret it as an animation with 10 frames.

The question is what happens if you try to compose this with another object? Think in this case you just have 15 composed elements.

draw(np.arange(10)) + draw(np.arange(5))

But is that the same as this case? Here you have a Compose node that also has a Batch dimension on it that applies to both its children.

@jax.vmap
def draw(i):
    return circle(i) + circle(i+10)
draw(np.arange(10))

You also have the case where there are multiple vmaps, in this case I think the concat should just flatten them and draw them in order.

@jax.vmap
def draw(i):
    @jax.vmap
    def draw(j):
        return circle(i) + circle(j+10)

However I think it would be nice if whatever we do here works both for animation and drawing. Like there is some notion of a composable sequence of diagrams either in z-space or time that corresponds to this tree idea.

srush commented 1 week ago

Hmm, I guess the more correct way to do this is to consider a batched diagram like a list of diagrams. You could then apply cat or chat or tcat along the batched axes.

You would then need to do:

concat(draw(np.arange(10)), axis=0) + hcat(draw(np.arange(5)), axis=0)

Maybe internally though this could still store it as 2 arrays instead of flattening to 15.

Need to think if this works for the other cases.

srush commented 1 week ago

I ended up going with the following design.

@jax.vmap
def outer(j):
    @jax.vmap
    def inner(i):
        return (circle(0.3 * i / 6).fill_color(np.ones(3) * i / 6) + 
                square(0.1).fill_color("white")).scale(1)
    inside = inner(np.arange(2, 5))
    return vcat(inside).scale(1)
out = outer(np.arange(1, 6))
print("My Size", out.size())
d = hcat(out)

I think this solves my issue without adding too much complexity to the system. The main difference is that now diagrams can have a .size() like (4, 3). This corresponds to the fact that they may have an internal batch dimension. You can only render diagrams of size (), meaning it is one diagram. You can now call hcat directly on diagrams with size != () and it will compose the inner dimension. So if it was (4,3) it becomes (4,). Internally it does the same thing that we do before but in jax it can do it all on matrices since they have a built in associative reduce https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.associative_scan.html

There are a couple core change to the system to make this work. One is a new node ComposeAxis which should be identical to compose but runs along the inner axis of the tree. Second is a Size visitor which determines the size by walking down the tree and taking ComposeAxis nodes into account. Finally think might try replacing the functional versions of Envelope/Trace with a Visitor pattern over diagram. Think it is roughly the same, but it's hard to debug the repeated function call thing, and it doesn't play nice with jax.

(I'll add to the PR when stable).

danoneata commented 1 week ago

Okay, this is indeed interesting! Let me see if I understood this right:

One is a new node ComposeAxis which should be identical to compose but runs along the inner axis of the tree.

The original Compose node also caches the envelope of the composition. I wonder whether this is needed for the ComposeAxis, since we treat the diagrams independently.

Finally think might try replacing the functional versions of Envelope/Trace with a Visitor pattern over diagram

I'm not completely sure I understand what this change implies. But it will surely get clear once I see the code 🙂

srush commented 1 week ago

Yup, that's a good summary. Still working out the details, but I think the semantics make sense.

I can even imagine the render method working similarly by saving each diagram in the array.

Yeah! I think the trick here is kind of cool. If you render batched primitives, you have a List[Prim] but they may no longer be in order (elements of the prim 2 may need to be in front of prim 1). But you can have an array .order of size size() on Primitive that has the z-order of the primitives and then sort before calling Cairo. Contructing order is really neat algorithm. I'm sure it has some clever functional name, but I hadn't seen it before.

  • The original Compose node also caches the envelope of the composition. I wonder whether this is needed for the ComposeAxis, since we treat the diagrams independently.

Right good point. This was exactly the hard part since having functions on the trees is not allowed in Jax. So instead I was thinking an envelope "function" could just a remember what node it is at, and then tree walk from there. Instead of the function being a monoid, the EnvDistance is just a monoid. We don't really do too many crazy things with Envelopse/Traces so this seems to be fine (I think pad / frame is the only place we play with the implicit form).

If you want to override the envelope on a Compose node, you need to pass in a new diagram. I think this is already the user facing semantic diagram.with_envelope(new_diagram_for_envelope)

@dataclass
class EnvDistance(Monoid):
    d: Scalars

    def __add__(self, other: Self) -> Self:
        return EnvDistance(tx.X.np.maximum(self.d, other.d))

    @staticmethod
    def empty() -> EnvDistance:
        return EnvDistance(tx.X.np.asarray(-1e5))

    def reduce(self, axis=0):
        return EnvDistance(tx.X.np.max(self.d, axis=axis))

class Envelope(Transformable, Monoid):
    diagram: Diagram
    affine: Affine

    def __call__(self, direction: V2_t) -> Scalars:
         self.diagram.accept(ApplyEnvelope(), EnvDistance.empty())

...
class ApplyEnvelope(DiagramVisitor[EnvDistance, V2_t]):
    A_type = EnvDistance

    def visit_primitive(self, diagram: Primitive, t: V2_t) -> EnvDistance: