google-deepmind / acme

A library of reinforcement learning components and agents
Apache License 2.0
3.5k stars 426 forks source link

InferenceServer Error #279

Open kinalmehta opened 1 year ago

kinalmehta commented 1 year ago

I am trying to use seel-RL style centralized server for inference. I modified the IMPALA example by adding the inference_server.InferenceServerConfig, which is passed to experiments.make_distributed_experiment. I keep getting the below error when running the modified code.

Code

Only the main function has been modified to be as below


def main(_):
  experiment_config = build_experiment_config()
  if RUN_DISTRIBUTED.value:
    program = experiments.make_distributed_experiment(
        experiment=experiment_config,
        num_actors=4,
        inference_server_config=inference_server.InferenceServerConfig(batch_size=32, 
          update_period=1000, 
          timeout=datetime.timedelta(seconds=100, milliseconds=0, microseconds=0))
        )
    lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
  else:
    experiments.run_experiment(experiment_config)

Error

    batcher = pybind.BuildBatchedHandlerWrapper(func.__name__, handler,
TypeError: BuildBatchedHandlerWrapper(): incompatible function arguments. The following argument types are supported:
    1. (arg0: str, arg1: courier::HandlerInterface, arg2: int, arg3: int, arg4: int, arg5: bool) -> courier::HandlerInterface

Invoked with: 'dereference_params_and_call_handler', <courier.handlers.python.pybind.HandlerInterface object at 0x7f36002f7b30>, 32, 2, 0.0, True
terminate called without an active exception
Fatal Python error: Aborted

Thread 0x00007f36c93a5740 (most recent call first):

Could you please help me run the InferenceServer based training?

Thanks Kinal

kinalmehta commented 1 year ago

I managed to get the inference server up. But it gets stuck at the action selection step. Based on my understanding, launchpad only supports calling a CourierNode using python types. However, when passing observations and LSTM states which are usually numpy or jax type objects, the system stops.

Any hints on how this can be handled? I tried using pickle to serialize and deserialize the objects but launchpad doesn't seem to accept byte type objects.

samlobel commented 1 year ago

Hi, did you make any headway on this? Running into the same thing/ I saw you commented on a related launchpad issue since your last update. Thanks!

abagaria commented 1 year ago

I am trying to run r2d2 with batched inference and am also running into this issue!

kinalmehta commented 1 year ago

I managed to get the Inference Server up and running, though I use it in my personal code-repo, not with acme. But there were a lot of other hacks I had to do to get it running.

And based on my experiments, I found that the standard IMPALA architecture with CPU inference is much faster than the Inference Server based action prediction. It could be that my network is small, and the overhead of making the call to inference server is higher than using CPU for inference.

samlobel commented 1 year ago

Thanks for the response! Seems like in our case, at the very least there are some inconsistencies with how timelimits are being passed around by acme that we need to get to the bottom of.

kinalmehta commented 1 year ago

If you have installed launchpad from pip, it is launchpad's bug. It can be fixed by editing python3.9/site-packages/launchpad/nodes/courier/courier_utils.py in your python environment installation of launchpad.

You can refer the attached python script to fix your launchpad installation for the timelimits bug. fix_launchpad.txt

samlobel commented 1 year ago

Hi, would you mind sharing more of what you did to get InferenceServer working? Think I'm treading the same path as you. I made the change to timeout which got the inference server called. Turned off jit because of the variable_client thing. Replaced the numpy arrays in obs/action/reward with jax arrays, which stopped InferenceServer from hanging, and lets the handler be called.

But now it seems that the arguments to the select_action function aren't working with the vmap in make_distributed_experiment. Not positive but I think the main problem is that the LSTM state isn't something you can vmap easily? I'm thinking of designing around it by directly passing on the parts of the recurrent state that are tensors instead of the whole class. But wanted to know if you can help me skip to the finish line instead :)

Thanks for your insights so far!

kinalmehta commented 1 year ago

Here is the modified inference server script I've been using

import dataclasses
import datetime
import threading
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar

import acme
import haiku as hk
import jax
import jax.numpy as jnp
import jaxlib
import launchpad as lp
import numpy as np
from acme.jax import variable_utils

@dataclasses.dataclass
class InferenceServerConfig:
  """Configuration options for centralised inference.

  Attributes:
    batch_size: How many elements to batch together per single inference call.
        Auto-computed when not specified.
    update_period: Frequency of updating variables from the variable source.
        It is passed to VariableClient. Auto-computed when not specified.
    timeout: Time after which incomplete batch is executed (batch is padded,
        so there batch handler is always called with batch_size elements).
        By default timeout is effectively disabled (set to 30 days).
  """
  batch_size: Optional[int] = None
  update_period: Optional[int] = None
  timeout: datetime.timedelta = datetime.timedelta(days=30)

InferenceServerHandler = TypeVar('InferenceServerHandler')

class InferenceServer(Generic[InferenceServerHandler]):
  """Centralised, batched inference server."""

  def __init__(self, handler: InferenceServerHandler,
               variable_source: acme.VariableSource,
               devices: Sequence[jax.xla.Device],
               config: InferenceServerConfig):
    """Constructs an inference server object.

    Args:
      handler: A callable or a mapping of callables to be exposed
        through the inference server.
      variable_source: Source of variables
      devices: Devices used for executing handlers. All devices are used in
        parallel.
      config: Inference Server configuration.
    """
    self._variable_source = variable_source
    self._variable_client = None
    self._keys = []
    self._devices = devices
    self._config = config
    self._call_cnt = 0
    self._device_params = [None] * len(self._devices)
    self._device_params_ids = [None] * len(self._devices)
    self._mutex = threading.Lock()
    self._handler = jax.tree_map(self._build_handler, handler, is_leaf=callable)

  @property
  def handler(self) -> InferenceServerHandler:
    return self._handler

  def _dereference_params(self, arg):
    """Replaces VariableReferences with their corresponding param values."""

    if not isinstance(arg, variable_utils.VariableReference):
      # All arguments but VariableReference are returned without modifications.
      arg = postprocess_data(arg)
      return arg

    # Due to batching dimension we take the first element.
    variable_name = arg.variable_name[0]

    if variable_name not in self._keys:
      # Create a new VariableClient which also serves new variables.
      self._keys.append(variable_name)
      self._variable_client = variable_utils.VariableClient(
          client=self._variable_source,
          key=self._keys,
          update_period=self._config.update_period)

    params = self._variable_client.params
    device_idx = self._call_cnt % len(self._devices)
    # Select device via round robin, and update its params if they changed.
    if self._device_params_ids[device_idx] != id(params):
      self._device_params_ids[device_idx] = id(params)
      self._device_params[device_idx] = jax.device_put(
          params, self._devices[device_idx])

    # Return the params that are located on the chosen device.
    device_params = self._device_params[device_idx]
    if len(self._keys) == 1:
      return device_params
    return device_params[self._keys.index(variable_name)]

  def _build_handler(self, handler: Callable[..., Any]) -> Callable[..., Any]:
    """Builds a batched handler for a given callable handler and its name."""

    def dereference_params_and_call_handler(*args, **kwargs):
      with self._mutex:
        # Dereference args corresponding to params, leaving others unchanged.
        args_with_dereferenced_params = [
            self._dereference_params(arg) for arg in args
        ]
        kwargs_with_dereferenced_params = {
            key: self._dereference_params(value)
            for key, value in kwargs.items()
        }
        self._call_cnt += 1

        # Maybe update params, depending on client configuration.
        if self._variable_client is not None:
          self._variable_client.update()

      op = handler(*args_with_dereferenced_params,
                   **kwargs_with_dereferenced_params)

      op = jax.tree_util.tree_map(lambda x: list(x), op)
      return op

    return lp.batched_handler(
        batch_size=self._config.batch_size,
        timeout=self._config.timeout,
        pad_batch=True,
        max_parallelism=2 * len(self._devices))(
            dereference_params_and_call_handler)

def postprocess_data(data):
  if type(data) is dict:
    data = {k: postprocess_data(v) for k, v in data.items()}
    return data
  elif type(data) is list:
    return jnp.array(data)
  elif type(data) is hk.LSTMState:
    return hk.LSTMState(
        hidden=jnp.array(data.hidden), cell=jnp.array(data.cell))
  else:
    raise ValueError(f'Unsupported data type: {type(data)}')

You might have to modify postprocess_data based on the data you pass to the inference server.

Basically, the leaf of the data tree you pass to the handler should have the zeroth dimension as the batch dimension.

samlobel commented 1 year ago

Cool, thanks a bunch! I ended up with a similar solution, posting here in case its helpful to you or the next person. Uses this gist that stacks arbitrary pytrees, and then unstacks them to return a list of values. I needed to do some other postprocessing to deal with how acme stacks the observations, but I think that was unavoidable.

# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines Inference Server class used for centralised inference."""

import dataclasses
import datetime
import threading
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar
import acme
from acme.jax import variable_utils
import jax
import launchpad as lp

import numpy as np

from jax import numpy as jnp
# from jax.lib import pytree
from jax.tree_util import tree_flatten, tree_unflatten

def tree_stack(trees):
  """Takes a list of trees and stacks every corresponding leaf.
  For example, given two trees ((a, b), c) and ((a', b'), c'), returns
  ((stack(a, a'), stack(b, b')), stack(c, c')).
  Useful for turning a list of objects into something you can feed to a
  vmapped function.
  """
  leaves_list = []
  treedef_list = []
  for tree in trees:
    leaves, treedef = tree_flatten(tree)
    leaves_list.append(leaves)
    treedef_list.append(treedef)

  grouped_leaves = zip(*leaves_list)
  result_leaves = [jnp.stack(l) for l in grouped_leaves]
  # return treedef_list[0].unflatten(result_leaves), treedef_list[0]
  return treedef_list[0].unflatten(result_leaves)

def tree_unstack(tree):
  """Takes a tree and turns it into a list of trees. Inverse of tree_stack.
  For example, given a tree ((a, b), c), where a, b, and c all have first
  dimension k, will make k trees
  [((a[0], b[0]), c[0]), ..., ((a[k], b[k]), c[k])]
  Useful for turning the output of a vmapped function into normal objects.
  """
  leaves, treedef = tree_flatten(tree)
  n_trees = leaves[0].shape[0]
  new_leaves = [[] for _ in range(n_trees)]
  for leaf in leaves:
    for i in range(n_trees):
      new_leaves[i].append(leaf[i])
  new_trees = [treedef.unflatten(l) for l in new_leaves]
  return new_trees

tree_stack = jax.jit(tree_stack)
tree_unstack = jax.jit(tree_unstack)

@dataclasses.dataclass
class InferenceServerConfig:
  """Configuration options for centralised inference.

  Attributes:
    batch_size: How many elements to batch together per single inference call.
        Auto-computed when not specified.
    update_period: Frequency of updating variables from the variable source.
        It is passed to VariableClient. Auto-computed when not specified.
    timeout: Time after which incomplete batch is executed (batch is padded,
        so there batch handler is always called with batch_size elements).
        By default timeout is effectively disabled (set to 30 days).
  """
  batch_size: Optional[int] = None
  update_period: Optional[int] = None
  timeout: datetime.timedelta = datetime.timedelta(days=30)

InferenceServerHandler = TypeVar('InferenceServerHandler')

def reverse_oar_thing(oar_thing):
  from acme.wrappers.observation_action_reward import OAR
  num_things = len(oar_thing.action)
  new_oars = [OAR(
    action=oar_thing.action[i],
    reward=oar_thing.reward[i],
    observation=oar_thing.observation[i]
    ) for i in range(num_things)]
  return new_oars

class InferenceServer(Generic[InferenceServerHandler]):
  """Centralised, batched inference server."""

  def __init__(self, handler: InferenceServerHandler,
               variable_source: acme.VariableSource,
               devices: Sequence[jax.xla.Device],
               config: InferenceServerConfig):
    """Constructs an inference server object.

    Args:
      handler: A callable or a mapping of callables to be exposed
        through the inference server.
      variable_source: Source of variables
      devices: Devices used for executing handlers. All devices are used in
        parallel.
      config: Inference Server configuration.
    """
    self._variable_source = variable_source
    self._variable_client = None
    self._keys = []
    self._devices = devices
    self._config = config
    self._call_cnt = 0
    self._device_params = [None] * len(self._devices)
    self._device_params_ids = [None] * len(self._devices)
    self._mutex = threading.Lock()
    self._handler = jax.tree_map(self._build_handler, handler, is_leaf=callable)

  @property
  def handler(self) -> InferenceServerHandler:
    return self._handler

  def _dereference_params(self, arg):
    """Replaces VariableReferences with their corresponding param values."""

    if not isinstance(arg, variable_utils.VariableReference):
      # All arguments but VariableReference are returned without modifications.
      return arg

    # Due to batching dimension we take the first element.
    variable_name = arg.variable_name[0]

    if variable_name not in self._keys:
      # Create a new VariableClient which also serves new variables.
      self._keys.append(variable_name)
      self._variable_client = variable_utils.VariableClient(
          client=self._variable_source,
          key=self._keys,
          update_period=self._config.update_period)

    params = self._variable_client.params
    device_idx = self._call_cnt % len(self._devices)
    # Select device via round robin, and update its params if they changed.
    if self._device_params_ids[device_idx] != id(params):
      self._device_params_ids[device_idx] = id(params)
      self._device_params[device_idx] = jax.device_put(
          params, self._devices[device_idx])

    # Return the params that are located on the chosen device.
    device_params = self._device_params[device_idx]
    if len(self._keys) == 1:
      return device_params
    return device_params[self._keys.index(variable_name)]

  def _build_handler(self, handler: Callable[..., Any]) -> Callable[..., Any]:
    """Builds a batched handler for a given callable handler and its name."""
    print('[inference_server] calling _build_handler')
    def dereference_params_and_call_handler(*args, **kwargs):
      with self._mutex:
        # Dereference args corresponding to params, leaving others unchanged.
        args_with_dereferenced_params = [
            self._dereference_params(arg) for arg in args
        ]
        kwargs_with_dereferenced_params = {
            key: self._dereference_params(value)
            for key, value in kwargs.items()
        }
        self._call_cnt += 1

        # Maybe update params, depending on client configuration.
        if self._variable_client is not None:
          self._variable_client.update()
      params = args_with_dereferenced_params[0]
      oar_thing = args_with_dereferenced_params[1]
      reversed_oar = reverse_oar_thing(oar_thing)
      recurrent_state_thing = args_with_dereferenced_params[2]
      reversed_oar_stacked = tree_stack(reversed_oar)
      recurrent_state_thing_stacked = tree_stack(recurrent_state_thing)
      to_return = handler(params, reversed_oar_stacked, recurrent_state_thing_stacked)
      unstacked_to_return = tree_unstack(to_return)
      print('[inference_server] returning unstacked')
      return unstacked_to_return
      # return handler(*args_with_dereferenced_params,
      #                **kwargs_with_dereferenced_params)

    to_return = lp.batched_handler(
        batch_size=self._config.batch_size,
        timeout=self._config.timeout,
        pad_batch=True,
        max_parallelism=2 * len(self._devices))(
            dereference_params_and_call_handler)
    return to_return
kinalmehta commented 1 year ago

Great!! BTW if you don't mind, please do share how the performance compares with the standard CPU based inference in your case and whether using InferenceServer helped or not.

samlobel commented 1 year ago

We're getting slightly better performance with the inference_server compared to without (3800 FPS vs 2200 FPS, as well as slightly better training steps/second). The key was to instantiate multiple copies of the inference server. I noticed that inference_server had very poor GPU utilization on its own. Ideally we figure out why that is instead of doing multiple.

abagaria commented 1 year ago

I managed to get the inference server up. But it gets stuck at the action selection step. Based on my understanding, launchpad only supports calling a CourierNode using python types. However, when passing observations and LSTM states which are usually numpy or jax type objects, the system stops.

Any hints on how this can be handled? I tried using pickle to serialize and deserialize the objects but launchpad doesn't seem to accept byte type objects.

Hey @kinalmehta did you ever figure out how to pass around numpy arrays (or non-python primitives in general) using courier? I am able to pass around python ints and floats, but need to pass around np uint arrays, but as you said, doing that causes the system to stop.

kinalmehta commented 1 year ago

Hi @abagaria, As mentioned here, launchpad seems to work fine with Jax arrays, so i use Jax arrays instead of numpy arrays.