fastmachinelearning / qonnx

QONNX: Arbitrary-Precision Quantized Neural Networks in ONNX
https://qonnx.readthedocs.io/
Apache License 2.0
124 stars 39 forks source link

add type check for input name of resolve_datatype() #121

Open makoeppel opened 4 months ago

makoeppel commented 4 months ago

I found some potential issue in the resolve_datatype() function of the datatype.py file:

def resolve_datatype(name):
    _special_types = {
        "BINARY": IntType(1, False),
        "BIPOLAR": BipolarType(),
        "TERNARY": TernaryType(),
        "FLOAT32": FloatType(),
        "FLOAT16": Float16Type(),
    }
    if name in _special_types.keys():
        return _special_types[name]
    elif name.startswith("UINT"):
        bitwidth = int(name.replace("UINT", ""))
        return IntType(bitwidth, False)
    elif name.startswith("INT"):
    ...

Looking at this function I first thought one has to pass a string for the name input. However, if the input name is of type for example qonnx.core.datatype.BipolarType with a name attribute which is covered in the _spacial_types dictonary (in this case Bipolar) the check if name in _special_types.keys(): will be true. But when you pass for example an qonnx.core.datatype.IntType class the check if name in _special_types.keys(): will be false and the program will crash at the next check elif name.startswith("UINT"): since here the assumption is that the input name is of type str. So long story short the first if can accept a str or a class which is named in the _spacial_types dictonary while the other cases only work with str.

I propose to introduce a type check for the input name for a better error handling. I also added a test case to check if the error is raised and another one to test the resolve_datatype function.