bobrarity / Iris-Classification-Dockerized

0 stars 0 forks source link

Better to create a load method for IrisNet #5

Open kirillepam opened 7 months ago

kirillepam commented 7 months ago

https://github.com/bobrarity/Docker-project-1/blob/59be91ad2a43c3400f4eabda681e3dd0d1bfc5f5/inference/inference.py#L21

class IrisNet(nn.Module):
    def __init__(self):
        super(IrisNet, self).__init__()
        self.fc1 = nn.Linear(4, 100)
        self.fc2 = nn.Linear(100, 3)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def load(self, filepath):
        self.load_state_dict(torch.load(filepath))
bobrarity commented 7 months ago

Ok, got it, thanks)