quantumjot / cellx

DL/ML library for CellX project
MIT License
5 stars 5 forks source link

Add a CNN to networks #5

Open quantumjot opened 3 years ago

quantumjot commented 3 years ago

We routinely use CNNs for image classification. We should add a generic classifier to networks. We can use the layers.Encoder2D to build the model of arbitrary size/architecture, and then reduce to a dense layer with N outputs. The final layer should be the raw unscaled logits.

So we could build something like this:

class CNNClassifier(K.Model):
    def __init__(self, encoder: layers.Encoder, outputs: int = 5):
        super().__init__()
        self.encoder = encoder
        self.outputs = outputs

This would be a subclass of the Keras Model as described in the docs. In this way, we could build CNNs with arbitrary encoder networks, that reduce to a one-hot or binary classification output.

quantumjot commented 3 years ago

We should finish off this PR, so that we can simply import the classifier network from the cellx library in the cnn-annotator notebooks.

Also, we should add the option discussed here: https://github.com/lowe-lab-ucl/cnn-annotator/issues/28