NVIDIA / TensorRT-Incubator

Experimental projects related to TensorRT
82 stars 12 forks source link

Improve behavior of evaluation during compile #409

Open pranavm-nvidia opened 21 hours ago

pranavm-nvidia commented 21 hours ago

When a tensor is evaluated during compile, we currently raise an error or print a warning. However, we could make Tripy functionally correct even in the case of evaluation by simply not updating the frontend tensor's op to Storage. This would preserve the computation graph and make compilation work correctly.

This might be as simple as adding one condition to Tensor.eval():

if not self.trace_tensor.is_compile_tracer:
    Storage.build_internal([], [self.trace_tensor], data)
    ...

We will need the warnings though since there may be cases where the evaluated result is erroneously used later in the graph - e.g.

batch = int(x.shape[0]) # Eval happens here
tp.ones((batch, ...)) # Dynamic shapes broken here

We could also suppress warnings in some cases, e.g. if the tensor is only being printed. A simple way to achieve that would be to add a suppress_warnings parameter to eval and set it to True when calling it from __repr__ - we will need to pass it through tolist, so we will likely want a _tolist_helper since the public tolist method should not have this option exposed.

We can, however, drop the errors that are thrown in compile, which means we can also drop the eval_stack_info field of TraceTensor.

Finally, to test it, we should verify the following are true when evaluating while compiling:

  1. We never raise an error
  2. The frontend tensor op is never updated in-place (i.e. it should not be turned into a Storage tensor)
  3. We do not emit warnings when the evaluation is triggered by __repr__ (i.e. we can safely assume the output is unused later in the graph)
pranavm-nvidia commented 18 hours ago

We may actually still want warnings in all cases since the extra evaluations will trigger compilation, which could make tracing extremely slow. A way to mitigate this could be to store the evaluated tensors somehow such that only other evaluated tensors use their values while the compiler still traces them like non-evaluated ops.

For example, if we have a graph like:

A -> B -> C -> D

and we print B and C, we should only need to compile the A->B part once during evaluation and compile C separately. However, in the final compiled executable, we still want the full graph, not just the ->D part.