Open kinalmehta opened 1 year ago
I managed to get the inference server up. But it gets stuck at the action selection step.
Based on my understanding, launchpad only supports calling a CourierNode using python types. However, when passing observations and LSTM states which are usually numpy
or jax
type objects, the system stops.
Any hints on how this can be handled? I tried using pickle to serialize and deserialize the objects but launchpad doesn't seem to accept byte
type objects.
Hi, did you make any headway on this? Running into the same thing/ I saw you commented on a related launchpad issue since your last update. Thanks!
I am trying to run r2d2 with batched inference and am also running into this issue!
I managed to get the Inference Server up and running, though I use it in my personal code-repo, not with acme. But there were a lot of other hacks I had to do to get it running.
And based on my experiments, I found that the standard IMPALA architecture with CPU inference is much faster than the Inference Server based action prediction. It could be that my network is small, and the overhead of making the call to inference server is higher than using CPU for inference.
Thanks for the response! Seems like in our case, at the very least there are some inconsistencies with how timelimits are being passed around by acme that we need to get to the bottom of.
If you have installed launchpad from pip, it is launchpad's bug.
It can be fixed by editing python3.9/site-packages/launchpad/nodes/courier/courier_utils.py
in your python environment installation of launchpad.
You can refer the attached python script to fix your launchpad installation for the timelimits bug. fix_launchpad.txt
Hi, would you mind sharing more of what you did to get InferenceServer working? Think I'm treading the same path as you. I made the change to timeout
which got the inference server called. Turned off jit
because of the variable_client
thing. Replaced the numpy arrays in obs/action/reward
with jax arrays, which stopped InferenceServer from hanging, and lets the handler be called.
But now it seems that the arguments to the select_action
function aren't working with the vmap
in make_distributed_experiment
. Not positive but I think the main problem is that the LSTM state isn't something you can vmap
easily? I'm thinking of designing around it by directly passing on the parts of the recurrent state that are tensors instead of the whole class. But wanted to know if you can help me skip to the finish line instead :)
Thanks for your insights so far!
Here is the modified inference server script I've been using
import dataclasses
import datetime
import threading
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar
import acme
import haiku as hk
import jax
import jax.numpy as jnp
import jaxlib
import launchpad as lp
import numpy as np
from acme.jax import variable_utils
@dataclasses.dataclass
class InferenceServerConfig:
"""Configuration options for centralised inference.
Attributes:
batch_size: How many elements to batch together per single inference call.
Auto-computed when not specified.
update_period: Frequency of updating variables from the variable source.
It is passed to VariableClient. Auto-computed when not specified.
timeout: Time after which incomplete batch is executed (batch is padded,
so there batch handler is always called with batch_size elements).
By default timeout is effectively disabled (set to 30 days).
"""
batch_size: Optional[int] = None
update_period: Optional[int] = None
timeout: datetime.timedelta = datetime.timedelta(days=30)
InferenceServerHandler = TypeVar('InferenceServerHandler')
class InferenceServer(Generic[InferenceServerHandler]):
"""Centralised, batched inference server."""
def __init__(self, handler: InferenceServerHandler,
variable_source: acme.VariableSource,
devices: Sequence[jax.xla.Device],
config: InferenceServerConfig):
"""Constructs an inference server object.
Args:
handler: A callable or a mapping of callables to be exposed
through the inference server.
variable_source: Source of variables
devices: Devices used for executing handlers. All devices are used in
parallel.
config: Inference Server configuration.
"""
self._variable_source = variable_source
self._variable_client = None
self._keys = []
self._devices = devices
self._config = config
self._call_cnt = 0
self._device_params = [None] * len(self._devices)
self._device_params_ids = [None] * len(self._devices)
self._mutex = threading.Lock()
self._handler = jax.tree_map(self._build_handler, handler, is_leaf=callable)
@property
def handler(self) -> InferenceServerHandler:
return self._handler
def _dereference_params(self, arg):
"""Replaces VariableReferences with their corresponding param values."""
if not isinstance(arg, variable_utils.VariableReference):
# All arguments but VariableReference are returned without modifications.
arg = postprocess_data(arg)
return arg
# Due to batching dimension we take the first element.
variable_name = arg.variable_name[0]
if variable_name not in self._keys:
# Create a new VariableClient which also serves new variables.
self._keys.append(variable_name)
self._variable_client = variable_utils.VariableClient(
client=self._variable_source,
key=self._keys,
update_period=self._config.update_period)
params = self._variable_client.params
device_idx = self._call_cnt % len(self._devices)
# Select device via round robin, and update its params if they changed.
if self._device_params_ids[device_idx] != id(params):
self._device_params_ids[device_idx] = id(params)
self._device_params[device_idx] = jax.device_put(
params, self._devices[device_idx])
# Return the params that are located on the chosen device.
device_params = self._device_params[device_idx]
if len(self._keys) == 1:
return device_params
return device_params[self._keys.index(variable_name)]
def _build_handler(self, handler: Callable[..., Any]) -> Callable[..., Any]:
"""Builds a batched handler for a given callable handler and its name."""
def dereference_params_and_call_handler(*args, **kwargs):
with self._mutex:
# Dereference args corresponding to params, leaving others unchanged.
args_with_dereferenced_params = [
self._dereference_params(arg) for arg in args
]
kwargs_with_dereferenced_params = {
key: self._dereference_params(value)
for key, value in kwargs.items()
}
self._call_cnt += 1
# Maybe update params, depending on client configuration.
if self._variable_client is not None:
self._variable_client.update()
op = handler(*args_with_dereferenced_params,
**kwargs_with_dereferenced_params)
op = jax.tree_util.tree_map(lambda x: list(x), op)
return op
return lp.batched_handler(
batch_size=self._config.batch_size,
timeout=self._config.timeout,
pad_batch=True,
max_parallelism=2 * len(self._devices))(
dereference_params_and_call_handler)
def postprocess_data(data):
if type(data) is dict:
data = {k: postprocess_data(v) for k, v in data.items()}
return data
elif type(data) is list:
return jnp.array(data)
elif type(data) is hk.LSTMState:
return hk.LSTMState(
hidden=jnp.array(data.hidden), cell=jnp.array(data.cell))
else:
raise ValueError(f'Unsupported data type: {type(data)}')
You might have to modify postprocess_data
based on the data you pass to the inference server.
Basically, the leaf of the data tree you pass to the handler
should have the zeroth dimension as the batch dimension.
Cool, thanks a bunch! I ended up with a similar solution, posting here in case its helpful to you or the next person. Uses this gist that stacks arbitrary pytrees, and then unstacks them to return a list of values. I needed to do some other postprocessing to deal with how acme stacks the observations, but I think that was unavoidable.
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines Inference Server class used for centralised inference."""
import dataclasses
import datetime
import threading
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar
import acme
from acme.jax import variable_utils
import jax
import launchpad as lp
import numpy as np
from jax import numpy as jnp
# from jax.lib import pytree
from jax.tree_util import tree_flatten, tree_unflatten
def tree_stack(trees):
"""Takes a list of trees and stacks every corresponding leaf.
For example, given two trees ((a, b), c) and ((a', b'), c'), returns
((stack(a, a'), stack(b, b')), stack(c, c')).
Useful for turning a list of objects into something you can feed to a
vmapped function.
"""
leaves_list = []
treedef_list = []
for tree in trees:
leaves, treedef = tree_flatten(tree)
leaves_list.append(leaves)
treedef_list.append(treedef)
grouped_leaves = zip(*leaves_list)
result_leaves = [jnp.stack(l) for l in grouped_leaves]
# return treedef_list[0].unflatten(result_leaves), treedef_list[0]
return treedef_list[0].unflatten(result_leaves)
def tree_unstack(tree):
"""Takes a tree and turns it into a list of trees. Inverse of tree_stack.
For example, given a tree ((a, b), c), where a, b, and c all have first
dimension k, will make k trees
[((a[0], b[0]), c[0]), ..., ((a[k], b[k]), c[k])]
Useful for turning the output of a vmapped function into normal objects.
"""
leaves, treedef = tree_flatten(tree)
n_trees = leaves[0].shape[0]
new_leaves = [[] for _ in range(n_trees)]
for leaf in leaves:
for i in range(n_trees):
new_leaves[i].append(leaf[i])
new_trees = [treedef.unflatten(l) for l in new_leaves]
return new_trees
tree_stack = jax.jit(tree_stack)
tree_unstack = jax.jit(tree_unstack)
@dataclasses.dataclass
class InferenceServerConfig:
"""Configuration options for centralised inference.
Attributes:
batch_size: How many elements to batch together per single inference call.
Auto-computed when not specified.
update_period: Frequency of updating variables from the variable source.
It is passed to VariableClient. Auto-computed when not specified.
timeout: Time after which incomplete batch is executed (batch is padded,
so there batch handler is always called with batch_size elements).
By default timeout is effectively disabled (set to 30 days).
"""
batch_size: Optional[int] = None
update_period: Optional[int] = None
timeout: datetime.timedelta = datetime.timedelta(days=30)
InferenceServerHandler = TypeVar('InferenceServerHandler')
def reverse_oar_thing(oar_thing):
from acme.wrappers.observation_action_reward import OAR
num_things = len(oar_thing.action)
new_oars = [OAR(
action=oar_thing.action[i],
reward=oar_thing.reward[i],
observation=oar_thing.observation[i]
) for i in range(num_things)]
return new_oars
class InferenceServer(Generic[InferenceServerHandler]):
"""Centralised, batched inference server."""
def __init__(self, handler: InferenceServerHandler,
variable_source: acme.VariableSource,
devices: Sequence[jax.xla.Device],
config: InferenceServerConfig):
"""Constructs an inference server object.
Args:
handler: A callable or a mapping of callables to be exposed
through the inference server.
variable_source: Source of variables
devices: Devices used for executing handlers. All devices are used in
parallel.
config: Inference Server configuration.
"""
self._variable_source = variable_source
self._variable_client = None
self._keys = []
self._devices = devices
self._config = config
self._call_cnt = 0
self._device_params = [None] * len(self._devices)
self._device_params_ids = [None] * len(self._devices)
self._mutex = threading.Lock()
self._handler = jax.tree_map(self._build_handler, handler, is_leaf=callable)
@property
def handler(self) -> InferenceServerHandler:
return self._handler
def _dereference_params(self, arg):
"""Replaces VariableReferences with their corresponding param values."""
if not isinstance(arg, variable_utils.VariableReference):
# All arguments but VariableReference are returned without modifications.
return arg
# Due to batching dimension we take the first element.
variable_name = arg.variable_name[0]
if variable_name not in self._keys:
# Create a new VariableClient which also serves new variables.
self._keys.append(variable_name)
self._variable_client = variable_utils.VariableClient(
client=self._variable_source,
key=self._keys,
update_period=self._config.update_period)
params = self._variable_client.params
device_idx = self._call_cnt % len(self._devices)
# Select device via round robin, and update its params if they changed.
if self._device_params_ids[device_idx] != id(params):
self._device_params_ids[device_idx] = id(params)
self._device_params[device_idx] = jax.device_put(
params, self._devices[device_idx])
# Return the params that are located on the chosen device.
device_params = self._device_params[device_idx]
if len(self._keys) == 1:
return device_params
return device_params[self._keys.index(variable_name)]
def _build_handler(self, handler: Callable[..., Any]) -> Callable[..., Any]:
"""Builds a batched handler for a given callable handler and its name."""
print('[inference_server] calling _build_handler')
def dereference_params_and_call_handler(*args, **kwargs):
with self._mutex:
# Dereference args corresponding to params, leaving others unchanged.
args_with_dereferenced_params = [
self._dereference_params(arg) for arg in args
]
kwargs_with_dereferenced_params = {
key: self._dereference_params(value)
for key, value in kwargs.items()
}
self._call_cnt += 1
# Maybe update params, depending on client configuration.
if self._variable_client is not None:
self._variable_client.update()
params = args_with_dereferenced_params[0]
oar_thing = args_with_dereferenced_params[1]
reversed_oar = reverse_oar_thing(oar_thing)
recurrent_state_thing = args_with_dereferenced_params[2]
reversed_oar_stacked = tree_stack(reversed_oar)
recurrent_state_thing_stacked = tree_stack(recurrent_state_thing)
to_return = handler(params, reversed_oar_stacked, recurrent_state_thing_stacked)
unstacked_to_return = tree_unstack(to_return)
print('[inference_server] returning unstacked')
return unstacked_to_return
# return handler(*args_with_dereferenced_params,
# **kwargs_with_dereferenced_params)
to_return = lp.batched_handler(
batch_size=self._config.batch_size,
timeout=self._config.timeout,
pad_batch=True,
max_parallelism=2 * len(self._devices))(
dereference_params_and_call_handler)
return to_return
Great!! BTW if you don't mind, please do share how the performance compares with the standard CPU based inference in your case and whether using InferenceServer helped or not.
We're getting slightly better performance with the inference_server compared to without (3800 FPS vs 2200 FPS, as well as slightly better training steps/second). The key was to instantiate multiple copies of the inference server. I noticed that inference_server had very poor GPU utilization on its own. Ideally we figure out why that is instead of doing multiple.
I managed to get the inference server up. But it gets stuck at the action selection step. Based on my understanding, launchpad only supports calling a CourierNode using python types. However, when passing observations and LSTM states which are usually
numpy
orjax
type objects, the system stops.Any hints on how this can be handled? I tried using pickle to serialize and deserialize the objects but launchpad doesn't seem to accept
byte
type objects.
Hey @kinalmehta did you ever figure out how to pass around numpy arrays (or non-python primitives in general) using courier? I am able to pass around python ints and floats, but need to pass around np uint arrays, but as you said, doing that causes the system to stop.
Hi @abagaria, As mentioned here, launchpad seems to work fine with Jax arrays, so i use Jax arrays instead of numpy arrays.
I am trying to use seel-RL style centralized server for inference. I modified the IMPALA example by adding the
inference_server.InferenceServerConfig
, which is passed toexperiments.make_distributed_experiment
. I keep getting the below error when running the modified code.Code
Only the
main
function has been modified to be as belowError
Could you please help me run the
InferenceServer
based training?Thanks Kinal