funkelab / gunpowder

A library to facilitate machine learning on multi-dimensional images.
https://funkelab.github.io/gunpowder/
MIT License
78 stars 56 forks source link

Add hook for data-dependent snapshots #151

Closed bentaculum closed 3 years ago

bentaculum commented 3 years ago

It can be useful to write out snapshots based on some condition in the current batch, especially for developing and debugging. Here is an example subclass that writes out a snapshot if there's a spike in the recorded loss.

import numpy as np
from gunpowder import Snapshot

class SnapshotLossIncrease(Snapshot):
    def __init__(self, factor, **kwargs):
        super().__init__(**kwargs)
        self.factor = factor
        self.running_loss = float('inf')

    def write_if(self, batch):
        out = self.running_loss * self.factor < batch.loss
        self.running_loss = batch.loss
        return out
pattonw commented 3 years ago

I think this is only partially working. Since we decide in the prepare function whether to make a snapshot, we request the data for the snapshot there. If it isn't a snapshot iteration, the process function won't get all of the data that should go into the snapshot, so you would be missing data (right now it looks like you would only be missing data from additional_request, but the snapshot node probably shouldn't request anything on non snapshot iterations) in your snapshot.

We should probably start with moving these lines down into the if self.record_snapshot case to fix the bug where we request a lot of data when we don't plan to write anything. This would break the dynamic snapshot though since you would have no data to save, except on the regular snapshot iterations. We could maybe make the write_if function an argument to the snapshot node. Then in prepare we could always request snapshot data if self.record_snapshot or self.write_if is not None, but then only write out based on the evaluation of write_if with a batch.

bentaculum commented 3 years ago

Thanks for elaborating. I was not aware that the plan is to not request anything in Snapshot any more by default, but with this context, agree to all your points. My PR actually combines periodic writing with data-dependent writing, which I realize now is not clean. With your suggestion, pure data-dependent writing would be achieved by setting every=1 and passing a write_if.

Actually, I implemented write_if at first as a function to be passed to Snapshot, but imo it is cleaner and better OOP style to create a hook.

pattonw commented 3 years ago

Well, nvm, I chatted with other lab members and it seems like we can keep the current request in the dependencies since there is no overhead to just requesting data that is already requested. This is definitely a useful feature that we want, we just need to make the behavior consistent and clear to users. We can't use write_if to determine whether the batch should be written in the prepare method, so I think we just stick to the every term, and use write_if as a filter. That way users can still do pure data-dependent writing as you suggested with every=1 and providing a write_if. As for passed in function vs subclassing, I think a passed in function would be simpler for the simple use cases, like write_if = lambda batch: batch.loss > 0.5. But I could also see people wanting to keep moving averages, and only write a snapshot if it is outside some expected range. That would be easier to do with a subclass. I'm starting to lean more towards subclass since thats how the RandomLocation node does it with accepts.

bentaculum commented 3 years ago

Closing in favor of #152.