csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

Expression Evaluator exposure #2194

Open csarofeen opened 1 year ago

csarofeen commented 1 year ago

🚀 The feature, motivation and pitch

It would be good to make sure we have strong exposure of our expression evaluator. Going to note some interfaces we would likely want to expose:

Input binding, I'm unifying this in the vectorization PR: https://github.com/csarofeen/pytorch/pull/2124/files#diff-1a2fb0763b31521224aad22ae7a7a7cc60a20635da4580eb9fcc46c500856b93R42

This may be good to have for any runtime checks through nvFuser IR. We could even represent operations in nvFuser IR that nvFuser won't even run. Thinking generic symbolic IR support.

Expression Evaluator that's returned has it's own functions to evaluate. The interface is: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/expr_evaluator.h#L43

Which returns an empty optional if it can't be evaluated. EvaluatorValue is our dynamic type defined in: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/dynamic_type.h

Much of the file can be ignored, the pattern/incantation of interest is: bool EvaluatorValue::isInt() bool EvaluatorValue::isDouble() bool EvaluatorValue::isBool()

Then you can grab the value with: int64_t EvaluatorValue::as<int64_t>() double EvaluatorValue::as<double>() bool EvaluatorValue::as<bool>()

The above can be paired with simple compile time helpers on Val*, specifically: Val::isBool(), Val::isDouble(), and Val::isAnInt() can be used to check what the type the Val is associated with (and makes sure it's a Scalar value not another type). This type matches the dynamic_type type.

There's also Val::isConstScalar() which will tell you if the scalar is actually a compile time constant i.e. doesn't need runtime bindings to evaluate; with this you can then call int64_t Val::evaluateInt(), double Val::evaluateDouble(), and bool Val::evaluateBool() directly to get the compile time constant evaluated value.

Arith functions should work perfectly fine on these scalar Val*'s the list of supported ops for evaluation can be found in: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp#L169-L295

Their runtime implementation is in: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/dynamic_type.h#L58-L316

Alternatives

No response

Additional context

No response

csarofeen commented 1 year ago

Forgot to mention but may want to add later, there's an option for Expression Evaluator to precompute any value it can for faster evaluation. This is an optimization we use to lower latency of kernel launches.