hammlab / Crowd-ML

Framework for Crowd-sourced Machine Learning
Apache License 2.0
16 stars 10 forks source link

NOT A BUG : Problem with the server config file #33

Closed 3ygun closed 7 years ago

3ygun commented 7 years ago

It's not training.

In the TensorFlowTrainer.java replace:

// Copy the weights Tensor into the weights array.
Trace.beginSection("fetch");
trainingInterface.fetch(weightsOp, w);
Trace.endSection();

with:

float[] w_new = new float[D * K];
// Copy the weights Tensor into the weights array.
Trace.beginSection("fetch");
trainingInterface.fetch(weightsOp, w_new);
Trace.endSection();

int z = 0;
boolean equal = true;
for (int x=0; x<D*K && equal; x++) {
    if (Math.abs(w[x])-Math.abs(w_new[x]) > 0.0005) {
        z = x;
        equal = false;
    }
}

Log.d(" ", " ");
Log.d("Weights Equal", "" + w.equals(w_new));
Log.d("Original Weights", ""+w[0]+" "+w[1]+" "+w[2]);
Log.d("New Weights", ""+w[0]+" "+w[1]+" "+w[2]);
Log.d(" ", "");
w = w_new;

And look at the log debug with weights aren't changing between runs only the data we run the initial random weights against which explains why we're not getting >20% accuracy.

tylermzeller commented 7 years ago

I have not seen this issue. I've had as high as 87% accuracy.

One thing I see in this code base that needs changed is the init op for tensorflow. At TensorFlowTrainer.java check line 99:

trainingInterface.run(new String[]{}, new String[]{initName});

In the current code base, this line will run after every communication round with the server, which is wrong. I have since changed this error and it will appear in future commits.

3ygun commented 7 years ago

@tylermzeller I guess don't update to TensorFlow 1.12 yet then.

tylermzeller commented 7 years ago

@3ygun oh jeez

3ygun commented 7 years ago

Could you post the server config and tensorflow python model file you are using? I think problems may be stemming from that by chance.

3ygun commented 7 years ago

Resolved the issue... TLDR make sure you are iterating through the correct files...

Basically, I was using the MNISTTestImages.50.l2.dat with MNISTTestLabels.dat so only a few of the labels matched 🤕 now I'm getting up to 83% after looking at ~20000 of the samples 🍷