HazyResearch / flyingsquid

More interactive weak supervision with FlyingSquid
Apache License 2.0
314 stars 22 forks source link

Questions for the pytorch-integration #13

Open dspoka opened 4 years ago

dspoka commented 4 years ago

Hey I didnt fully understand how the pytorch integration fits in the online learning way. Was wondering if you could make a simple tutorial. The other two are fantastic! Thanks!

DanFu09 commented 4 years ago

Hey, great question!

The idea is that when you're computing the loss function, you can use a bunch of weak labels instead of a single ground truth label. The way we do this is that we basically wrap FlyingSquid with a small PyTorch wrapper, so that FlyingSquid will compute a single probabilistic label, which you can then use for cross entropy loss, compute gradients on, etc.

Basically the idea is to replace these lines in a training loop:

model = ... # whatever model you have
optimizer = ... # set up optimizer, schedule, etc

criterion = nn.BCEWithLogitsLoss()

# one gold label for each data point
for inputs, gold_labels in dataloader:
    outputs = models(inputs)
    loss = criterion(outputs, gold_labels)

    loss.backward()
    optimizer.step()

with something like this:

from flyingsquid.pytorch_loss import FSLoss

model = ... # whatever model you have
optimizer = ... # set up optimizer, schedule, etc

m = ... # number of weak labels per data point

# buffer_capacity is how many points to use for online learning with FlyingSquid
criterion = FSLoss(m, buffer_capacity = 100)

# here we load up m WEAK labels for each data point
for inputs, weak_labels in dataloader:
    outputs = models(inputs)
    loss = criterion(outputs, weak_labels)

    loss.backward()
    optimizer.step()

Under the hood, every time the forward function of FSLoss gets called (i.e., criterion(outputs, weak_labels)), the new weak labels get added to a buffer, and FlyingSquid re-learns labeling function accuracies based on everything in its buffer. Then the new accuracies and the votes from the latest batch are used to generate probabilistic labels, which gets fed into a normal PyTorch loss function.

In pseudocode, the forward function looks something like this:

def forward(self, outputs, weak_labels):
    self.buffer += weak_labels # update the buffer
    self.flying_squid_label_model.fit(self.buffer) # re-learn the label model
    prob_labels = self.flying_squid_label_model.predict_proba(weak_labels) # predict probabilistic labels
    return self.pytorch_criterion(outputs, prob_labels) # self.pytorch_criterion is an nn.BCEWithLogitsLoss