balbasty / nitorch

Neuroimaging in PyTorch
Other
86 stars 14 forks source link

Added Model specific tensorboard visualisation. #39

Closed brudfors closed 3 years ago

balbasty commented 3 years ago

I think that the signature of _board is a bit too specific. And the way it is called in the trainer class definitely is too specific: it assumes that there are always two elements in batch, which might not always be the case (think of a variational auto encoder).

I think that, at the Module level, _board should have a signature that looks more like that of forward. Although we will also give it the output of the forward pass, so maybe two arguments which are tuples:

class Module:
    ...
    def _board(self, tb, inputs, outputs):
        pass

Then in SegNet knows what to expect

class SegNet(Module):
    ...
    def _board(self, tb, inputs, outputs):
        image = inputs[0]
        ground_truth = inputs[1] if len(inputs) > 1 else None
        prediction = outputs[0]
        ...

And in the trainer you can have:

if self.tensorboard:
    self.model._board(self.tensorboard, batch, output)

Finally, I think that the name can be "public" (board instead of _board).

What do you think? It akes the signature of board a bit ugly but at least it is generic.

brudfors commented 3 years ago

@balbasty , what do you think now?