cornell-zhang / heterocl

HeteroCL: A Multi-Paradigm Programming Infrastructure for Software-Defined Heterogeneous Computing
https://cornell-zhang.github.io/heterocl/
Apache License 2.0
322 stars 92 forks source link

[Fix] Remove CastRemover for Tensor Indices (issue #386) #453

Closed zzzDavid closed 2 years ago

zzzDavid commented 2 years ago

Remove CastRemover for Tensor Indices

Fixed issue: #386

Detailed description:

The cast operations on tensor indices are incorrectly removed by CastRemover, which causes binary operation data type mismatch. This PR fixes this issue by removing the CastRemovers and enforcing indices to have i32 data type for LLVM backend.

Link to the tests: tests/issues/test_issue_386.py

zzzDavid commented 2 years ago

More explanation on why CastRemover is deleted in tensor.py and added in generate_reuser_buffer.cc:

Why there's a CastRemover in the first place

The ReuseBufferInserter mutator calculates a tensor index expression to determine the reuse buffer size. The reuse buffer size needs to be a constant:

https://github.com/cornell-zhang/heterocl/blob/267d670bfd7ec14276a7ddd21103a6e6ed668e5b/tvm/src/pass/generate_reuse_buffer.cc#L225

https://github.com/cornell-zhang/heterocl/blob/267d670bfd7ec14276a7ddd21103a6e6ed668e5b/tvm/src/pass/generate_reuse_buffer.cc#L234

The index expression is simplified by Simplify function. However, Simplify function does not fold expression "across" cast node. For example:

(int33) (y + 1) - (int33) y

does not get simplified to 1. We need to remove the cast operations and provide simplifier such expression: y + 1 - 1

The CastRemover in tensor.py is added for this purpose: https://github.com/cornell-zhang/heterocl/blob/267d670bfd7ec14276a7ddd21103a6e6ed668e5b/python/heterocl/tensor.py#L151

What caused the issue

However, removing all cast node for index expression leads to unsafe binary operation type. For example:

https://github.com/zzzDavid/heterocl/blob/8875efb47a1978ce8b1064ebb271c76455c4a307/tests/issues/test_issue_386.py#L9-L17

When all the cast operations in the tensor index expression are removed, the Mul operation will have a mismatch for its two operand types.

How this PR fix this issue

This PR fixes this issue by removing the CastRemover in the frontend, so that this issue doesn't happen. But ReuseBufferInserter still faces the issue that simplifier doesn't work with cast operation. Therefore a CastRemover is added precisely where it's needed: right before when an index expression gets simplified.

Q&A

Is it safe to remove the cast operations? Seems the CastRemover does not enumerate all the operations but only several binary ones.

As explained above, CastRemover only removes cast in the index expression which needs to be simplified, so only binary operations and single variable needs to be covered since it's not meant to remove cast in all cases.

Other than that, I also added type checking in the CastRemover, in case the binary operation cannot be cast-removed.

Can you also provide the test cases for reuse_at? Seems you changed the reuse buffer pass.

The changes are meant to make reuse_at work without removing all cast in tensor indices, instead of changing its behaviour.

So the original tests in test_schedule_memory.py passing indicates the fix is working as expected.

zzzDavid commented 2 years ago

@chhzh123 @Shaojie Please review again and merge if there's no issue

hecmay commented 2 years ago

@zzzDavid you probly need to change the github workflow a bit: https://github.com/cornell-zhang/heterocl/pull/456/files#diff-7829468e86c1cc5d5133195b5cb48e1ff6c75e3e9203777f6b2e379d9e4882b3

hecmay commented 2 years ago

@zzzDavid I've merged the fix into master. you can pull it back and it should pass the CI/CD tests

hecmay commented 2 years ago

the zhang-x1 local CI/CD runner cannot run through. @zzzDavid

Traceback (most recent call last):
  File "tests/test_cont_integration.py", line 1, in <module>
    import heterocl as hcl
.......
OSError: /actions-runner/_work/heterocl/heterocl/libhcl.so: undefined symbol: _ZN4llvm16MetadataTracking5trackEPvRNS_8MetadataENS_12PointerUnionIIPNS_15MetadataAsValueEPS2_EEE

i switched to my local runner

zzzDavid commented 2 years ago

@hecmay Thanks, I'll take a look at it. It seems my local runner is not using the correct g++. Currently it works on your local runner. Please feel free to merge it.