csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

New lowering pass: loop rotation #2500

Closed zasdfgbnm closed 1 year ago

zasdfgbnm commented 1 year ago

This PR adds a new lowering pass: loop rotation.

Loop rotation is a loop transformation that transforms the code like below

for i in range(n):
  statement1(i)
  statement2(i)
  statement3(i)
  statement4(i)

into something like:

if 0 < n:
  for i = 0:
    statement1(i)
    statement2(i)
for i in range(n):
  statement3(i)
  statement4(i)
  if i + 1 < n:
    statement1(i+1)
    statement2(i+1)

Right now, because existing predicates should already cover all illegal access, I am not materializing the 0 < n and i + 1 < n predicates, so I am actually generating:

for i = 0:
  statement1(i)
  statement2(i)
for i in range(n):
  statement3(i)
  statement4(i)
  if True:
    statement1(i+1)
    statement2(i+1)

Loop rotation is conceptually very simple regarding index and predicate lowering, and their changes are mostly mechanical. For the trivial loop for i = 0, our index and predicate lowering should already handle it correctly. For the exprs inside the if i + 1 < n block, the only change needed is to replace all the loop->index() with loop->index() + 1, which requires me to detect the if i + 1 < n block and mark its body as "rotated" during index and predicate lowering.

In order to use loop rotation, users can do fusion.rotateLoop(tv, dim, {tv1, tv2}), then during lowering, the loop tv->axis(dim) will be rotated, and the expressions allocating and computing tv1 and tv2 will be the rotated exprs.

naoyam commented 1 year ago

I get this build error (clang 13):

/home/nmaruyama/pytorch/debug2/third_party/nvfuser/csrc/ir_cloner.h:67:10: error: lambda capture 'this' is not used [-Werror,-Wunused-lambda-capture]
        [this](auto&... x) { return std::make_tuple<Ts...>(clone(x)...); },
         ^
/home/nmaruyama/pytorch/debug2/third_party/nvfuser/csrc/ir_cloner.h:49:22: note: in instantiation of function template specialization 'nvfuser::IrCloner::clone<nvfuser::TensorView *, long, std::unordered_set<nvfuser::Statement *>>' requested here
      copy.push_back(clone(p));
                     ^
/home/nmaruyama/pytorch/debug2/third_party/nvfuser/csrc/fusion.cpp:83:40: note: in instantiation of function template specialization 'nvfuser::IrCloner::clone<std::tuple<nvfuser::TensorView *, long, std::unordered_set<nvfuser::Statement *>>>' requested here
  to->loop_rotation_param_ = ir_cloner.clone(from->loop_rotation_param_);
                                       ^
1 error generated.
zasdfgbnm commented 1 year ago

I get this build error (clang 13):

/home/nmaruyama/pytorch/debug2/third_party/nvfuser/csrc/ir_cloner.h:67:10: error: lambda capture 'this' is not used [-Werror,-Wunused-lambda-capture]
        [this](auto&... x) { return std::make_tuple<Ts...>(clone(x)...); },
         ^
/home/nmaruyama/pytorch/debug2/third_party/nvfuser/csrc/ir_cloner.h:49:22: note: in instantiation of function template specialization 'nvfuser::IrCloner::clone<nvfuser::TensorView *, long, std::unordered_set<nvfuser::Statement *>>' requested here
      copy.push_back(clone(p));
                     ^
/home/nmaruyama/pytorch/debug2/third_party/nvfuser/csrc/fusion.cpp:83:40: note: in instantiation of function template specialization 'nvfuser::IrCloner::clone<std::tuple<nvfuser::TensorView *, long, std::unordered_set<nvfuser::Statement *>>>' requested here
  to->loop_rotation_param_ = ir_cloner.clone(from->loop_rotation_param_);
                                       ^
1 error generated.

I think this is a bug of clang, but I changed the code to this->clone instead of clone. Hope this will make clang happy.

naoyam commented 1 year ago

I get this build error (clang 13):

/home/nmaruyama/pytorch/debug2/third_party/nvfuser/csrc/ir_cloner.h:67:10: error: lambda capture 'this' is not used [-Werror,-Wunused-lambda-capture]
        [this](auto&... x) { return std::make_tuple<Ts...>(clone(x)...); },
         ^
/home/nmaruyama/pytorch/debug2/third_party/nvfuser/csrc/ir_cloner.h:49:22: note: in instantiation of function template specialization 'nvfuser::IrCloner::clone<nvfuser::TensorView *, long, std::unordered_set<nvfuser::Statement *>>' requested here
      copy.push_back(clone(p));
                     ^
/home/nmaruyama/pytorch/debug2/third_party/nvfuser/csrc/fusion.cpp:83:40: note: in instantiation of function template specialization 'nvfuser::IrCloner::clone<std::tuple<nvfuser::TensorView *, long, std::unordered_set<nvfuser::Statement *>>>' requested here
  to->loop_rotation_param_ = ir_cloner.clone(from->loop_rotation_param_);
                                       ^
1 error generated.

I think this is a bug of clang, but I changed the code to this->clone instead of clone. Hope this will make clang happy.

Yeah, looks like so. I don't see any error with the change.

naoyam commented 1 year ago

Just curious, can loop rotation be nested? Not that I think it's important.

zasdfgbnm commented 1 year ago

Just curious, can loop rotation be nested? Not that I think it's important.

I don't think so. To support nested loop rotation, we need to carefully handle the registerReplace and registerInsertBefore to make sure it does not change the abandoned main loop.

naoyam commented 1 year ago

Just curious, can loop rotation be nested? Not that I think it's important.

I don't think so. To support nested loop rotation, we need to carefully handle the registerReplace and registerInsertBefore to make sure it does not change the abandoned main loop.

OK. Maybe add a comment about it?

zasdfgbnm commented 1 year ago

Just curious, can loop rotation be nested? Not that I think it's important.

I don't think so. To support nested loop rotation, we need to carefully handle the registerReplace and registerInsertBefore to make sure it does not change the abandoned main loop.

OK. Maybe add a comment about it?

Actually, it is pretty easy to support it with very little extra work: I can just rotate one loop each pass.

I updated this PR to support this in https://github.com/csarofeen/pytorch/pull/2500/commits/67f7df347d86fdf035323046aa083c2c7635b6b5