revalo / tree-diffusion

Diffusion on syntax trees for program synthesis
https://tree-diffusion.github.io
MIT License
386 stars 19 forks source link

Fix end position bug #9

Open nathanvogt opened 1 week ago

nathanvogt commented 1 week ago

If multiple different rules begin at the same start position, the wrong end position can be chosen. Changed so that the chosen position rule is taken into account when setting the end position.

revalo commented 1 week ago

Thanks for the contribution! Do you have any examples where this is actually a problem, just so I can reproduce the bug on my end?

nathanvogt commented 1 week ago

This is a slightly contrived example but modifying tinysvgoffset.py to the following will eventually cause an error when running eval_td.py. I added a custom_color rule that can cause conflicts with the start position.

from typing import Tuple

import iceberg as ice
from lark import Transformer, Tree
from lark.visitors import v_args

from td.environments.environment import Environment
from td.environments.goal_checker import GaussianImageGoalChecker
from td.grammar import Compiler, Grammar

_grammar_spec = r"""
custom_color: number " " number " " number " " number | color // RGBA or color
// s: arrange | move | pad | compose | rect | ellipse
s: arrange | rect | ellipse
direction: "v" -> v | "h" -> h
color: "red" -> red | "green" -> green | "blue" -> blue | "yellow" -> yellow | "purple" -> purple | "orange" -> orange | "black" -> black | "white" -> white | "none" -> none
number: "0" -> zero | "1" -> one | "2" -> two | "3" -> three | "4" -> four | "5" -> five | "6" -> six | "7" -> seven | "8" -> eight | "9" -> nine
sign: "+" -> plus | "-" -> minus
boolean: "true" -> true | "false" -> false

// Rectangle w h fillcolor strokecolor strokewidth ox oy
rect: "(" "Rectangle" " " number " " number " " custom_color " " custom_color " " number " " sign number " " sign number ")"

// Ellipse w h fillcolor strokecolor strokewidth ox oy
ellipse: "(" "Ellipse" " " number " " number " " custom_color " " custom_color " " number " " sign number " " sign number ")"

// Arrange direction left right gap
arrange: "(" "Arrange" " " direction " " s " " s " " number ")"

// Move x y negx negy
// move: "(" "Move" " " s " " number " " number " " boolean " " boolean ")"

// Pad t r b l
// pad: "(" "Pad" " " s " " number " " number " " number " " number ")"

// Compose without arranging
// compose: "(" "Compose" " " s " " s ")"

%ignore /[\t\n\f\r]+/ 
"""

_CANVAS_WIDTH = 224
_CANVAS_HEIGHT = 224

_ice_renderer = ice.Renderer(gpu=False)
_ice_canvas = ice.Blank(
    ice.Bounds(size=(_CANVAS_WIDTH, _CANVAS_HEIGHT)), ice.Colors.WHITE
)

class _Move(ice.Drawable):
    child: ice.Drawable
    x: float
    y: float

    def setup(self):
        self._child_bounds = self.child.bounds
        self._moved = self.child.move(self.x, self.y)

    @property
    def bounds(self):
        return self._child_bounds

    def draw(self, canvas):
        self._moved.draw(canvas)

    @property
    def children(self):
        return [self._moved]

class TSVGOToIceberg(Transformer):
    def __init__(
        self,
        visit_tokens: bool = True,
        stroke_width_divisor: float = 2.0,
        size_multiplier: float = 6.0,
    ) -> None:
        super().__init__(visit_tokens)
        self._stroke_width_divisor = stroke_width_divisor
        self._size_multiplier = size_multiplier

    def custom_color(self, children):
        if len(children) == 4:
            return ice.Color(
                children[0] / 255,
                children[1] / 255,
                children[2] / 255,
                children[3] / 255,
            )
        else:
            return children[0]

    @v_args(meta=True)
    def rect(self, meta, children):
        w, h, fill_color, stroke_color, stroke_width, signx, ox, signy, oy = children

        stroke_width = stroke_width / self._stroke_width_divisor
        w = w * self._size_multiplier
        h = h * self._size_multiplier

        ox = ox * signx * self._size_multiplier
        oy = oy * signy * self._size_multiplier

        rv = ice.Rectangle(
            ice.Bounds(size=(w, h)),
            fill_color=fill_color,
            border_color=stroke_color,
            border_thickness=stroke_width,
            anti_alias=False,
            dont_modify_bounds=True,
        )
        rv._lark_meta = meta

        rv = _Move(child=rv, x=ox, y=oy)

        return rv

    @v_args(meta=True)
    def ellipse(self, meta, children):
        w, h, fill_color, stroke_color, stroke_width, signx, ox, signy, oy = children
        stroke_width = stroke_width / self._stroke_width_divisor
        w = w * self._size_multiplier
        h = h * self._size_multiplier

        ox = ox * signx * self._size_multiplier
        oy = oy * signy * self._size_multiplier

        rv = ice.Ellipse(
            rectangle=ice.Bounds(size=(w, h)),
            fill_color=fill_color,
            border_color=stroke_color,
            border_thickness=stroke_width,
            anti_alias=True,
            dont_modify_bounds=True,
        )
        rv._lark_meta = meta

        rv = _Move(child=rv, x=ox, y=oy)

        return rv

    def arrange(self, children):
        direction, left, right, gap = children

        return ice.Arrange(
            [left, right],
            arrange_direction=(
                ice.Arrange.Direction.HORIZONTAL
                if direction == "h"
                else ice.Arrange.Direction.VERTICAL
            ),
            gap=gap,
        )

    def move(self, children):
        drawable, x, y, negx, negy = children

        x = x * self._size_multiplier
        y = y * self._size_multiplier

        x = x if not negx else -x
        y = y if not negy else -y
        return _Move(child=drawable, x=x, y=y)

    def s(self, children):
        return children[0]

    def v(self, _):
        return "v"

    def h(self, _):
        return "h"

    def red(self, _):
        return ice.Colors.RED

    def green(self, _):
        return ice.Colors.GREEN

    def blue(self, _):
        return ice.Colors.BLUE

    def yellow(self, _):
        return ice.Colors.YELLOW

    def purple(self, _):
        return ice.Color.from_hex("#800080")

    def orange(self, _):
        return ice.Color.from_hex("#FFA500")

    def black(self, _):
        return ice.Colors.BLACK

    def white(self, _):
        return ice.Colors.WHITE

    def none(self, _):
        return None

    def zero(self, _):
        return 0

    def one(self, _):
        return 1

    def two(self, _):
        return 2

    def three(self, _):
        return 3

    def four(self, _):
        return 4

    def five(self, _):
        return 5

    def six(self, _):
        return 6

    def seven(self, _):
        return 7

    def eight(self, _):
        return 8

    def nine(self, _):
        return 9

    def true(self, _):
        return True

    def false(self, _):
        return False

    def plus(self, _):
        return 1

    def minus(self, _):
        return -1

class TSVGOCompiler(Compiler):
    def __init__(self) -> None:
        super().__init__()
        self._expression_to_iceberg = TSVGOToIceberg()

    def compile(self, expression: Tree):
        drawable = self._expression_to_iceberg.transform(expression)
        scene = ice.Anchor((_ice_canvas, _ice_canvas.add_centered(drawable)))
        # return scene
        _ice_renderer.render(scene)
        rv = _ice_renderer.get_rendered_image()[:, :, :3] / 255.0

        return rv

class TinySVGOffset(Environment):
    def __init__(self) -> None:
        super().__init__()

        self._grammar = Grammar(
            _grammar_spec,
            start="s",
            primitives=["rect", "ellipse"],
        )

        self._compiler = TSVGOCompiler()
        self._goal_checker = GaussianImageGoalChecker(self.compiled_shape, sigma=0.1)

    @property
    def grammar(self) -> Grammar:
        return self._grammar

    @property
    def compiler(self) -> Compiler:
        return self._compiler

    @property
    def compiled_shape(self) -> Tuple[int, ...]:
        return _CANVAS_WIDTH, _CANVAS_HEIGHT, 3

    @classmethod
    def name(self) -> str:
        return "tinysvgoffset"

    def goal_reached(self, compiledA, compiledB) -> bool:
        return self._goal_checker.goal_reached(compiledA, compiledB)