IoBT-VISTEC / MIN2Net

End-to-End Multi-Task Learning for Subject-Independent Motor Imagery EEG Classification (IEEE Transactions on Biomedical Engineering)
https://min2net.github.io
Apache License 2.0
76 stars 25 forks source link

I would like to express the MIN2Net model in pytorch, is that correct? #4

Open Buddies-as-you-know opened 1 year ago

Buddies-as-you-know commented 1 year ago

Only the MI-EEG classification part was expressed in pytorch. Does it match?

class Conv2D_Norm_Constrained(nn.Conv2d):
    def __init__(self, max_norm_val, norm_dim, **kwargs):
        super().__init__(**kwargs)
        self.max_norm_val = max_norm_val
        self.norm_dim = norm_dim

    def get_constrained_weights(self, epsilon=1e-8):
        norm = self.weight.norm(2, dim=self.norm_dim, keepdim=True)
        return self.weight * (torch.clamp(norm, 0, self.max_norm_val) / (norm + epsilon))

    def forward(self, input):
        return F.conv2d(input, self.get_constrained_weights(), self.bias, self.stride, self.padding, self.dilation, self.groups)

class ConstrainedLinear(nn.Linear):
    def forward(self, input):
        return F.linear(input, self.weight.clamp(min=-1.0, max=0.5), self.bias)
class MinNet(nn.Module): # input = (1,16,125)
  def __init__(self, input_shape=(1,400,20)):
    super().__init__()
    self.D, self.T, self.C = input_shape
    self.subsampling_size = 100
    self.pool_size_1 = (1,self.T//self.subsampling_size)
    self.en_conv = nn.Sequential(
                    Conv2D_Norm_Constrained(in_channels=1, out_channels=16, kernel_size=(1, 64), padding="same", max_norm_val=2.0, norm_dim=(0, 1, 2)),
                    nn.ELU(),
                    nn.BatchNorm2d(16,eps=1e-05, momentum=0.1),
                    nn.AvgPool2d((1,self.pool_size_1)),
                    nn.Flatten(),
                    ConstrainedLinear(32000,64),
                    nn.ELU(),
                    ConstrainedLinear(64,3)
                )
  def forward(self,x):
      x = self.en_conv(x)
      return x