unpackAI / unpackai

The Unpack.AI library
https://www.unpackai.com
MIT License
19 stars 4 forks source link

DL101 Week 1 interp.plot_top_losses plots only one image #64

Open jamescavanagh opened 2 years ago

jamescavanagh commented 2 years ago

Describe the bug

In the DL101 notebook, the interp.plot_top_losses(5, nrows=5) line will only return one image. It is critical to the notebook running correctly for students.

To Reproduce Steps to reproduce the behavior or code you have used

Run the notebook, train the model, then run this cell

interp.plot_top_losses(5, nrows=5)

Expected behavior A clear and concise description of what you expected to happen.

Screenshots If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

Additional context

This has been mentioned by several students, and is important to their experience.

jfthuong commented 2 years ago

@Dennis-fast-ai

Could you put in the comments of this issue, the fix-up code that you found on fastai forums?

Thanks

Note: as mentioned in our chats, we could add a function in our library that is inspired from fastai with stable code, tests to check it's not breaking, and 2 additional parameters:

jamescavanagh commented 2 years ago

@jfthuong Here is the code It is a bit messy, but this is how it came

@title Getting a visual result for the top losses

def plot_top_losses_fix(interp, k, largest=True, **kwargs): losses,idx = interp.top_losses(k, largest) if not isinstance(interp.inputs, tuple): interp.inputs = (interp.inputs,) if isinstance(interp.inputs[0], Tensor): inps = tuple(o[idx] for o in interp.inputs) else: inps = interp.dl.create_batch(interp.dl.before_batch([tuple(o[i] for o in interp.inputs) for i in idx])) b = inps + tuple(o[idx] for o in (interp.targs if is_listy(interp.targs) else (interp.targs,))) x,y,its = interp.dl._pre_show_batch(b, max_n=k) b_out = inps + tuple(o[idx] for o in (interp.decoded if is_listy(interp.decoded) else (interp.decoded,))) x1,y1,outs = interp.dl._pre_show_batch(b_out, max_n=k) if its is not None:

plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), L(self.preds).itemgot(idx), losses, **kwargs)

        plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), interp.preds[idx], losses,  **kwargs)
    #TODO: figure out if this is needed
    #its None means that a batch knows how to show itself as a whole, so we pass x, x1
    #else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)

plot_top_losses_fix(interp, 10, nrows=2)