triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.83k stars 1.55k forks source link

[RFC] Unfungle Triton Import Graph? #884

Open crutcher opened 1 year ago

crutcher commented 1 year ago

triton module imports are currently very circular; many modules import other modules while they are only partially initialized. The behavior of core.py, jit.py, and semantic.py are a particular example.

With some incremental work, I could unroll these modules into a strictly ordered graph; it could even be done without changing the appearance of the public API, introducing a graph of incremental declaration components, and re-exporting them in the appropriate modules to maintain compatibility.

this transformation would also help bring mypy scanning to triton

ptillet commented 1 year ago

Yes I support that! I think it should be done inside of the triton-mlir branch though (unless it happens after the merge)

crutcher commented 1 year ago

I've put some time into trying to take a stab at this on main; and the approach I've been pursuing is to unroll the entire lib into one very long module; and trying to get that module in a clean declaration order.

semantically, due to the construction order of python module imports, import graphs are forced to be equivalent to some ordered import stack; so as a general strategy, this is a good approach. specifically, triton has a lot of reflexive self-knowledge and meta-programming; and understanding where those points of reference are is complicated.

if I can get a clean single-module instance, or at least for everything that comes in with import triton; I can recover a satisfying declaration order, and that's a good starting point to plan a module heirarchy; which could be replayed (mechanically, not by merge tools) against the new target branch.

crutcher commented 1 year ago

@ptillet @Jokeren re: the discussion about the utility of mypy

I, personally, am a strong proponent of full-mypy on all python projects.

It is a pain to move a project that isn't mypy-clean to a one that is; because complex codebases which weren't using mypy tend to use construction orders and type unions which don't have clean types, and work needs to be done to massage a codebase into compliance; not just adding type annotations, but in some places changing design decisions so that type annotations make sense.

A lot of the current code would be cleaner if constexpr (and possibly tensor) were generic classes with specialization types constexpr[int] is easy; but it's a harder pull for tensor, where you may want to also model polyhedral types, and I'm not sure python's up for that yet.

But, when a codebase is mypy-clean, you don't just get type checking, and you don't just get a lot of guardrails on code semantics; you also enable code indexing and completion and refactoring in strong IDEs (such as IntelliJ) which approach what they can do in Java.

The type annotations make the IDEs more powerful when working on the codebase, which in turn accelerates codebase development by cheapening refactors

ptillet commented 1 year ago

I agree with all of that! Beyond type annotations, mypy will force us to follow good coding practices. This will also make code reviews easier.

Jokeren commented 1 year ago

Sounds good

crutcher commented 1 year ago

I've got a flattend version up, as a first pass of the above strategy.

crutcher commented 1 year ago

I've got a re-rolled version up for discussion: https://github.com/openai/triton/pull/887

This is still not based upon the MLIR branch, but it sorts through the import order and demonstrates that mypy typing is tractable

ptillet commented 1 year ago

Awesome! Thank you so much for this :) As I said in the PR, we should probably aim to do it first thing after triton-mlir is merged, but I'm super excited!

crutcher commented 1 year ago

Here's the core spine of the order; I've omitted the __all__ statements in this comment, but they're present and important to get tools which import these libraries to perform well.

I did a lot trying to trackdown / squash mypy errors. I came to a few conclusions:

by pulling all non-impl methods out of the standard lib; the true core was reduced to about 3k LOC; and then the standard lib was able to be broken up into tractable submodules, mostly with their impl bodies co-located in the code.

I made no effort here to split the tests up; but that's an obvious next step in a real refactor, to align the tests with the sub-modules they test, to speed up iteration.

There are some packages (such as triton/runtime) that exist solely to maintain backwards compat with existing imports.

triton/__init__.py

from .utils import cdiv, next_power_of_2

from triton._C.libtriton.triton import ir

from .impl import (
    autotune,
    CompilationError,
    compile,
    CompiledKernel,
    Config,
    extern,
    heuristics,
    Heuristics,
    jit,
    JITFunction,
    KernelInterface,
    MockTensor,
    OutOfResources,
    reinterpret,
    TensorWrapper,
)

from . import language

from . import runtime
from . import testing
from . import ops

triton/impl/__init__.py

from .. import ir

from ..utils import MockTensor

from .base import TensorWrapper, reinterpret

# No @tr.jit() interface can be called until after .compiler is loaded;
# and .compiler depends upon the core stack.
from .jitlib import (
    extern,
    jit,
    JITFunction,
    KernelInterface,
)

# .compiler depends upon core.minimum and core.where
from .compiler import (
    CompilationError,
    compile,
    CompiledKernel,
    OutOfResources,
)

from .autotuner import (
    autotune,
    Config,
    heuristics,
    Heuristics,
)

triton/language/__init__.py

from ..impl.base import (
    _add,
    _and_,
    bfloat16,
    _binary_op_type_checking_impl,
    _bitcast,
    _bitwise_op_type_checking_impl,
    block_type,
    _bool_like,
    _broadcast_impl_shape,
    _broadcast_impl_shape,
    _broadcast_impl_value,
    _broadcast_impl_value,
    builtin,
    _cast,
    _check_ptr_type_impl,
    _computation_type_impl,
    constexpr,
    _constexpr_to_value,
    cvalue,
    dtype,
    _equal,
    _fdiv,
    float16,
    float32,
    float64,
    float8,
    _floordiv,
    _greater_equal,
    _greater_than,
    int1,
    int16,
    int32,
    int64,
    int8,
    _integer_promote_impl,
    _invert,
    ir,
    is_builtin,
    is_triton_tensor,
    _less_equal,
    _less_than,
    _lshr,
    _mod,
    _mul,
    _not_equal,
    _or_,
    pi32_t,
    pointer_type,
    reinterpret,
    _reshape,
    _shl,
    _sub,
    tensor,
    TensorWrapper,
    _to_tensor,
    _truediv,
    uint16,
    uint32,
    uint64,
    uint8,
    void,
    _where,
    _xor_,
)
from ..impl.core import (
    minimum,
    where,
)
from .meta import (
    _globaltimer,
    globaltimer,
    _clock,
    clock,
    debug_barrier,
    _program_id,
    program_id,
    _num_programs,
    num_programs,
    _multiple_of,
    multiple_of,
    _max_contiguous,
    max_contiguous,
)
from .constructors import (
    _arange,
    arange,
    _zeros,
    zeros,
    zeros_like,
    _cat,
    cat,
)
from .quantization import (
    _dequantize,
    dequantize,
)
from .broadcasting import (
    broadcast,
    broadcast_to,
)
from .transfer import (
    _load,
    load,
    _store,
    store,
)
from .structure import (
    reshape,
    ravel,
    swizzle2d,
)
from .matrix import (
    _dot,
    dot,
)
from .atomic import (
    _atomic_cas,
    atomic_cas,
    _atomic_xchg,
    atomic_xchg,
    _atomic_add,
    atomic_add,
    _atomic_max,
    atomic_max,
    _atomic_min,
    atomic_min,
    atomic_and,
    _atomic_or,
    atomic_or,
    _atomic_xor,
    atomic_xor,
)
from .math import (
    _umulhi,
    umulhi,
    fdiv,
    _exp,
    exp,
    _log,
    log,
    _cos,
    cos,
    _sin,
    sin,
    _sqrt,
    sqrt,
    abs,
    cdiv,
)
from .reductions import (
    _argmax,
    argmax,
    _argmin,
    argmin,
    _max,
    max,
    maximum,
    _min,
    min,
    _reduce_impl,
    _reduce_impl,
    sigmoid,
    softmax,
    _sum,
    sum,
    _xor_sum,
    _xor_sum,
    xor_sum,
)
from . import random
from .random import (
    pair_uniform_to_normal,
    philox,
    philox_impl,
    rand,
    rand4x,
    randint,
    randint4x,
    randn,
    randn4x,
    triton,
    uint32_to_uniform_float,
)
from . import libdevice
crutcher commented 1 year ago

I've got an outstanding PR against triton-mlir which reorders the input graph. @ptillet has suggested that this is too much change to land before triton-mlir lands; so we're going to hold off on it.

That said, the core ideas are worth talking about:

this is intended to be enabling tech for threading mypy types through everything (since the definitions will now be visible for all the interfaces.

Additionally, I worked out how (in the other PR) to make the insides of @triton.builtin methods analyzable by mypy; with the small help of a cvalue(<constexpr>) special form handled by the complier.