mossprescott / pynand

Nand2Tetris in Python.
Other
12 stars 3 forks source link

Implement a simple type checker for Jack #41

Open mossprescott opened 2 months ago

mossprescott commented 2 months ago

When writing any kind of non-trivial Jack code, the complete lack of any type checking means both wasting time fixing silly errors and adding types that are effectively never used (except as documentation).

A simple type-checker would catch a lot of the simplest problems, although the language and its types are pretty weak.

Motivating example: the Scheme interpreter needs to convert between tagged values (typed as Obj) and pointers (Rib) before using them. These conversion are all over the code, which is fairly complex to begin with, and it's very easy to miss one. On the other hand, the interpreter also indexes directly into ribs to avoid the overhead of method calls, which will produce unhelpful type errors. Arguably, the fix for that is inlining the methods.

mossprescott commented 2 months ago

Here's an attempt that identifies errors in assignments. How to extend this to deal with errors in expressions as well?

"""Simple type-checker for Jack.
"""

from typing import Generator, List, Sequence
from nand.jack_ast import *

class TypeError(NamedTuple):
    loc: str
    severity: str
    msg: str
    stmt: Statement

class Member(NamedTuple):
    class_: str
    name: str
    type_: SubKind  # "function", "method", or "constructor"
    params: List[Type]
    result: Type

def exported(ast: Class) -> List[Member]:
    def subToType(sub: SubroutineDec) -> Member:
        return Member(ast.name, sub.name, sub.kind, [p.type for p in sub.params], sub.result)
    return [subToType(s) for s in ast.subroutineDecs]

def check_types(ast: Class, imports: Sequence[Member]) -> Generator[TypeError, None, None]:
    imports_by_name = {
        (m.class_, m.name): m
        for m in list(imports) + exported(ast)
    }

    class_vars_by_name = { n: vd for vd in ast.varDecs for n in vd.names }

    for sub in ast.subroutineDecs:
        params_by_name = { p.name: p.type for p in sub.params }
        local_vars_by_name = { n: vd for vd in sub.body.varDecs for n in vd.names }

        def expr_type(expr: Expression) -> Type:
            if isinstance(expr, IntegerConstant): return "int"
            elif isinstance(expr, StringConstant): return "str"
            elif isinstance(expr, KeywordConstant):
                if isinstance(expr.value, bool):
                    return "boolean"
                else:
                    raise Exception(f"keyword: {expr}?")

            elif isinstance(expr, VarRef):
                if expr.name in local_vars_by_name:
                    return local_vars_by_name[expr.name].type
                elif expr.name in params_by_name:
                    return params_by_name[expr.name]
                elif expr.name in class_vars_by_name:
                    return class_vars_by_name[expr.name].type
                else:
                    return "⊥"

            elif isinstance(expr, ArrayRef):
                # var_type = expr_type(VarRef(expr.name))
                return "any"

            elif isinstance(expr, SubroutineCall):
                if expr.class_name is not None:
                    # Qualified with the class name (ie. static)
                    cn = expr.class_name
                elif expr.var_name is not None:
                    # Referring to a local:
                    cn = expr_type(expr.var_name)
                else:
                    # No qualifier, refers to "this" implicitly:
                    cn = ast.name

                nt = (cn, expr.sub_name)
                return imports_by_name[nt].result

            elif isinstance(expr, BinaryExpression):
                left_type = expr_type(expr.left)
                right_type = expr_type(expr.right)
                if left_type == "int" and right_type == "int":
                    if expr.op.symbol in ("+-*/~&|"):
                        return "int"
                    elif expr.op.symbol in ("<>="):
                        return "boolean"
                    else:
                        return "⊥"
                if left_type == "Array" and expr.op.symbol in ("+-") and right_type == "int":
                    return "Array"
                if left_type not in ("int", "char", "boolean") and expr.op.symbol in ("+-") and right_type == "int":
                    # TODO: warn here
                    return left_type
                else:
                    return "⊥"
                    # raise Exception(f"binary: {left_type} {expr.op.symbol} {right_type}; {expr}?")

            elif isinstance(expr, UnaryExpression):
                child_type = expr_type(expr.expr)
                # TODO: inspect expr.op and do something smart
                if child_type == "int":
                    return "int"
                elif child_type == "boolean":
                    return "boolean"
                else:
                   return "⊥"
            else:
                raise Exception(f"{expr}?")

        def check_stmt(stmt: Statement):
            loc = f"{ast.name}.{sub.name}"

            if isinstance(stmt, LetStatement):
                if stmt.name in local_vars_by_name:
                    left_type = local_vars_by_name[stmt.name].type
                elif stmt.name in params_by_name:
                    left_type = params_by_name[stmt.name].type
                elif stmt.name in class_vars_by_name:
                    cvd = class_vars_by_name[stmt.name]
                    if sub.kind == "function" and not cvd.static:
                        yield TypeError(loc, "error", "reference to instance variable in function")
                    # TODO: catch other invalid references
                    left_type = cvd.type
                else:
                    yield TypeError(loc, "error", f'no variable "{stmt.name}" in scope', stmt)

                # TODO: yield errors from the expr
                right_type = expr_type(stmt.expr)

                if stmt.array_index is not None:
                    if left_type in ("int", "str", "boolean"):
                        yield TypeError(loc, "error", f"Indexing a non-array: {stmt.name} is {left_type}", stmt)
                    elif left_type == "Array":
                        return
                    else:
                        yield TypeError(loc, "info", f"Indexing a class: {stmt.name} is {left_type}", stmt)

                else:
                    if left_type == right_type:
                        pass
                    elif left_type not in ("int", "char", "boolean") and right_type == "Array":
                        yield TypeError(loc, "info", f"Implicit cast from Array to {left_type}", stmt)
                    elif left_type == "Array" and right_type == "int":
                        yield TypeError(loc, "info", f"Implicit cast from int to Array", stmt)
                    elif left_type not in ("int", "char", "boolean") and right_type == "int":
                        yield TypeError(loc, "warn", f"Implicit cast from int to {left_type}", stmt)
                    elif left_type != right_type:
                        yield TypeError(loc, "error", f"expected {left_type}, found {right_type}", stmt)

            else:
                #yield TypeError(str(stmt), "temp", "unrecognized")
                return

        for stmt in sub.body.statements:
            yield from check_stmt(stmt)

def stmt_to_str(stmt: Statement) -> str:
    """These errors really aren't useful if you can't easily read the code, since we don't have
    source locations.

    That said, this works only for simple, non-nested expressions.
    """

    if isinstance(stmt, LetStatement):
        if stmt.array_index is None:
            return f"{stmt.name} = {expr_to_str(stmt.expr)}"
        else:
            return f"{stmt.name}[{expr_to_str(stmt.array_index)}] = {expr_to_str(stmt.expr)}"
    else:
        return str(stmt)

def expr_to_str(expr: Expression) -> str:
    if isinstance(expr, IntegerConstant):
        return str(expr.value)
    elif isinstance(expr, StringConstant):
        return expr.value
    elif isinstance(expr, KeywordConstant):
        return str(expr.value)
    elif isinstance(expr, VarRef):
        return expr.name
    elif isinstance(expr, ArrayRef):
        return f"{expr.name}[{expr_to_str(expr.array_index)}]"
    elif isinstance(expr, SubroutineCall):
        if expr.class_name is not None:
            prefix = f"{expr.class_name}."
        elif expr.var_name is not None:
            prefix = f"{expr.class_name}."
        else:
            prefix = ""
        args = ", ".join(expr_to_str(x) for x in expr.args)
        return f"{prefix}{expr.sub_name}({args})"
    elif isinstance(expr, BinaryExpression):
        return f"{expr_to_str(expr.left)} {expr.op.symbol} {expr_to_str(expr.right)}"
    elif isinstance(expr, UnaryExpression):
        return f"{expr.op.symbol}{expr_to_str(expr.expr)}"
    else:
        raise Exception(f"unexpected: {expr}")