beacon-biosignals / LighthouseFlux.jl

An adapter package that implements Lighthouse's framework interface for Flux
MIT License
1 stars 1 forks source link

WIP MNIST example #20

Open ericphanson opened 3 years ago

ericphanson commented 3 years ago

pairing with @hannahilea to put up a simple example

ericphanson commented 3 years ago

Initially, I forgot the

Flux.@functor SimpleModel (chain,)

line, which caused it to not train at all. Funnily enough, the model seemed to mostly predict "3" for any input. Now with that line, I get

individualImage

It looks suspiciously like the error path I introduced in the fake voters, which 10% of the time vote for n+1 (mod 10) instead of the true value n. So either I mixed that up, or I have an off-by-1 error somewhere else, since it seems to predict 0s and 1s, etc. Not totally sure what's up with the missing plots. The script returned accuracy(test_set..., model) = 0.9897 though, so I'm not totally sure what's going on. Maybe an off-by-1 in the classes or plotting somehow?

@hannahilea I didn't end up mocking out the model like you suggested yet because I was hitting annoying CUDA errors for some reason. I'll try that next though.

hannahilea commented 3 years ago

lolol sure looks like an off-by-one in the classes---your predictions are suspiciously bad-in-a-good-way. Let me take a look at the code. Just to clarify, this is the test set evaluation plot, right?

ericphanson commented 3 years ago

this is the test set evaluation plot, right?

Yep! That was after 100 epochs. I made the change you suggested, and also tweaked the voter errors so they make off-by-2 mistakes instead of off-by-1 to see if that would change things here.

After 23 epochs, I get

mnist

so still off-by-1 (and not off-by-2, so it seems separate from the intentional error in the voters).

hannahilea commented 3 years ago

Yeah---the fact that you aren't getting algorithm vs expert agreements is pretty suspect too. I think there's still something weird with the labels. i'm trying to repro it locally, but running into scalar getindex is disallowed errors when doing train_set = gpu.(train_set); do you have any local changes that haven't been pushed?

ericphanson commented 3 years ago

Huh, nope, my working directory is clean. Maybe I should push a manifest? I'm on 1.5.3 but I'll try 1.4 to see if I can repro

hannahilea commented 3 years ago

OH. I definitely know what it is. This is one of those problems that would have been startlingly apparent if the labels on this dataset weren't integers, lolol.

As far as I can tell, your model training etc is all fine. The issue shows up during evaluation: basically, all of the evaluation metrics of Lighthouse. evaluation_metrics_plot (which is called via Lighthouse.evaluate! during Lighthouse.train!) expect hard labels (e.g. predicted_hard_labels) to be integers that are the indices of your Lighthouse.classes(model). And in the mnist example here, our integer labels aren't class indices, they are literal label values.

How could you have known this? Good question, hah. Not sure you could have! We definitely need to add it to the documentation around setting up a classifier and the training harness! We do imply it elsewhere in docs, sort of. E.g., from the definition of Lighthouse.evaluation_metrics_plot() (https://github.com/beacon-biosignals/Lighthouse.jl/blob/4ea93408df1b6eb1412900bdaa99a52da4b95c15/src/learn.jl#L653), we see that

  • predicted_soft_labels is a matrix of soft labels whose columns correspond to classes and whose rows correspond to samples in the evaluation set.

(bolding mine). Then the implicit assumption in the following two definitions

  • predicted_hard_labels is a vector of hard labels where the ith element is the hard label predicted by the model for sample i in the evaulation set.
  • elected_hard_labels is a vector of hard labels where the ith element is the hard label elected as "ground truth" for sample i in the evaulation set.

is that each of these hard labels is a class index rather than a class value.

In this mnist example, our off by one error is introduced when we define the classes as 0:9 when we set up the classifier:

classifier = FluxClassifier(model, opt, 0:9)

Without us realizing it, class "0" now has index 1, class "1" has index 2, etc. (Hooray 1-based indexing! Would our example have succeeded if mnist values were instead 1:10? It sure would have!)

Under this index-based label assumption, during evaluation, all labels that are integers that fall outside the range of possible ranges (i.e., i <= 1and i > length(classes(model))) are ignored. This is actually an intentional/nice property---it means that when we're evaluation a set of labels with 7 classes, if we want to know how well we individually perform on a single one of those classes, we can use Lighthouse.binary_statistics() on it (https://github.com/beacon-biosignals/Lighthouse.jl/blob/4ea93408df1b6eb1412900bdaa99a52da4b95c15/src/metrics.jl#L39) and the set of labels will be treated as "class vs not class" rather than "class 1 vs class 2 vs class 3 ...".


To fix our example here, we need to do two things: 1) Tell our classifier that our classes are ["$i" for 0:9] (or, even clearer, ["class_$i" for 0:9]). It might be nice to make this a constant toward the top of the script, so that we can reference it later:

const MNIST_LABELS = ["$i" for 0:9]

(EDIT: I don't actually think we need to change this first point at all, shouldn't matter one bit if the classes are strings or not. const MNIST_LABELS = 0:9 is totally fine, and will still be indexed into correctly. Doing it this way might be more readable, but if it isn't, feel free to skip it.)

2) Set up our labels in the first place so that they correspond to these indices. We could do this inside the get_processed_data(...) function when we're setting up our ground truth in the first place, by doing

train_labels = MNIST.labels() .+ 1
...
test_labels = MNIST.labels(:test) .+ 1

Then an image with label 3 would be of class MNIST_LABELS[3] == "class_2".

ericphanson commented 3 years ago

Great catch! With those changes, it works-- here is epoch 71 (test evaluation):

mnist-good

Thanks for all the refs to the code and such. I'll at least rename the variables in the examples to emphasize that they are class indices, and I'll try to see where else in the lighthouse docs I can add emphasis to that fact. It makes a lot of sense that they are indices, but somehow I didn't think about it at all.

Btw, it did run ok for me on Julia 1.4.2 (no scalar indexing error). But I noticed there's two dev'd dependencies, LighthouseFlux itself (dev'd into the docs folder on my clone) and the Lighthouse PR. I wonder if that could be a source of non-reproducibility. Also some previous commits did have a scalar indexing error so maybe that's related.