Regarding swap_layers(self, model, *args, **kwargs) function
In your refactoring, the swap_layers(self, model, *args, **kwargs) function tries to binarize all the layers, leaving the user with no option to leave out specific modules from binarization.
This function would have more flexibility if it accepted a dictionary of layer types the user allows to binarize. Something like this:
# Pick out the layers you want to binarize
# This is created by user and passed as an argument to swap_layers(...)
BIN_LAYERS = {
nn.Conv2d : True,
nn.Linear : True,
nn.BatchNorm2d : False,
}
def swap_layers(self, model, BIN_LAYERS,*args, **kwargs):
list_model = list(model.children())
for idx, layer in enumerate(list_model):
if type(layer) in self.BIN_LAYERS:
try:
layer.weight.data = self.binarize(layer.weight.data)
except:
print(f"Cannot binarize weight data of {type(layer)}") # Make comment more clear
try:
layer.bias.data = self.binarize(layer.bias.data)
except:
print(f"Cannot binarize bias data of {type(layer)}") # Make comment more clear
list_model[idx] = layer.type(torch.int8)
else:
continue
return nn.Sequential(*list_model)
The above piece of code allows the user to say that they want to binarize all conv2d and linear layers, but want you to leave out batch norm layers as is.
I think this would be an important addition to have before the merge?
Regarding
swap_layers(self, model, *args, **kwargs)
functionIn your refactoring, the
swap_layers(self, model, *args, **kwargs)
function tries to binarize all the layers, leaving the user with no option to leave out specific modules from binarization.This function would have more flexibility if it accepted a dictionary of layer types the user allows to binarize. Something like this:
The above piece of code allows the user to say that they want to binarize all conv2d and linear layers, but want you to leave out batch norm layers as is.
I think this would be an important addition to have before the merge?