AlexanderLutsenko / nobuco

Pytorch to Keras/Tensorflow/TFLite conversion made intuitive
MIT License
251 stars 15 forks source link

Debugging tensor shapes when using dynamic axes #24

Open BenCrulis opened 6 months ago

BenCrulis commented 6 months ago

Hello, first of all thank you for this fantastic tool.

I am using it to convert the YOLO-World model from PyTorch to Tensorflow, but I am encountering a strange issue.

The conversion process completes without any issue, but when I change the size of the input tensor, the inferences crashes because of a shape mismatch somewhere in the operation graph. When I do the same thing with the original PyTorch model, it works, when I convert the model using static input shapes, it works. It is only when using dynamic axis and changing the dimension in question that it fails.

After some investigations, it appears that a tensor has be "doubled" in size along this axis somewhere in the TF graph, for some reason. I was lucky that the error happened near the end of the graph, in a way that I managed to reduce the conversion process to stop before the operation that crashes because of the shape mismatch. I can then see that some columns are repeated. The issue is that I don't know where the columns are doubled, and thus I can't really find or fix the issue.

As I am not able to create a minimal reproducible example for now, in this issue I am asking for ways to debug the output computational graph, for instance by having a way to get the computed shapes even when using dynamic shapes. Or perhaps a way to insert "fake nodes" in order to step into the graph more easily with the python debugger.

If I can find the source of the issue, I will open another issue for it specifically.

AlexanderLutsenko commented 6 months ago

Hi! Here's one failure case:

class DynamicShape(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        _, h, w = x.shape

        # DO NOT cast pytorch tensors to numpy / python primitives if you want to trace them!
        h, w = int(h), int(w)

        b = torch.ones((1, h, w))
        return x + b

input = torch.normal(0, 1, size=(1, 128, 128))
pytorch_module = DynamicShape().eval()

keras_model = nobuco.pytorch_to_keras(
    pytorch_module,
    args=[input],
    input_shapes={input: (None, None, None)},  # Annotate dynamic axes with None
    trace_shape=True,
    constants_to_variables=False,
    inputs_channel_order=ChannelOrder.PYTORCH,
)

input_tf = tf.zeros(shape=(1, 64, 64))
keras_model(input_tf)

Nobuco can only operate in the realm of tensors. Once you step outside, the warranty is void. The worst part is that I do not see a good way to identify such situations.

After some investigations, it appears that a tensor has be "doubled" in size along this axis somewhere in the TF graph, for some reason.

Is that so? Surely a bug, if true. Care to provide more details?

AlexanderLutsenko commented 6 months ago

Or perhaps a way to insert "fake nodes" in order to step into the graph more easily with the python debugger.

You can try something like this:

@nobuco.traceable
def print_shape(x, name=None):
    print('[Pytorch shape]', f'{name}:' if name is not None else '', x.shape)
    return x

class PrintShape(tf.keras.layers.Layer):
    def call(self, x, name=None):
        print('[Tensorflow shape]', f'{name}:' if name is not None else '', tf.shape(x))
        return x

@nobuco.converter(print_shape, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
def converter_print_shape(x, name=None):
    return PrintShape()

class DynamicShape(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        _, h, w = x.shape

        # DO NOT cast pytorch tensors to numpy / python primitives if you want to trace them!
        h, w = int(h), int(w)

        b = torch.ones((1, h, w))
        b = print_shape(b, name='b')
        return x + b
BenCrulis commented 6 months ago

Nobuco can only operate in the realm of tensors. Once you step outside, the warranty is void.

I do not see calls to int in the YOLO-World code where that would be relevant for this issue, but the issue might actually appear way upstream in the layers, I will investigate using your method with printer nodes.

The worst part is that I do not see a good way to identify such situations.

Could it be possible to mock the int function, or the __int__ method in order to identify when a traced shape tensor is going to return to a native int type? For instance I can do something like this:

from unittest.mock import patch

def my_func(x):
    return int(x) + 1

orig_int = int
def traced_int(x):
    print(f'Called int({x})')
    return orig_int(x)

x = 1
with patch('__main__.int', traced_int):
    print(my_func(x)) # prints Called int(1) \n 2

Is that so? Surely a bug, if true. Care to provide more details?

From what I can tell for now, at one point there is a tensor with a shape that is (1,64,84,84) that is concatenated to a tensor of shape (1,3,84,84) with dim=1 in the original pytorch code, so it should give me a tensor of shape (1,67,84,84).

The second tensor is the one with a dynamic dimension, the second dimension (corresponding to 3) is the number of classes passed as an input to YOLO-World, and since it uses some kind of Attention mechanism between class embeddings and visual patches, it should adapt to the number of classes even after the conversion process is done.

However, I crashed in the next node that was a reshape node, complaining it received a tensor of shape (1,70,84,84).

I has the same issue when I used only two classes, the reshape node crashed saying it received a tensor of shape (1,68,84,84) when it should have been (1,66,84,84).

I temporarily fixed the issue by avoiding the concatenation right before the reshape operation, and do the reshape on each tensor separately, but I still end up with a second tensor of shape (1,6,84,84) (then further reshaped into (1,6,7056)), when it should be (1,3,7056). There are lots of places where the second dimension might have doubled in size before that and not provoke crashes, given the code is made to accept any number of classes dynamically.

I will investigate further using your method tomorrow, thanks! Do you want me to open a new issue when I find the origin of the bug, or shall we continue here?

AlexanderLutsenko commented 6 months ago

Could it be possible to mock the int function, or the int method in order to identify when a traced shape tensor is going to return to a native int type?

Definitely possible. That's how I do graph tracing in the first place.

Do you want me to open a new issue when I find the origin of the bug, or shall we continue here?

I leave it to your judgement.

By the way, if you set debug_traces=nobuco.TraceLevel.ALWAYS, you'll get code pointers to each operation/module in the graph. This should aid your search.

BenCrulis commented 6 months ago

Ok I feel dumb now. After further debugging using print_shape, I found the issue came from the reshape node itself, it used a hard coded value of the number of classes instead of deriving it from the text embedding tensor shape. Seem like YOLO-World really wasn't written with exporting the model in mind. I had it backward with the doubling: I was testing by giving a tensor with double the size after the conversion, and as I thought the reshape had the correct target shape, I thought it was doubled again before the reshape, which led to a crash. It was not a bug after all.

In any case, the code you provided was very useful for pinpointing the exact place where the hard coded value was used. I am wondering if there would be an easier way to insert such debugging operations into the pytorch code, perhaps we could imagine using PyTorch hooks and also record the stack trace for this?

However, I do not understand why the tensorflow reshape node was displayed as if it was using a dynamical axis if the value is indeed hard coded, I was debugging with the wrong assumption that the reshape had the correct target shape because of this.

Thank you again!

AlexanderLutsenko commented 6 months ago

Seem like YOLO-World really wasn't written with exporting the model in mind.

Yep, the flexibility of Pytorch spoils people.

I am wondering if there would be an easier way to insert such debugging operations into the pytorch code, perhaps we could imagine using PyTorch hooks and also record the stack trace for this?

Easier in what way exactly? Now that I think about it, one can put something like print_shape after every operation and draw that colored debug graph for the Keras model as well, with links to the original Pytorch code. Need to mull it over.

I do not understand why the tensorflow reshape node was displayed as if it was using a dynamical axis if the value is indeed hard coded

Go figure, maybe that changes in Keras 3, like lots of other weird stuff.

BenCrulis commented 6 months ago

Easier in what way exactly? Now that I think about it, one can put something like print_shape after every operation and draw that colored debug graph for the Keras model as well, with links to the original Pytorch code. Need to mull it over.

Well, there is already the file and the line of the original PyTorch operations that are indicated in the conversion log given by nobuco, so instead I should rather ask for a way to match the nodes in the conversion log with the nodes displayed in the debug graph, and when the model crashes at inference time after the conversion succeeded.

Tensorflow seems to annotate each node with an index, like reshape_23 that was crashing for me, but the reshape operation in the nobuco logs don't have indices. Maybe it would be a good thing for debugging to get the corresponding indices in the log somehow.

At the moment, I don't think there is an easy way to know where the TF model crash takes place in the corresponding PyTorch code. Except by inserting code like print_shape everywhere of course.

AlexanderLutsenko commented 6 months ago

Except by inserting code like print_shape everywhere of course.

I can't see any other way. It's certainly possible to insert a bunch of print_shapes automatically, but I'm not sure how useful that could be.