mead-ml / mead-baseline

Deep-Learning Model Exploration and Development for NLP
Apache License 2.0
243 stars 73 forks source link

WIP: switch to ONNX compat torch-only tril #885

Closed dpressel closed 2 years ago

dpressel commented 2 years ago

our tril impl was using numpy -- according to the ONNX docs, this is not the right way during tracing, it should be tensor/pyt ops. PyTorch does have the proper operators now for this, but when I switched to that, I got an error on export. I tracked that down in this ticket: https://github.com/pytorch/pytorch/issues/34129

It seems that its fixed in the latest PyTorch but we are still supporting older versions, so for now, follow the recipe in the ticket to ensure that it does what we want