If multiple preprocessing transforms are wrapped in a torch.nn.Sequential object are used the training is failing because the example input array is generated wrong because torch.nn.Sequential does not have a in_features attribute.
We could create a wrapper like SequentialTransform
class SequentialTransform(torch.nn.Module):
def __init__(self, modules) -> None:
super().__init__()
self.modules = modules
self.in_features = modules[0].in_features
def forward(self, x):
for module in self.modules:
x = module(x)
return x
If multiple preprocessing transforms are wrapped in a
torch.nn.Sequential
object are used the training is failing because the example input array is generated wrong becausetorch.nn.Sequential
does not have ain_features
attribute.We could create a wrapper like
SequentialTransform