dask / dask-ml

Scalable Machine Learning with Dask
http://ml.dask.org
BSD 3-Clause "New" or "Revised" License
892 stars 255 forks source link

Parameter Server #171

Open mrocklin opened 6 years ago

mrocklin commented 6 years ago

I'm restarting the discussion on parameter servers from https://github.com/dask/dask-glm/issues/57

When doing distributed training it is often useful to rapidly update and share parameters between distributed workers. The frequency with which these updates can be shared seems to strongly affect the rate of convergence. The common practice today seems to be to have dedicated nodes within the network that serve as parameter servers, accepting and aggregating updates from workers, and periodically publishing those back out to workers.

Dask's existing task scheduling model is not ideal for this. Every computation or communication checks in with the scheduler, which adds unnecessary network roundtrips and provides a central bottleneck for metadata tracking. To do parameter servers well we probably need to break out of this model and engage more peer-to-peer coordination, without frequent checking in with the scheduler.

Building this infrastructure is not hard, but if possible it would be good to build general machinery that can be used to solve both this problem, and potentially others. I would appreciate having a conversation with a few people to identify what programming and networking constructs might suffice for this.

cc @stsievert @fabianp @mlnick

mrocklin commented 6 years ago

One general mechanism that might fit this need well is Pub/Sub. Under this construct any worker can join a topic as either a publisher or subscriber (or both). It learns from the scheduler who all of its peers are on that topic. When it publishes data on that topic it sends this data directly to all subscribers (or, in the future, to some subscribers that then forward it along to others). Under this model the parameter subscribes to updates and publishes parameters.

An implementation of a single parameter server might look like the following:

def parameter_server(...):
    """ This runs on one special worker """
    parameters = np.random.random(...)

    updates_sub = Sub('updates')     # subscribe to updates from workers
    parameters_pub = Pub('parameters')  # publish new parameters to workers

    async def publish_parameters():
        while not should_stop():
            parameters_pub.put(parameters)    # maybe send only to a few round-robin instead?
            await gen.sleep(0.01)  # maybe control this based on network traffic?

    get_worker().loop.add_callback(publish_parameters)

    for new_data in updates_sub:
        update(parameters, new_data)
        if should_stop():
            return parameters

def worker(data: List[np.ndarray], # these correspond to chunks of data from a dask array
           labels: List[np.ndarray]):
    """ This runs on each worker and gets the arrays present on that worker """
    data = np.concat(data)
    labels = np.concat(labels)

    updates_pub = Pub('updates')
    parameters_sub = Sub('parameters')

    async def update_parameters():
        async for params in parameters_sub:
            parameters = params

    get_worker().loop.add_callback(update_parameters)

    while not should_stop():
        batch = data[np.random.randint(0, len(data))]
        result = compute_update(parameters, batch)
        updates_pub.put(result)
mrocklin commented 6 years ago

To be clear Pub and Sub don't exist today, but they could exist tomorrow if this is an appropriate model. If it isn't then we can build something else. Other ideas are welcome.

stsievert commented 6 years ago

The common practice today seems to be to have dedicated nodes within the network that serve as parameter servers,

That is common practice, and is what distributed Tensorflow uses.

accepting and aggregating updates from workers, and periodically publishing those back out to workers.

In the vanilla case, we only want to update parameters from the parameters we sent to other workers.

In the most simple use case for distributed tensorflow, each worker passes around some parameter global_step that identifies when parameters was pulled from the parameter server, and gathers updates from a certain number of workers. The update is discarded if the global_step a worker sends back is stale. A good reference for this is "Scaling Distributed Machine Learning with the Parameter Server".

This pub/sub framework LGTM. Is there any way to have a worker subscribe to a topic published by another worker in particular? Something like Sub('foo', worker=1) or Pub('foo', worker='nearest_neighbor'). This could enable some of the distributed algorithms (e.g., where nodes are connected to their neighbors and form a ring).

Two extensions:

I'm inclined to believe these extensions are best suited for a class interface, at least privately.

mrocklin commented 6 years ago

In the most simple use case for distributed tensorflow, each worker passes around some parameter global_step that identifies when parameters was pulled from the parameter server, and gathers updates from a certain number of workers. The update is discarded if the global_step a worker sends back is stale.

I encourage people to start thinking aspirationally here, a bit beyond the simplest case.

I was looking at Downpour and also this paper which had different algorithms that I felt were both probably decently well approximated with PubSub. It may be though that my approximations are significant.

Is there any way to have a worker subscribe to a topic published by another worker in particular?

This would be outside of the paradigm of pub/sub but is certainly something that is doable. I expect pubsub to be something like 200-400 lines of code. We can do more things if they are important.

I am not personally familiar with algorithms that are based on ring-like topologies. Is this common?

Async read/writes could be useful

The pseudocode above does async reads/writes. The pub.put commands return immediately. The subscribe commands would block, but are happening in an async block.

Gradient coding, which codes the gradient ndarray to some object

I'm hoping that this is independent of communication patterns and is something that folks like you would feel comfortable writing after pubsub, or whatever network machinery we need, is put in place.

I'm inclined to believe these extensions are best suited for a class interface, at least privately.

This work is likely to happen in two levels:

  1. core dask stuff, like pubsub, which has nothing at all to do with parameter servers
  2. ML stuff, which can be as specialized as is necessary
stsievert commented 6 years ago

Got it. This makes more sense now.

am not personally familiar with algorithms that are based on ring-like topologies. Is this common?

At least one algorithm depends on a ring-structure, Hogwild++. This is suited for HPC applications, and shows much better speedups than Hogwild as the number of workers grow. I don't think I'd say this is relevant or that distributed ring-like topologies are super common, but they could be useful.

core dask stuff, like pubsub, which has nothing at all to do with parameter servers ... ML stuff, which can be as specialized as is necessary

Pub/sub looks sufficient for parameter servers. Almost everything I said was ML specific.

mrocklin commented 6 years ago

A draft pubsub implementation is available here: https://github.com/dask/distributed/pull/1999 I suspect that it is still buggy.

I'll work on setting up a mocked out example of the system above and then hopefully others can start playing with things.

My initial experiments show that we have about a 1ms latency when sending messages between workers. I am curious to from people who have an understanding of what kinds of latencies we need to get decent performance in these applications.

stsievert commented 6 years ago

Moderate sized deep learning models have latency of at least 10ms for feeding 1 image through the network for prediction (source: https://dawn.cs.stanford.edu/benchmark/). The time required for optimization of these networks is (likely) no more than 5-6 times as slow, so probably at least 50-100ms for one optimization step.

mrocklin commented 6 years ago

Well then I guess we're in the clear?

On Wed, May 23, 2018 at 1:01 PM, Scott Sievert notifications@github.com wrote:

Moderate sized deep learning models have latency of at least 10ms for feeding 1 image through the network for prediction (source: https://dawn.cs.stanford.edu/benchmark/). The time required for optimization of these networks is (likely) no more than 5-6 times as slow, so probably at least 50-100ms for one optimization step.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/dask/dask-ml/issues/171#issuecomment-391424011, or mute the thread https://github.com/notifications/unsubscribe-auth/AASszIZbuArkTpsK4RmGFvAomjh16NdGks5t1ZXtgaJpZM4UG6C2 .

stsievert commented 6 years ago

1ms latency seems reasonable – latency between cluster nodes for me is (on avg) 0.3ms. I think I'd say it a good first draft, and enough to start experimenting. I can implement a parameter server and look for any bottlenecks.

mrocklin commented 6 years ago

PubSub is merged. I have a tiny toy example here: https://gist.github.com/mrocklin/0d906828544ddeb8e6e1d3d193172ae9

stsievert commented 6 years ago

@mrocklin could I have your timing script to measure the latency? I'd like to compare on the same machine. I'm measuring the latency of PyTorch's distributed module to have a latency of 0.029ms with https://gist.github.com/stsievert/1a93cf732d66f22a3080fcd0729364d6

mrocklin commented 6 years ago

See

https://github.com/dask/distributed/blob/12ddc080d1d876b0fce6fe4b2e863fe7c7b31543/distributed/tests/test_pubsub.py#L13

On Fri, Jun 8, 2018 at 6:15 PM, Scott Sievert notifications@github.com wrote:

@mrocklin https://github.com/mrocklin could I have your timing script to measure the latency? I'd like to compare on the same machine. I'm measuring the latency of PyTorch's distributed module to have a latency of 0.029ms with https://gist.github.com/stsievert/ 1a93cf732d66f22a3080fcd0729364d6

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/dask/dask-ml/issues/171#issuecomment-395906364, or mute the thread https://github.com/notifications/unsubscribe-auth/AASszCosKvJ7JX3OxKrLwXmztvOoFe5Aks5t6vd5gaJpZM4UG6C2 .

stsievert commented 6 years ago

Side by side on my machine I measure PyTorch 0.4.0 to have a latency of 0.031ms and dask.distributed master to have a latency of 1.398ms, both between two processes with Python 3.6.5. I measure with the same gist, https://gist.github.com/stsievert/1a93cf732d66f22a3080fcd0729364d6

mrocklin commented 6 years ago

I'm not surprised to see PyTorch be faster, however I am surprised to see that it can do things in 30us. Some things to check if you're interested:

  1. Try a roundtrip. Rank 0 should send data, then receive that same data in each step. This will avoid things like buffering or async sends from giving an overly optimistic result.
  2. Consider doing this with just plain Python sockets to get a baseline. This page seems to give a decent example: https://pymotw.com/2/socket/tcp.html
stsievert commented 6 years ago

I measure a latency of 16us with Python 3.6.5 and 19us with Python 2.7.5 for the code in the updated gist. I've slightly modified the example from https://pymotw.com/3/socket/tcp.html.

For PyTorch with a roundtrip, I'm measuring a latency of 37us.

mrocklin commented 6 years ago

OK cool. I'll readjust my internal expectations. I'm now a bit interested in setting up a single-threaded async example with Dask to measure where our overhead is coming from. This is probably not something that I would recommend you focus on though.

mrocklin commented 6 years ago

Raw Tornado has a roundtrip limit of around 200us on my machine

https://gist.github.com/609824993811575dd7d774f1eb5becc9

Under PyPy this goes down to about 60us

Again, this isn't a priority (we don't have any applications that currently care about this) but it's fun to fool around with

On Sun, Jun 10, 2018 at 8:22 PM, Scott Sievert notifications@github.com wrote:

I measure a latency of 16us with Python 3.6.5 and 19us with Python 2.7.5 for the code in the updated gist https://gist.github.com/stsievert/1a93cf732d66f22a3080fcd0729364d6. I've slightly modified the example from https://pymotw.com/3/socket/ tcp.html.

For PyTorch with a roundtrip connection, I'm measuring a latency of 37us.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/dask/dask-ml/issues/171#issuecomment-396093962, or mute the thread https://github.com/notifications/unsubscribe-auth/AASszDBX4kgJgCbzAE8BLg_6S1kmVyqSks5t7bhWgaJpZM4UG6C2 .

stsievert commented 6 years ago

Clarification: all the times I measured were estimates of the latency between two machines, not the roundtrip time. This is a factor of 2, so I would have estimated 100us for the raw Tornado example, and 30us under PyPy.

This is probably not something that I would recommend you focus on though.

Don't worry, I'm not planning to focus on this. I wanted to provide hard numbers for "1ms seems reasonable" https://github.com/dask/dask-ml/issues/171#issuecomment-391428668.

mrocklin commented 6 years ago

Another possibly cleaner way of implementing this would be by first implementing actors (see https://github.com/dask/distributed/issues/2109). However that task is probably blocked on myself for the near future.

Regardless though, I believe that before we spend engineering time on this problem we should have a case study problem against which we could apply it in order to guide our choices.

mrocklin commented 6 years ago

I tried raising a few case study examples from a conversation with Olivier, but I wasn't able to come up with a clean case study for a problem that requires a parameter server. @stsievert perhaps this is something you could think about and develop? I believe that developing a case study should block making technical progress on this issue.

stsievert commented 6 years ago

a clean case study for a problem that requires a parameter server.

I see a Dask parameter server as being useful in the same cases where the Tensorflow parameter server is used, because it also only works between machines. I see both of these as being useful when either

I see a Dask parameter server as playing a data management role and handing the computation to the optimization or deep learning library. It would only handle giving the data to each worker, and communicating the different updates produced by the optimization process. I think a Dask parameter server would make it easy to scale to larger datasets, and would be similar but different than various Spark learning libraries (e.g., BigDL, MLLib).

I do not see it as being as useful in the same cases as Horovod. I think Horovod is more performant than Dask in inter-GPU communication, but is harder to use and less fault-tolerant across multiple machines.

something you could ... develop?

I'll look for a use case where the Tensorflow parameter server is required and mirror that.

Another possibly cleaner way of implementing this would be by first implementing actors

I think actors should be implemented. This would only be useful when the workers store state, which would enable fancier communication methods for synchronous optimization. For an example of the speedups of different communication methods, see https://talwalkarlab.github.io/paleo/ and select strong scaling instead of the default weak scaling.

We don't need this if we want only asynchronous algorithms because the model has to be sent over the wire. We could do this with synchronous algorithms too, but it'd unnecessarily waste half the bandwidth.

I believe that developing a case study should block making technical progress on this issue.

:+1:

stsievert commented 6 years ago

@mrocklin is implementing stateful "actors" in https://github.com/dask/distributed/pull/2133, which are required for parameter servers.

stsievert commented 6 years ago

I've sketched a very rough parameter server: https://gist.github.com/stsievert/ff8a1df9300a82f15a2704e913469522.

And this naturally depends on https://github.com/dask/distributed/pull/2133.

mrocklin commented 6 years ago

Cool. I'm looking forward to playing with this.

What are your thoughts on a centralized parameter server approach? Not worth doing? I'd be curious to know what one gains from a fully decentralized system. This is probably just my general tendancy to always start with simple things and work up though.

On Mon, Jul 30, 2018 at 3:09 PM, Scott Sievert notifications@github.com wrote:

I've sketched a very rough parameter server: https://gist.github.com/ stsievert/ff8a1df9300a82f15a2704e913469522.

  • This is a decentralized example, so every worker has state. This allows fancy communication schemes (e.g., all-reduce).
  • This integrates pretty nicely with the PyTorch's MNIST example https://github.com/pytorch/examples/tree/master/mnist; only small modifications were required to their functions.

And this naturally depends on dask/distributed#2133 https://github.com/dask/distributed/pull/2133.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/dask/dask-ml/issues/171#issuecomment-409028338, or mute the thread https://github.com/notifications/unsubscribe-auth/AASszM86-KsTpQbQXxkvvGcCbTLwjToCks5uL4QigaJpZM4UG6C2 .

stsievert commented 6 years ago

What are your thoughts on a centralized parameter server approach?

Good question. I've considered this.

Decentralized parameter server really means that every worker can store state because it has an actor. With this, the communication between these workers can be customized while letting the actor manage the computation. Customization of the communication scheme can include:

There are real speed benefits for all-reduce communication scheme: https://talwalkarlab.github.io/paleo/ (select "strong scaling" instead of "weak scaling").

One good example of the centralized vs decentralized parameter servers is with Hogwild and Hogwild++ respectively, where Hogwild is an asynchronous optimization algorithm and Hogwild++ is a slight modification. Hogwild has moderate speedups in time to a particular accuracy, and Hogwild++ has near-linear speedups (see Figures 5 in their paper).

mrocklin commented 6 years ago

I understand that there are potential performance benefits to decentralizing things. However I'd like for us to start small if possible, if for no other reason than to demonstrate that a distributed approach is valuable. I suspect that we'll also learn a lot in the process.

mrocklin commented 6 years ago

I had a good time running through the notebook last night and things ran smoothly. One thing I really missed was getting feedback both about how the cluster was performing (the diagnostics are silent with actors) and on how well the model was being trained.

FT

mrocklin commented 6 years ago

Whoops, misfire.

The diagnostics issue is something that I'll take care of for actors, probably by creating a diagnostics page on each of the workers's bokeh servers.

The "how well is the model training" is a more open question though. What's the right way to do this? Periodically ask some worker to score the testing or validation set and plot the results over time?

stsievert commented 6 years ago

to demonstrate that a distributed approach is valuable.

Good point. I should have numbers to back my decisions. I'll generate a comparison.

The centralized PS does spend half it's bandwidth on communicating the model, something that's not necessary with the decentralized approach. The centralized approach has approximately the same latency cost (about the same number of connections, depending on the reduce algorithm used). However, it is more resilient because only one worker is holding state.

Most existing work uses all-reduce, including Horovod and some distributed gradient coding papers for their experiments (QSGD, TernGrad). I've seen a centralized parameter server used with distributed Tensorflow.

stsievert commented 6 years ago

"how well is the model training"

I'd want to see {test, train} accuracy over {time, epochs}.

The diagnostics issue is something that I'll take care of for actor

Good to hear. It's also difficult to debug actors. What can be done to help with that?

stsievert commented 6 years ago

I've updated the gist with a centralized parameter server: https://gist.github.com/stsievert/ff8a1df9300a82f15a2704e913469522

I'll profile these two implementations while varying the model size. This small network for MNIST is 84KB with 21k 32-bit floats. A small ResNet is 44MB, and a deep ResNet is 169MB.

I can look at dataset size too, but it only happens once at the beginning. A bigger issue is probably with scaling the number of workers.

mrocklin commented 6 years ago

Do you have a cluster to try out scaling? If not, we could work through http://dask.pydata.org/en/latest/setup/kubernetes-helm.html together tomorrow if that would be helpful.

On Tue, Jul 31, 2018 at 2:38 PM, Scott Sievert notifications@github.com wrote:

I've updated the gist with a centralized parameter server: https://gist.github.com/stsievert/ff8a1df9300a82f15a2704e913469522

I'll profile these two implementations while varying the model size. This small network for MNIST is 84KB with 21k 32-bit floats. A small ResNet is 44MB, and a deep ResNet is 169MB.

I can look at dataset size too, but it only happens once at the beginning. A bigger issue is probably with scaling the number of workers.

— You are receiving this because you modified the open/close state. Reply to this email directly, view it on GitHub https://github.com/dask/dask-ml/issues/171#issuecomment-409376616, or mute the thread https://github.com/notifications/unsubscribe-auth/AASszEzvSCcvNK0RxO1K5HDtcWfzRO0bks5uMM5sgaJpZM4UG6C2 .

mrocklin commented 6 years ago

@lesteve you might find @stsievert 's notebooks above interesting

stsievert commented 6 years ago

I've started a repo for those notebooks: https://github.com/stsievert/dask-ps, and have modified the implementation a bit. I see these timings with dask.distributed master:

Currently, the parameter server is spending about 50% of it's time waiting (I think; this time is spent in e.wait(10) on distributed/utils.py#L275). Most of the other 50% is spent in my train function computing the gradient, so I think this is most of the comm_model time.

This result can be reproduced by running Centralized-PS.ipynb in https://github.com/stsievert/dask-ps with dask.distributed master. Next steps I'm planning are to test this with a larger model and see where the bottlenecks are.

mrocklin commented 6 years ago

@stsievert it looks like we're still spending 60% or so of our time waiting on synchronization, which seems to be the main bottleneck for scaling. Do you have any thoughts on asynchronous methods that don't wait for all updates to arrive before modifying the model?

mrocklin commented 6 years ago

I've pushed what I think an async parameter server would look like here: https://gist.github.com/e57551e451bd633e13009ccce9c6ff67

Feedback on whether or not this is valid would be appreciated

mrocklin commented 6 years ago

It looks like we're currently bound by serializing and communicating models around. We can balance this out by using larger batches, but that probably affects scalability of training. It would be useful to know how much increasing the batch sizes affects training performance over time.

What is the right way to score a model like this? Perhaps we can have the client periodically pull down a model and score it against some testing dataset and then keep a time series of these scores over time?

mrocklin commented 6 years ago

As in the following:

ps = client.submit(PS, ..., actor=True)
futures = [client.submit(worker, ps, ...) for _ in range(n_workers)]

times = []
scores = []
while not all(future.done() for future in futures):
    model = ps.model
    times.append(time())
    scores.append(score(model, test_data))
    sleep(0.5)

plot(times, scores)
stsievert commented 6 years ago

I've reworked your notebook: https://github.com/stsievert/dask-ps/blob/master/Centralized-PS.ipynb. This notebook is functionally the same as my previous implementation, but it spends less time waiting. I'm not sure why.

The communication scheme this notebook implements is "synchronous with backup workers", which is detailed below.

Do you have any thoughts on asynchronous methods that don't wait for all updates to arrive before modifying the model?

Depends on what you mean by "asynchronous" – either the workers don't wait on each other, or the optimization algorithm can return partially updated models. Figure 5 of the Tensorflow paper has a good depiction of the difference:

screen shot 2018-08-07 at 10 04 07 am 2

Using the "synchronous with backup workers" method will use the users optimization method and allow the workers to operate independently, plus the timing won't be effected by stragglers.

More of this difference and the performance benefits are detailed in "Revisiting distributed synchronous SGD".

gist.github.com/e57551e451bd633e13009ccce9c6ff67 Feedback on whether or not this is valid would be appreciated

I think this implements Hogwild, where workers blindly run SGD (hence the name). This does have convergence guarantees, but they're fairly weak. They require a model sparsity constraint and a limit on how long the workers can take to update the model.

What is the right way to score a model like this?

Your implementation looks good. I think that'd be the easiest method.

stsievert commented 6 years ago

I've timed the parameter server in the notebook Centralized-PS.ipynb@a971d2. The parameter server behind this graph implements the "synchronous communication with backup workers", but without any backup workers (i.e, workers never wait and can repeat gradient computations).

Here's a quick summary from that notebook:

Total optimization time Optimization time components
/>

The solid lines are the median, the shaded color borders specify the 95% confidence interval. I get this on my local machine (16GB RAM, no GPU) when timing the MNIST PyTorch example (which is unrealistic: gradients take a long time to compute but the model is small).

mrocklin commented 6 years ago

So, looking at the profile page I see something similar-ish.

Of the 11.50 CPU-seconds it took to run a 4-worker training process

There are a couple other tiny things, but not worth mentioning.

There were 215 push calls recorded on the parameter server, which places each of them at around 30ms, which is longer than I'd like. During this time bandwidth on the worker with the actor reached around 50MB/s, which is nowhere near saturating local network speed. The Worker's CPU was being used during this time, mostly on networking overhead and data compression, and a little bit on deserialization.

In isolation a single call to train takes around 10ms and pulling a model takes 9ms (though this is proxied through the scheduler so this is conservative)

mrocklin commented 6 years ago

There are some things we can do to reduce worker overhead:

  1. Experiment with turning off compression (I'll do this now)
  2. Improve serialization of PyTorch objects so that we get straight memoryviews out quickly

I'm curious though how expensive train calls are going to be normally. Is 10ms typical, or do these become more expensive? It may be that training a batch on MNIST is too fast for dask to be of much use.

mrocklin commented 6 years ago

Turning compression off results in 380MB/s bandwidth and slightly faster performance overall. Offloading compression to a separate thread seems to get to lower bandwidth and decent performance. This might be the solution longer term. The actor worker is now not quite as hot (event loop taking up around 60% of a CPU) and actor wait time is down to around 4s total.

stsievert commented 6 years ago

how expensive train calls are going to be normally. Is 10ms typical, or do these become more expensive?

A good benchmark for this is DAWNBench. Top performing deep learning models have a inference latency of 10ms, and a couple spots down is about 20ms. That means the train call is probably costs 40 or 50ms. The fastest I've seen for training is about 100ms.

It may be that training a batch on MNIST is too fast for dask to be of much use.

Sounds like we scale to a larger model, which isn't too hard. I was only using MNIST as an easy and basic example.

mrocklin commented 6 years ago

Ah interesting, I hadn't drawn that connection before. So because training is just applying the model + back propagation (which is roughly the same order of magnitude) we would expect something like 2x of the inference cost.

But for a minibatch would we then multiply by the batch size? Or is it generally much cheaper to compute a batch than to compute many samples individually.

stsievert commented 6 years ago

But for a minibatch would we then multiply by the batch size? Or is it generally much cheaper to compute a batch than to compute many samples individually.

We can multiply by the batch size. The only cheaper faster way to calculate the gradient is to feed them through a GPU (which is why batch sizes are typically multiples of 32).

Calculating the gradient for many examples requires knowing the how wrong the prediction is on each example, and then adding up their gradients.

training is just applying the model + back propagation (which is roughly the same order of magnitude)

Correct.

Details: the gradient calculation is done via reverse mode autodiff, which typically requires no more than 5-6x more floating point operations for an typical gradient than a function evaluation. On my machine, it takes 2x longer on average to calculate the gradient vs getting the output with PyTorch's (simple) MNIST network.

(also, turns out autodiff is faster than hard-coding the gradient with NumPy even though it does more work because the graph autodiff requires can be parallelized: https://stsievert.com/blog/2017/09/07/pytorch/#speed)

mrocklin commented 6 years ago

typically requires no more than 5-6x more floating point operations for an typical gradient

Heh, not just typically. My understanding is that the 5x number is part of a theorem in the early work on the subject :)

1-2x matches my experience as well, though obviously this depends on the system you use. There is a lot of fun work in that field to suit all sorts :)

stsievert commented 6 years ago

My understanding is that the 5x number is part of a theorem in the early work on the subject :)

Almost – you're off by one! It's guaranteed to be strictly less than 6:

a constant guaranteed to be c < 6 and typically c ~ [2, 3] —http://jmlr.org/papers/volume18/17-468/17-468.pdf, page 13

mrocklin commented 6 years ago

Interesting, this differs from Griewank. Going back in time a bit...

In 1982 Phil Wolfe [31] made the following observation regarding the ratio between the cost of evaluating a gradient with n components and the cost of evaluating the underlying scalar function.

If care is taken in handling quantities which are common to the function and derivatives, the ratio is usually around 1.5, not n+1. [31]

The main purpose of this article is to demonstrate that Phil Wolfe's observation is in fact a theorem (with the average 1.5 replaced by an upper bound 5)

https://www.researchgate.net/profile/Andreas_Griewank/publication/2703247_On_Automatic_Differentiation/links/0c96052529013aed9e000000/On-Automatic-Differentiation.pdf

stsievert commented 6 years ago

Interesting, this differs from Griewank.

Hm... the review paper I referenced cited Griewank and Walther's 2008 textbook Evaluating Derivatives. In this, I couldn't find a claim that c < 6. I only found that c <= 4 on page 85, but I only skimmed. Without a bound on memory access time, this bound grows to c <= 5. This does rely on some (fair) assumptions about the relative cost (e.g., "that a memory access is at least as slow as a multiplication which is in turn at least as slow as an addition").

Thanks for the reference.