AlexanderLutsenko / nobuco

Pytorch to Keras/Tensorflow/TFLite conversion made intuitive
MIT License
272 stars 17 forks source link

Identity layer is considered as 'Unimplemented nodes' #38

Closed crimson206 closed 6 months ago

crimson206 commented 6 months ago

Reproduction

Identity layer(no function in forward) causes the unimplemented nodes error.

import torch
import torch.nn as nn
import nobuco
from nobuco import ChannelOrder

class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,x):
        return x

net = Identity()

net.eval()
input_tensor = torch.randn(10,20)

keras_mode = nobuco.pytorch_to_keras(
    net,
    args = [input_tensor], kwargs=None,
    inputs_channel_order=ChannelOrder.TENSORFLOW,
    outputs_channel_order=ChannelOrder.TENSORFLOW
)
---------------------------------------------------------------------------
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
Cell In[11], line 18
     15 net.eval()
     16 input_tensor = torch.randn(10,20)
---> 18 keras_mode = nobuco.pytorch_to_keras(
     19     net,
     20     args = [input_tensor], kwargs=None,
     21     inputs_channel_order=ChannelOrder.TENSORFLOW,
     22     outputs_channel_order=ChannelOrder.TENSORFLOW
     23 )

File ~nobuco/nobuco/convert.py:335, in pytorch_to_keras(model, args, kwargs, input_shapes, inputs_channel_order, outputs_channel_order, trace_shape, enable_torch_tracing, constants_to_variables, full_validation, validation_tolerance, return_outputs_pt, save_trace_html, debug_traces)
    333     print('Unimplemented nodes:')
    334     print(unimplemented_hierarchy.__str__(**vis_params))
--> 335     raise Exception('Unimplemented nodes')
    337 keras_op = keras_converted_node.keras_op
    339 args_tf, kwargs_tf = prepare_inputs_tf((args, kwargs), inputs_channel_order, input_shapes)

Exception: Unimplemented nodes

Reason

The Identity module is considered as a node. The node itself or its children should be registered in convert_dict. However, it is not registered, and it has no children.

def find_unimplemented(hierarchy: PytorchNodeHierarchy, converter_dict: Dict[object, Pytorch2KerasNodeConverter]) -> Optional[PytorchNodeHierarchy]:
    # Test if the node itself has a converter
    if has_converter(hierarchy.node, converter_dict):
        return None
    elif len(hierarchy.children) == 0:
        return PytorchNodeHierarchy(hierarchy.node, hierarchy.children)
    else:
        children_unimplemented = []
        for child in hierarchy.children:
            child_unimplemented = find_unimplemented(child, converter_dict)
            if child_unimplemented is not None:
                children_unimplemented.append(child_unimplemented)

        if len(children_unimplemented) > 0:
            # The node is unimplemented
            return PytorchNodeHierarchy(hierarchy.node, children_unimplemented)
AlexanderLutsenko commented 6 months ago

Fixed that in v0.14.1, thanks.