our tril impl was using numpy -- according to the ONNX docs, this is not the right way during tracing,
it should be tensor/pyt ops. PyTorch does have the proper operators now for this, but when I switched
to that, I got an error on export. I tracked that down in this ticket: https://github.com/pytorch/pytorch/issues/34129
It seems that its fixed in the latest PyTorch but we are still supporting older versions, so for now, follow
the recipe in the ticket to ensure that it does what we want
our tril impl was using numpy -- according to the ONNX docs, this is not the right way during tracing, it should be tensor/pyt ops. PyTorch does have the proper operators now for this, but when I switched to that, I got an error on export. I tracked that down in this ticket: https://github.com/pytorch/pytorch/issues/34129
It seems that its fixed in the latest PyTorch but we are still supporting older versions, so for now, follow the recipe in the ticket to ensure that it does what we want