Open pranavm-nvidia opened 21 hours ago
We may actually still want warnings in all cases since the extra evaluations will trigger compilation, which could make tracing extremely slow. A way to mitigate this could be to store the evaluated tensors somehow such that only other evaluated tensors use their values while the compiler still traces them like non-evaluated ops.
For example, if we have a graph like:
A -> B -> C -> D
and we print B
and C
, we should only need to compile the A->B
part once during evaluation and compile C
separately.
However, in the final compiled executable, we still want the full graph, not just the ->D
part.
When a tensor is evaluated during compile, we currently raise an error or print a warning. However, we could make Tripy functionally correct even in the case of evaluation by simply not updating the frontend tensor's op to
Storage
. This would preserve the computation graph and make compilation work correctly.This might be as simple as adding one condition to
Tensor.eval()
:We will need the warnings though since there may be cases where the evaluated result is erroneously used later in the graph - e.g.
We could also suppress warnings in some cases, e.g. if the tensor is only being printed. A simple way to achieve that would be to add a
suppress_warnings
parameter toeval
and set it toTrue
when calling it from__repr__
- we will need to pass it throughtolist
, so we will likely want a_tolist_helper
since the publictolist
method should not have this option exposed.We can, however, drop the errors that are thrown in
compile
, which means we can also drop theeval_stack_info
field ofTraceTensor
.Finally, to test it, we should verify the following are true when evaluating while compiling:
Storage
tensor)__repr__
(i.e. we can safely assume the output is unused later in the graph)