luigibonati / mlcolvar

A unified framework for machine learning collective variables for enhanced sampling simulations
MIT License
91 stars 24 forks source link

linear activation function gives errors #134

Closed EnricoTrizio closed 2 months ago

EnricoTrizio commented 3 months ago

In the list of activation functions there is also linear that doesn't do anything except for printing a warning.

def get_activation(activation: str):
    """Return activation module given string."""
    activ = None
    if activation == "relu":
        activ = torch.nn.ReLU(True)
    ...
    elif activation == "linear":
        print("WARNING: no activation selected")
    elif activation is None:
        pass
    else:
        raise ValueError(
            f"Unknown activation: {activation}. options: 'relu','elu','tanh','softplus','shifted_softplus','linear'. "
        )
    return activ

So if used (which is something that makes sense), activ is initialized as None, giving errors.

Maybe we can create a fake activation function that doesn't do anything but still is a torch.nn.module

class No_Activation(torch.nn.Module):
    def __init__(self):
        super(No_Activation, self).__init__()

    def forward(self, input):
        return input

Also, the list of available activations must be updated.

andrrizzi commented 3 months ago

I think that works but it'll probably end up creating more lines of code than those you'd save by handling it explicitly as usual ('linear' here means really None). I don't think that function is called in many places anyway. What if we remove that case and just call that function as

if activ != 'linear':
    get_activation(activ)
else:
    logger.warn('No activation function')
EnricoTrizio commented 3 months ago

Sorry, I realized this because for the committor one has an activation on the last layer, but it's also useful to turn it off sometimes. I can move it to the post-processing part in case.

I was thinking that if someone chooses linear it would also be nice to have it working.