pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

Make sure the dynamic layering mechanism composes with the LazyTensor prototype #24

Open zou3519 opened 3 years ago

zou3519 commented 3 years ago

In the future users should be able use composable function transforms with LazyTensor to deliver better performance.

zou3519 commented 3 years ago

I installed the LTC prototype as well as an XLA backend and ran all of our tests on CPU using XLA.

Only a few tests failed:

The per_sample_grad tests failed with incorrect values, which means that something fishy might be going on (or a batching rule may be implemented incorrectly)

The test_maml_omniglot_cpu test failed with an internal assert in XLA

E       RuntimeError: Internal: From /job:localservice/replica:0/task:0:
E       2 root error(s) found.
E         (0) Internal: RET_CHECK failure (tensorflow/compiler/xla/service/cpu/ir_emitter.cc:3211) ShapeUtil::SameElementType(operands[0]->shape(), operand->sh
ape())
E                [[{{node XRTCompile}}]]
E         (1) Internal: RET_CHECK failure (tensorflow/compiler/xla/service/cpu/ir_emitter.cc:3211) ShapeUtil::SameElementType(operands[0]->shape(), operand->sh
ape())
E                [[{{node XRTCompile}}]]
E                [[XRTCompile_G6]]
E       0 successful operations.
E       0 derived errors ignored.
E       Recent warning and error logs:
E         Internal: RET_CHECK failure (tensorflow/compiler/xla/service/cpu/ir_emitter.cc:3211) ShapeUtil::SameElementType(operands[0]->shape(), operand->shape(
))
E       *** Begin stack trace ***
E               tensorflow::CurrentStackTrace[abi:cxx11]()
E
E               xla::status_macros::MakeErrorStream::Impl::GetStatus()
E               xla::cpu::IrEmitter::ElementTypesSameAndSupported(xla::HloInstruction const&, absl::lts_2020_02_25::Span<xla::HloInstruction const* const>, abs
l::lts_2020_02_25::Span<xla::PrimitiveType const>)
E               xla::cpu::IrEmitter::HandleDot(xla:
E         OP_REQUIRES failed at xrt_compile_ops.cc:215 : Internal: RET_CHECK failure (tensorflow/compiler/xla/service/cpu/ir_emitter.cc:3211) ShapeUtil::SameEl
ementType(operands[0]->shape(), operand->shape())

/raid/rzou/pt/ltc/torch/nn/functional.py:1847: RuntimeError