Open dspoka opened 4 years ago
Hey, great question!
The idea is that when you're computing the loss function, you can use a bunch of weak labels instead of a single ground truth label. The way we do this is that we basically wrap FlyingSquid with a small PyTorch wrapper, so that FlyingSquid will compute a single probabilistic label, which you can then use for cross entropy loss, compute gradients on, etc.
Basically the idea is to replace these lines in a training loop:
model = ... # whatever model you have
optimizer = ... # set up optimizer, schedule, etc
criterion = nn.BCEWithLogitsLoss()
# one gold label for each data point
for inputs, gold_labels in dataloader:
outputs = models(inputs)
loss = criterion(outputs, gold_labels)
loss.backward()
optimizer.step()
with something like this:
from flyingsquid.pytorch_loss import FSLoss
model = ... # whatever model you have
optimizer = ... # set up optimizer, schedule, etc
m = ... # number of weak labels per data point
# buffer_capacity is how many points to use for online learning with FlyingSquid
criterion = FSLoss(m, buffer_capacity = 100)
# here we load up m WEAK labels for each data point
for inputs, weak_labels in dataloader:
outputs = models(inputs)
loss = criterion(outputs, weak_labels)
loss.backward()
optimizer.step()
Under the hood, every time the forward function of FSLoss gets called (i.e., criterion(outputs, weak_labels)
), the new weak labels get added to a buffer, and FlyingSquid re-learns labeling function accuracies based on everything in its buffer. Then the new accuracies and the votes from the latest batch are used to generate probabilistic labels, which gets fed into a normal PyTorch loss function.
In pseudocode, the forward function looks something like this:
def forward(self, outputs, weak_labels):
self.buffer += weak_labels # update the buffer
self.flying_squid_label_model.fit(self.buffer) # re-learn the label model
prob_labels = self.flying_squid_label_model.predict_proba(weak_labels) # predict probabilistic labels
return self.pytorch_criterion(outputs, prob_labels) # self.pytorch_criterion is an nn.BCEWithLogitsLoss
Hey I didnt fully understand how the pytorch integration fits in the online learning way. Was wondering if you could make a simple tutorial. The other two are fantastic! Thanks!