luigibonati / mlcolvar

A unified framework for machine learning collective variables for enhanced sampling simulations
MIT License
91 stars 24 forks source link

Multiple preprocessing transforms wrapped in sequential failing in training #140

Closed EnricoTrizio closed 2 months ago

EnricoTrizio commented 3 months ago

If multiple preprocessing transforms are wrapped in a torch.nn.Sequential object are used the training is failing because the example input array is generated wrong because torch.nn.Sequential does not have a in_features attribute.

We could create a wrapper like SequentialTransform

class SequentialTransform(torch.nn.Module):
    def __init__(self, modules) -> None:
        super().__init__()
        self.modules = modules
        self.in_features = modules[0].in_features

    def forward(self, x):
        for module in self.modules:
            x = module(x)
        return x
andrrizzi commented 2 months ago

That sounds good to me. An alternative could be something like

class SequentialTransform(Sequential):
  @property
  def in_features(self):
    return next(self._modules).in_features
  # Setter if required