Open crutcher opened 1 year ago
Yes I support that! I think it should be done inside of the triton-mlir
branch though (unless it happens after the merge)
I've put some time into trying to take a stab at this on main
; and the approach I've been pursuing is to unroll the entire lib into one very long module; and trying to get that module in a clean declaration order.
semantically, due to the construction order of python module imports, import graphs are forced to be equivalent to some ordered import stack; so as a general strategy, this is a good approach. specifically, triton has a lot of reflexive self-knowledge and meta-programming; and understanding where those points of reference are is complicated.
if I can get a clean single-module instance, or at least for everything that comes in with import triton
; I can recover a satisfying declaration order, and that's a good starting point to plan a module heirarchy; which could be replayed (mechanically, not by merge tools) against the new target branch.
@ptillet @Jokeren re: the discussion about the utility of mypy
I, personally, am a strong proponent of full-mypy on all python projects.
It is a pain to move a project that isn't mypy-clean to a one that is; because complex codebases which weren't using mypy tend to use construction orders and type unions which don't have clean types, and work needs to be done to massage a codebase into compliance; not just adding type annotations, but in some places changing design decisions so that type annotations make sense.
A lot of the current code would be cleaner if constexpr
(and possibly tensor
) were generic classes with specialization types constexpr[int]
is easy; but it's a harder pull for tensor
, where you may want to also model polyhedral types, and I'm not sure python's up for that yet.
But, when a codebase is mypy-clean, you don't just get type checking, and you don't just get a lot of guardrails on code semantics; you also enable code indexing and completion and refactoring in strong IDEs (such as IntelliJ) which approach what they can do in Java.
The type annotations make the IDEs more powerful when working on the codebase, which in turn accelerates codebase development by cheapening refactors
I agree with all of that! Beyond type annotations, mypy
will force us to follow good coding practices. This will also make code reviews easier.
Sounds good
I've got a flattend version up, as a first pass of the above strategy.
I've got a re-rolled version up for discussion: https://github.com/openai/triton/pull/887
This is still not based upon the MLIR branch, but it sorts through the import order and demonstrates that mypy typing is tractable
Awesome! Thank you so much for this :) As I said in the PR, we should probably aim to do it first thing after triton-mlir is merged, but I'm super excited!
Here's the core spine of the order; I've omitted the __all__
statements in this comment, but they're present and important to get tools which import these libraries to perform well.
I did a lot trying to trackdown / squash mypy errors. I came to a few conclusions:
constexpr
requires that unpacking them have type-support in the body of methods; but that can be made a free-op in jitted code by recognizing the unpacking function.by pulling all non-impl methods out of the standard lib; the true core was reduced to about 3k LOC; and then the standard lib was able to be broken up into tractable submodules, mostly with their impl bodies co-located in the code.
I made no effort here to split the tests up; but that's an obvious next step in a real refactor, to align the tests with the sub-modules they test, to speed up iteration.
There are some packages (such as triton/runtime
) that exist solely to maintain backwards compat with existing imports.
triton/__init__.py
from .utils import cdiv, next_power_of_2
from triton._C.libtriton.triton import ir
from .impl import (
autotune,
CompilationError,
compile,
CompiledKernel,
Config,
extern,
heuristics,
Heuristics,
jit,
JITFunction,
KernelInterface,
MockTensor,
OutOfResources,
reinterpret,
TensorWrapper,
)
from . import language
from . import runtime
from . import testing
from . import ops
triton/impl/__init__.py
from .. import ir
from ..utils import MockTensor
from .base import TensorWrapper, reinterpret
# No @tr.jit() interface can be called until after .compiler is loaded;
# and .compiler depends upon the core stack.
from .jitlib import (
extern,
jit,
JITFunction,
KernelInterface,
)
# .compiler depends upon core.minimum and core.where
from .compiler import (
CompilationError,
compile,
CompiledKernel,
OutOfResources,
)
from .autotuner import (
autotune,
Config,
heuristics,
Heuristics,
)
triton/language/__init__.py
from ..impl.base import (
_add,
_and_,
bfloat16,
_binary_op_type_checking_impl,
_bitcast,
_bitwise_op_type_checking_impl,
block_type,
_bool_like,
_broadcast_impl_shape,
_broadcast_impl_shape,
_broadcast_impl_value,
_broadcast_impl_value,
builtin,
_cast,
_check_ptr_type_impl,
_computation_type_impl,
constexpr,
_constexpr_to_value,
cvalue,
dtype,
_equal,
_fdiv,
float16,
float32,
float64,
float8,
_floordiv,
_greater_equal,
_greater_than,
int1,
int16,
int32,
int64,
int8,
_integer_promote_impl,
_invert,
ir,
is_builtin,
is_triton_tensor,
_less_equal,
_less_than,
_lshr,
_mod,
_mul,
_not_equal,
_or_,
pi32_t,
pointer_type,
reinterpret,
_reshape,
_shl,
_sub,
tensor,
TensorWrapper,
_to_tensor,
_truediv,
uint16,
uint32,
uint64,
uint8,
void,
_where,
_xor_,
)
from ..impl.core import (
minimum,
where,
)
from .meta import (
_globaltimer,
globaltimer,
_clock,
clock,
debug_barrier,
_program_id,
program_id,
_num_programs,
num_programs,
_multiple_of,
multiple_of,
_max_contiguous,
max_contiguous,
)
from .constructors import (
_arange,
arange,
_zeros,
zeros,
zeros_like,
_cat,
cat,
)
from .quantization import (
_dequantize,
dequantize,
)
from .broadcasting import (
broadcast,
broadcast_to,
)
from .transfer import (
_load,
load,
_store,
store,
)
from .structure import (
reshape,
ravel,
swizzle2d,
)
from .matrix import (
_dot,
dot,
)
from .atomic import (
_atomic_cas,
atomic_cas,
_atomic_xchg,
atomic_xchg,
_atomic_add,
atomic_add,
_atomic_max,
atomic_max,
_atomic_min,
atomic_min,
atomic_and,
_atomic_or,
atomic_or,
_atomic_xor,
atomic_xor,
)
from .math import (
_umulhi,
umulhi,
fdiv,
_exp,
exp,
_log,
log,
_cos,
cos,
_sin,
sin,
_sqrt,
sqrt,
abs,
cdiv,
)
from .reductions import (
_argmax,
argmax,
_argmin,
argmin,
_max,
max,
maximum,
_min,
min,
_reduce_impl,
_reduce_impl,
sigmoid,
softmax,
_sum,
sum,
_xor_sum,
_xor_sum,
xor_sum,
)
from . import random
from .random import (
pair_uniform_to_normal,
philox,
philox_impl,
rand,
rand4x,
randint,
randint4x,
randn,
randn4x,
triton,
uint32_to_uniform_float,
)
from . import libdevice
I've got an outstanding PR against triton-mlir which reorders the input graph. @ptillet has suggested that this is too much change to land before triton-mlir lands; so we're going to hold off on it.
That said, the core ideas are worth talking about:
this is intended to be enabling tech for threading mypy types through everything (since the definitions will now be visible for all the interfaces.
Additionally, I worked out how (in the other PR) to make the insides of @triton.builtin
methods analyzable by mypy; with the small help of a cvalue(<constexpr>)
special form handled by the complier.
triton module imports are currently very circular; many modules import other modules while they are only partially initialized. The behavior of
core.py
,jit.py
, andsemantic.py
are a particular example.With some incremental work, I could unroll these modules into a strictly ordered graph; it could even be done without changing the appearance of the public API, introducing a graph of incremental declaration components, and re-exporting them in the appropriate modules to maintain compatibility.
this transformation would also help bring
mypy
scanning totriton