Closed shiffman closed 5 years ago
Hi all!
@garciadelcastillo and myself have been working on modularizing SketchRNN for a few projects to generate a drawing prediction from an input set strokes. I hope this can help porting SketchRNN to the ml5-library.
Right now, this simple_predict.js module has the logic to load a model from a local file, set input strokes as relative or absolute, and generate a drawing prediction. (It contains the guts of the original simple_predict demo sketch plus some helpers to work with both relative and absolute sketch coordinates.)
One of our main goals was to serve SketchRNN as an HTTP service (see http-server.js) and a WebSocket client (see websocket-client.js). A sample use of the HTTP service is this p5 sketch.
Also, this video (minute 4:30) explains how SketchRNN encodes its strokes. Internally, each position is a relative movement from the previous position.
=)
Hi @nonoesp!
This looks very cool! It would be awesome to have to a ml5.SketchRNN()
class!
I imagine we could have a collection of models and a few examples to play with. The model you are using are not in the repo you linked to? I couldn't find them. Just wondering if you are preprocessing them in any way.
If you and @garciadelcastillo are interested, I will be glad to help push a PR to incorporte this. @shiffman, thoughts?
Hi @cvalenzuela,
We have been working a lot with the library these days, would be happy to find some time to do a formal PR. However, some thoughts:
We are currently wrapping the Sketch-RNN functionality inside an http server, and serving it via POST requests. I wonder how ml5
serves other similar models and deals with models-as-a-service (`MaaS'? :boom:)
All the trained models (which we are not processing at all) occupy 1.5Gb approx, it is quite a large download. The library could be written to request models ad-hoc from a CDN, but it would slow the process down tremendously (most models are 11Mb in size). Thoughts?
JL
Hi @cvalenzuela! I'd be glad to help with this as well.
Here is a list with all the generative models. By changing gen
to vae
you can download the full variational auto-encoder model (which allows to use latent vectors). It would make sense to download them on-demand as in the sketch-rnn demo.
From the tensorflow/magenta-demos
repo:
Pre-trained weight files
The RNN model has 2 modes: unconditional and conditional generation. Unconditional generation means the model will just generate a random vector image from scratch and not use any latent vectors as an input. Conditional generation mode requires a latent vector (128-dim) as an input, and whatever the model generates will be defined by those 128 numbers that can control various aspects of the image. Whether conditional or not, all of the raw weights of these models are individually stored as .json files inside the models directory. For example, for the 'butterfly' class, there are 2 models that come pretrained: butterfly.gen.json - unconditional model butterfly.vae.json - conditional model
thanks @nonoesp and @garciadelcastillo!
What kind of functionality are you wrapping in the server? Or is it just serving the .json files?
We try to keep ml5 as "client-side" as possible. We just fetch weights, when necessary, from a constant URL to keep the library small. So if the server you are running is just storing the urls for those .json files, I imagine that an on-demand approach will be the best. This might look something like this:
// Providing the 'cat' attribute will make the class fetch the right .json file
let catRNN = new ml5.SketchRNN('cat', onModelLoaded);
// Callback when the model loads
function onModelLoaded() {
// Generate
catRNN.generate();
}
sketch-rnn demo demo takes a couple of seconds to download a model, I guess that's fine for our case too.
I am so excited about this! I can imagine a server-side component for ml5 eventually but I agree with @cvalenzuela that coming up a client-side only example first would be great. With ml5 we are also not as concerned with perfection/accuracy as we are with ease of use and friendliness. So sacrificing some quality for smaller model files is something we can explore/discuss too.
I wonder as a step 2 (or 200?) if there is a way we can do either transfer learning or training from scratch also with new user data.
@hardmaru mentioned to me he was interested in helping make this happen!
Hi! Sorry for the radio silence!
@garciadelcastillo and I were experiment with serving SketchRNN (and other libraries) over HTTP and WebSocket for a workshop, to have participants interact with machine learning libraries (such as SketchRNN) from different coding environments.
I'm currently porting the barebones of simple_predict.js as an ml5 module in the nonoesp-sketchrnn branch—still really work-in-progress. This module would run on the client side, potentially loading models from the same source as sketch-rnn-demo is.
https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/category.type.json
Here, you would access the generative model of a bird, bicycle, or angel, with the following URLs. (More on this here.)
Of course, I expect @hardmaru to be able to expose a lot more functionality.
This is very exciting! It probably makes sense for us to mirror the ml5.LSTM
API to the extent that it makes sense. Building on @cvalenzuela's earlier comment, I'm thinking a simple example could look something like. . .?
const sketchRNN = ml5.sketchRNN('rainbow', modelReady);
function modelReady() {
console.log('Ready to generate');
}
function setup() {
createCanvas(400, 400);
// These would all be optional it could generate something by default?
// A seed and options would be optional?
// Seed would be array of objects with x,y, and pen state?
let initialSketch = [
{x: 100, y: 100, pen: true},
{x: 100, y: 200, pen: true},
{x: 200, y: 200, pen: false}
];
let options = {
temperature: 0.5, // temperature
length: 100, // how many points of a path to generate
seed: initialSketch
};
sketchRNN.generate(options, gotSketch);
}
function gotSketch(sketch) {
// "sketch" is an array of objects with x, y, and pen state?
for (let i = 1; i < sketch.length; i++) {
let current = sketch[i];
if (current.pen) {
let previous = sketch[i-1];
line(previous.x, previous.y, current.x, current.y);
}
}
}
We could consider integrating with p5.Vector
but perhaps this would tie it too closely to p5?
Hi @shiffman ! Thanks for all the discussion.
That's a nice suggestion. In my original model api, I did it where we sample each point incrementally, rather than sample the entire drawing, since it might allow for more creative applications, such as allowing the algorithm to extend what the user has drawn.
When I wrote the model, deeplearn.js wasn't available yet so I just implemented my own LSTM using javascript, but the interface in the code shouldn't be too difficult to port over to your ml5.LSTM. That being said, the code as it is now is fairly efficient and works quite fast on the client side even on an old mobile device.
I've also been thinking of cleaning up an old script that can convert TensorFlow-trained sketch-rnn models over to the compressed JSON format that the JS version can use, so in theory we can use non-quickdraw datasets. Will probably try to do that first.
@hardmaru yes, that makes a lot of sense! Perhaps the default behavior can be to just sample one point at a time with an option to ask for an array? My concern is that I'm assuming that sampling will require a callback which could get quite confusing for a beginner trying to do something with a draw()
loop in p5. Would it be able to do one point at a time without a callback, i.e.?
let sketchRNN = ml5.sketchRNN('rainbow', modelReady);
let ready = false;
let previous = null;
// Using preload would make this much simpler for a beginner example!
function modelReady() {
ready = true;
}
function setup() {
createCanvas(400, 400);
}
function draw() {
if (ready) {
// optionally can pass in a seed / temperature, etc.?
let next = sketchRNN.generate();
if (previous && next.pen) {
line(previous.x, previous.y, next.x, next.y);
previous = next;
}
}
}
@shiffman that seems nice and simple enough to understand, I like it! I wonder with this example, how would one want to create a demo where we let the user start a sketch, and have sketchRNN finish it? maybe will have to separately encode that feature somehow.
In the version I had, it might be more straight forward to extend to do such things, at the expense of a bit more complexity, which is always a tradeoff:
... initialization code before draw(), see doc
function draw() {
// see if we finished drawing
if (prev_pen[2] == 1) {
p.noLoop(); // stop drawing
return;
}
// using the previous pen states, and hidden state, get next hidden state
// the below line takes the most CPU power, especially for large models.
rnn_state = model.update([dx, dy, pen_down, pen_up, pen_end], rnn_state);
// get the parameters of the probability distribution (pdf) from hidden state
pdf = model.get_pdf(rnn_state);
// sample the next pen's states from our probability distribution
[dx, dy, pen_down, pen_up, pen_end] = model.sample(pdf, temperature);
// only draw on the paper if the pen is touching the paper
if (prev_pen[0] == 1) {
p.stroke(line_color);
p.strokeWeight(2.0);
p.line(x, y, x+dx, y+dy); // draw line connecting prev point to current point.
}
// update the absolute coordinates from the offsets
x += dx;
y += dy;
// update the previous pen's state to the current one we just sampled
prev_pen = [pen_down, pen_up, pen_end];
}
In this version, since the prev_pen state can be either sampled using SketchRNN's pdf (your generate()), or can be overwritten using the user's actual mouse/touch movement, it is easy to incorporate the interactive component to get the human in the loop. Maybe there can be an elegant way to incorporate this into your proposed framework too, maybe something like:
let sketchRNN = ml5.sketchRNN('rainbow', modelReady, optionalTemperature, optionalSeed);
let ready = false;
let previous = null;
// Using preload would make this much simpler for a beginner example!
function modelReady() {
ready = true;
}
function setup() {
createCanvas(400, 400);
}
function draw() {
if (ready) {
let next = sketchRNN.generate();
if (previous && next.pen) {
line(previous.x, previous.y, next.x, next.y);
previous = next;
}
sketchRNN.update(previous);
// previous can be overwritten by human input, so doesn't necessarily have to be what is generated by sketchRNN
}
}
Ah yes, this makes sense! We should definitely allow for the user to pass in human input and override the model's generated data. This could also possibly be an argument to generate()
where it's something like:
// previous can be overwritten by human input!
let next = sketchRNN.generate(previous);
if (next.pen) {
line(previous.x, previous.y, next.x, next.y);
previous = next;
}
In looking at your code I see that the model provides dx,dy
rather than literal x,y
coordinates. I think this makes sense to keep, I was just making stuff up without looking closely!
We can probably conflate pen_up
and pen_down
into one state pen
(true
or false
)? What is pen_end
?
Hi @shiffman
In addition to modelling when the pen should touch the canvas and when it should be lifted away from the canvas, Sketch-RNN also models when to finish drawing (via the event pen_end
). So [pen_down, pen_up, pen_end]
is a one-hot vector sampled from a categorial distribution.
Unlike an LSTM generating Hemingway forever, if we let an LSTM doodle birds without end, it will fill the entire canvas with black ink eventually (i.e. kanji example)!
Hi @shiffman @cvalenzuela @nonoesp @garciadelcastillo
A few updates from me:
1) I ported the sketch-rnn-js model over to TensorFlow.js using the TypeScript style of the magenta.js project. The API is very similar to sketch-rnn-js, but just GPU accelerated. I'll try to put this on the magenta.js repo soon, after porting over a few demos over and testing a few things.
2) Wrote a small IPython notebook to show how to quickly train a sketch-rnn model with TensorFlow, and convert that model over to the JSON format that can be used by sketch-rnn-js (and the TensorFlow.js version in (1)): https://github.com/tensorflow/magenta-demos/blob/master/jupyter-notebooks/Sketch_RNN_TF_To_JS_Tutorial.ipynb
After I put (1) out it should be fairly easy to wrap ml5.js over it so that sketch-rnn can be readily available.
Currently, this is how I deal with the model loading in magenta.js but I think the ml5.js way is more elegant:
var sketch = function( p ) {
"use strict";
console.log("SketchRNN JS demo.");
var model;
var dx, dy; // offsets of the pen strokes, in pixels
var pen_down, pen_up, pen_end; // keep track of whether pen is touching paper
var x, y; // absolute coordinates on the screen of where the pen is
var prev_pen = [1, 0, 0]; // group all p0, p1, p2 together
var rnn_state; // store the hidden states of rnn's neurons
var pdf; // store all the parameters of a mixture-density distribution
var temperature = 0.45; // controls the amount of uncertainty of the model
var line_color;
var model_loaded = false;
// loads the TensorFlow.js version of sketch-rnn model, with the "cat" model's weights.
model = new ms.SketchRNN("https://storage.googleapis.com/quickdraw-models/sketchRNN/models/cat.gen.json");
Promise.all([model.initialize()]).then(function() {
// initialize the scale factor for the model. Bigger -> large outputs
model.set_pixel_factor(3.0);
// initialize pen's states to zero.
[dx, dy, pen_down, pen_up, pen_end] = model.zero_input(); // the pen's states
// zero out the rnn's initial states
rnn_state = model.zero_state();
model_loaded = true;
console.log("model loaded.");
});
p.setup = function() {
var screen_width = p.windowWidth; //window.innerWidth
var screen_height = p.windowHeight; //window.innerHeight
x = screen_width/2.0;
y = screen_height/3.0;
p.createCanvas(screen_width, screen_height);
p.frameRate(60);
// define color of line
line_color = p.color(p.random(64, 224), p.random(64, 224), p.random(64, 224));
};
p.draw = function() {
if (!model_loaded) {
return;
}
// see if we finished drawing
if (prev_pen[2] == 1) {
p.noLoop(); // stop drawing
return;
}
// using the previous pen states, and hidden state, get next hidden state
// the below line takes the most CPU power, especially for large models.
rnn_state = model.update([dx, dy, pen_down, pen_up, pen_end], rnn_state);
// get the parameters of the probability distribution (pdf) from hidden state
pdf = model.get_pdf(rnn_state, temperature);
// sample the next pen's states from our probability distribution
[dx, dy, pen_down, pen_up, pen_end] = model.sample(pdf);
// only draw on the paper if the pen is touching the paper
if (prev_pen[0] == 1) {
p.stroke(line_color);
p.strokeWeight(2.0);
p.line(x, y, x+dx, y+dy); // draw line connecting prev point to current point.
}
// update the absolute coordinates from the offsets
x += dx;
y += dy;
// update the previous pen's state to the current one we just sampled
prev_pen = [pen_down, pen_up, pen_end];
};
};
var custom_p5 = new p5(sketch, 'sketch');
Amazing @hardmaru! this will be super nice to have in ml5. Let us know when you publish your code so when can make a wrapper around it!
Once that is ready, we can also put the training instructions and script here: https://ml5js.org/docs/training-introduction
Nice! So glad to hear about TypeScript @hardmaru, and thanks so much for sharing the IPython notebook. Looking forward to the release.
I put the code in my fork for now but should be merged in the next few days.
There are 3 working demos that use sketch-rnn with TensorFlow.js, linked in the README.md
The TensorFlow.js has gone thru code review and accepted into the main repo. The current interface is more or less inspired by the p5.js style workflow, and in fact all the demos use p5.js
https://github.com/tensorflow/magenta-js/blob/master/sketch/README.md
The next step is to try to wrap it over with ml5.js and make the ml5.SketchRNN()
class.
great! I'll make a branch and start working on it
Yay! I am so excited about this! I would love to help work on this too.
Thanks for the help @reiinakano @cvalenzuela @shiffman
I published a more optimized version 0.1.2 (no change to the API) today:
https://www.npmjs.com/package/@magenta/sketch
This version reduced the number of dataSync()
calls and improves performance by a little bit.
I'm briefly re-opening this issue to cover some API decisions @cvalenzuela just made in our weekly ml5 meeting!
Instead of storing the sketch data as:
var initialStrokes = [
[-4, 0, 1, 0, 0],
[-15, 9, 0, 1, 0],
[-10, 17, 0, 0, 1]
];
we propose:
var initialStroke = [
{ dx: -4, dy: 0, pen: "down"},
{ dx: -15, dy: 9, pen: "up"},
{ dx: -10, dy: 17, pen: "end"}
];
and then sketch data generated would look like:
function gotResult(err, result) {
if (previous.pen === "down")
stroke(255, 0, 0)
strokeWeight(3.0);
line(x, y, x + result.dx, y + result.dy);
}
x += result.dx;
y += result.dy;
previous = result;
}
Feel free to weigh in with any thoughts or comments!
Looks sensible to me! Makes it a lot more readable.
In the future the backend might still want to deal with a raw 2d array of floats if we want the generation of an entire sketch in one gpu call but we can just convert back to the array format when needed.
On Fri, Oct 12, 2018 at 3:48 AM Daniel Shiffman notifications@github.com wrote:
I'm briefly re-opening this issue to cover some API decisions @cvalenzuela https://github.com/cvalenzuela just made in our weekly ml5 meeting!
Instead of storing the sketch data as:
var initialStrokes = [ [-4, 0, 1, 0, 0], [-15, 9, 0, 1, 0], [-10, 17, 0, 0, 1] ];
we propose:
var initialStroke = [ { dx: -4, dy: 0, pen: "down"}, { dx: -15, dy: 9, pen: "up"}, { dx: -10, dy: 17, pen: "end"} ];
and then sketch data generated would look like:
function gotResult(err, result) { if (previous.pen === "down") stroke(255, 0, 0) strokeWeight(3.0); line(x, y, x + result.dx, y + result.dy); }
x += result.dx; y += result.dy; previous = result; }
Feel free to weigh in with any thoughts or comments!
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/ml5js/ml5-library/issues/11#issuecomment-429075783, or mute the thread https://github.com/notifications/unsubscribe-auth/AGBoHu2lujktQKZqPTJN5kcPBVNgIaxSks5uj5KIgaJpZM4QG8Yk .
Was also taking a closer look at the ml5js API for sketch-rnn. I'm not sure if I have the right understanding, but I think in the current abstraction, currently we can only generate an entire sketch (from the beginning) using the API, but not able to do things like feed in a current incomplete drawing, and have sketch-rnn finish the drawing. I guess it's a trade off between complexity / simplicity of the API, although if we only generate the complete sketches, one can also pull ground truth "human-generated" data directly from quickdraw dataset too :)
On Fri, 12 Oct 2018 at 07:04, hard maru hardmaru@gmail.com wrote:
Looks sensible to me! Makes it a lot more readable.
In the future the backend might still want to deal with a raw 2d array of floats if we want the generation of an entire sketch in one gpu call but we can just convert back to the array format when needed.
On Fri, Oct 12, 2018 at 3:48 AM Daniel Shiffman notifications@github.com wrote:
I'm briefly re-opening this issue to cover some API decisions @cvalenzuela https://github.com/cvalenzuela just made in our weekly ml5 meeting!
Instead of storing the sketch data as:
var initialStrokes = [ [-4, 0, 1, 0, 0], [-15, 9, 0, 1, 0], [-10, 17, 0, 0, 1] ];
we propose:
var initialStroke = [ { dx: -4, dy: 0, pen: "down"}, { dx: -15, dy: 9, pen: "up"}, { dx: -10, dy: 17, pen: "end"} ];
and then sketch data generated would look like:
function gotResult(err, result) { if (previous.pen === "down") stroke(255, 0, 0) strokeWeight(3.0); line(x, y, x + result.dx, y + result.dy); }
x += result.dx; y += result.dy; previous = result; }
Feel free to weigh in with any thoughts or comments!
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/ml5js/ml5-library/issues/11#issuecomment-429075783, or mute the thread https://github.com/notifications/unsubscribe-auth/AGBoHu2lujktQKZqPTJN5kcPBVNgIaxSks5uj5KIgaJpZM4QG8Yk .
Ah, yes this is a very good point! I think this relates (?) to the current discussion about stateful LSTM's in this pull request! I wonder if we could adopt a similar API for SketchRNN where we have simple "generate a drawing mode" as well as "generate one pen motion at a time where the user can take over" etc.
See: https://github.com/ml5js/ml5-library/pull/221#issuecomment-429060273 for more.
Something like:
function draw() {
if (user is drawing) {
var next = {
dx: mouseX - pmouseX,
dy: mouseY - pmouseY,
pen: "down" // dynamic based on mouseIsPressed?
};
line(mouseX, mouseY, pmouseX, pmouseY);
sketchRNN.update(next);
previous = next;
} else if (model is drawing) {
let next = sketchRNN.next(0.1);
if (previous.pen === "down")
stroke(255, 0, 0)
strokeWeight(3.0);
line(x, y, x + result.dx, y + result.dy);
}
x += result.dx;
y += result.dy;
previous = next;
sketchRNN.update(next);
}
}
I'm ignoring the asynchronous aspect here and making up variables but is this the right idea?
The way I handled the interactivity is to completely abandon the async nature of the API (although this might be the wrong decision since there is a tradeoff vs performance).
In the current magenta version of sketch-rnn (https://www.npmjs.com/package/@magenta/sketch?activeTab=readme), the API is basically completely synchronous, and the code is similar to what your comment describes. Here is the sketch loop for generating a sketch:
function draw() {
// see if we finished drawing
if (prev_pen[2] == 1) {
noLoop(); // stop drawing
return;
}
// using the previous pen states, and hidden state, get next hidden state
// the below line takes the most CPU power, especially for large models.
rnn_state = model.update([dx, dy, pen_down, pen_up, pen_end], rnn_state);
// get the parameters of the probability distribution (pdf) from hidden state
pdf = model.getPDF(rnn_state, temperature);
// sample the next pen's states from our probability distribution
[dx, dy, pen_down, pen_up, pen_end] = model.sample(pdf);
// only draw on the paper if the pen is touching the paper
if (prev_pen[0] == 1) {
stroke(line_color);
strokeWeight(3.0);
line(x, y, x+dx, y+dy); // draw line connecting prev point to current point.
}
// update the absolute coordinates from the offsets
x += dx;
y += dy;
// update the previous pen's state to the current one we just sampled
prev_pen = [pen_down, pen_up, pen_end];
};
So to incorporate the interactivity, I can just override what sketch-rnn generates with what the user draws using the mouse/tablet data in the draw loop.
Maybe an easy way is to leave the current mode for async, and copy in the non-async api from the magenta version (with the syntactic sugar and also remaining to dx/dy/pen state names)?
As a matter of update I have a working example for my A2Z class here:
https://github.com/shiffman/A2Z-F18/tree/master/week8-charRNN/04_sketchRNN
Are certain models there automatically and others I'll need to download? Right now it works with "cat" out of the box. Next step is I'll work on the SketchRNN
class to implement some of the feature suggestions in this thread, as well as make an example with interactivity.
Looks fun! The pre-trained models are all in JSON format that should be dynamically loaded.
There's a few interactive demos in the magenta-js version that can prob be ported to this version (though the API will prob need to be refactored depending on the level of abstraction we want to give the user):
Closing!!! (New issues coming with remaining to do's. . .)
Simple example using sketch-rnn plus p5.js could be integrated as part of this project, or at least linked to! (cc @hardmaru yet again!)