zhuhaozhe / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
0 stars 1 forks source link

Te inplace v2 #3

Closed zhuhaozhe closed 2 years ago

zhuhaozhe commented 2 years ago

Summary

This PR aims to support inplace op it TE partially. This can enable TE fusion for patterns like "at::conv, at::relu". "at::add, at::relu, at::add". We can get better performance by this.

Options

Option 1: Replace the in-place operator with the out-place operator

Option 2: Lower the body in terms of in-place directly

We choose option 1 in this PR, and we can consider option 2 if we observed that in many real-world scenarios, options1 failed.

Implement Details

image

We will pull inplace ops into TE fusion groups if the ops satisfy 2 conditionals:

To achieve this, we extend the behavior of Operator supported check and TryMerge.

In Operator supported check.

For example, isSupported. If an inplace op can be safely replaced with its outplace version. We will create an outplace node and send this node to pass the TE checks. After the checks are done, destroy the outplace node.

In TryMerge.

Here we can know all checks are passed, we will replace an inplace op with its outplace version before the moment we merge the node into fusion groups.

Some details for safely replacement

Whether an inplace op can be replaced safely depends on the behavior of RemoveTensorMutions. 2 cases below will not be replaced.

def fn(a, b):
    return a.relu_() + b.relu()
def fn(a, b):
    c = a + b
    return c.sigmoid().add(c.relu_())

Ops like __and__/__or__. Their inplace version are __iand__/__ior__. These ops cannot be replaced by RemoveTensorMutaion. So we do not support them yet.

Unit test details

Since only element-wise ops will own outplace version, we show the Op coverage status based on this TE support list https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp#L11.

UT covered ops

      {"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"}, -> test_ternary_ops
      {"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", "broadcast_three"}, -> test_ternary_ops
      {"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor", "unary"}, -> test_clamp_int, test_clamp_double
      {"aten::lgamma(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::log10(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::log(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::log2(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::log1p(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::exp(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::erf(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::erfc(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::cos(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::sin(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::tan(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::acos(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::asin(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::atan(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::cosh(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::sinh(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::tanh(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::hardsigmoid(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::hardswish(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::sqrt(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::rsqrt(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::abs(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::floor(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::ceil(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::round(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::trunc(Tensor self) -> Tensor", "unary"}, -> test_unary_ops -> test_unary_ops
      {"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", "unary"},
      {"aten::sigmoid(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::relu(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor", "unary"}, ->no inplace version
      {"aten::relu6(Tensor self) -> Tensor", "unary"},  -> test_unary_ops
      {"aten::gelu(Tensor self, *, str approximate='none') -> Tensor", "unary"}, ->test_gelu
      {"aten::neg(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::reciprocal(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::expm1(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::frac(Tensor self) -> Tensor", "unary"}, -> test_unary_ops
      {"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"}, ->test_binary_scalar_ops
      {"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},, ->test_binary_scalar_ops
      {"aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},, ->test_binary_scalar_ops
      {"aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},, ->test_binary_scalar_ops
      {"aten::eq.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},, ->test_binary_scalar_ops
      {"aten::ne.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},, ->test_binary_scalar_ops
      {"aten::ge.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},, ->test_binary_scalar_ops
      {"aten::gt.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},, ->test_binary_scalar_ops
      {"aten::le.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},, ->test_binary_scalar_ops
      {"aten::lt.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},, ->test_binary_scalar_ops
      {"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", "unary"},, ->test_binary_pow
      {"aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::div.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::eq.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::ne.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::ge.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::gt.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::le.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::lt.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::atan2(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::min.other(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops
      {"aten::max.other(Tensor self, Tensor other) -> Tensor", "broadcast"}, ->test_binary_ops

Element-wise ops with outplace support and without inplace ops support:

      {"aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"}, -> not supported, cannot be convert by RemoveTensorMutation
      {"aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"}, -> not supported, cannot be convert by RemoveTensorMutation
      {"aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"}, -> not supported, cannot be convert by RemoveTensorMutation
      {"aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"}, -> not supported, cannot be convert by RemoveTensorMutation
      {"aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"}, -> not supported, cannot be convert by RemoveTensorMutation
      {"aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", "unary"}, -> no inplace version
      {"aten::to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", "unary"}, -> no inplace version
      {"aten::to.dtype_layout(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", "unary"}, -> no inplace version
      {"aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)", "unary"}, -> no inplace version
      {"aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)", "unary"}, -> no inplace version
      {"aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)", "unary"}, -> no inplace version
      {"aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)", "unary"}, -> no inplace version
      {"aten::isnan(Tensor self) -> Tensor", "unary"}, -> no inplace version
      {"aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor", "broadcast"}, -> no inplace version
      {"aten::type_as(Tensor self, Tensor other) -> Tensor", "unary"}, -> no inplace version
      {"aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, -> not supported, cannot be convert by RemoveTensorMutation
      {"aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, -> not supported, cannot be convert by RemoveTensorMutation
      {"aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, -> not supported, cannot be convert by RemoveTensorMutation
      {"aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, -> not supported, cannot be convert by RemoveTensorMutation
      {"aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, -> not supported, cannot be convert by RemoveTensorMutation
      {"aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor", "broadcast_three"}, ->no_inplace_version
      {"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"}, ->no_inplace_version