hammlab / Crowd-ML

Framework for Crowd-sourced Machine Learning
Apache License 2.0
16 stars 10 forks source link

Android TensorFlow (with wifi & refractoring changes) #32

Closed 3ygun closed 7 years ago

3ygun commented 7 years ago

Goal

Update the Android app to allow for a TensorFlow back-end #23 with the wifi & refracting changes of #27. The validity of which was checked in #11.

Status

NOTE: Don't merge until all are complete!

End of Scope for this PR

Other Parts

3ygun commented 7 years ago

Should be good I think @tylermzeller could you look over it?

3ygun commented 7 years ago

I tested with the included server/config/tensorflow.json and the mnist_mlp.pd file generated by the following:

import tensorflow as tf

with tf.Session() as sess:

    x = tf.placeholder(tf.float32, shape=[None, 50], name="x")
    y = tf.placeholder(tf.float32, [None, 10], name="y")
    w = tf.placeholder(tf.float32, [50, 10], name="weights_in")

    W = tf.Variable(tf.zeros([50, 10]), name="weights")
    b = tf.Variable(tf.zeros([10]))

    y_out = tf.add(tf.matmul(x, W), b, name="y_out")

    #cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_out), reduction_indices=[1]))
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_out))
    train_step = tf.train.AdamOptimizer(0.005).minimize(cross_entropy, name="train")

    correct_prediction = tf.equal(tf.argmax(y_out,1), tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32),                                                                                                                                           name="test")

    init = tf.variables_initializer(tf.global_variables(), name="init")

    tf.train.write_graph(sess.graph_def,
                         './',
                         'mnist_mlp.pb', as_text=False)