Open yuanzhi-zhu opened 9 months ago
For models with two inputs, can I wrapper them like this?
class Wrapper(nn.Module): def __init__(self, model): super(Wrapper, self).__init__() self.model = model def forward(self, x): class_labels = torch.eye(1000, device=device)[torch.randint(0, 1000, (x.shape[0],))] return self.model(x, class_labels)
For models with two inputs, can I wrapper them like this?