numagic / lumos

scalable accelerated optimal control
MIT License
17 stars 0 forks source link

automatically collects outputs from submodel #47

Open yunlongxu-numagic opened 2 years ago

yunlongxu-numagic commented 2 years ago

First draft to show some ideas:

yunlongxu-numagic commented 2 years ago

testing on vehicle model, there seems to be an increase in jitting time (from 45 -> 80 sec) for jax. Jax execution time also increases

Reducing the outputs form submodes (Eg, cutting down on tire outputs numbers) helps alleviate this issue. This seems to show that even though we only use the con_outputs in the NLP functions, creating a large outputs and then extracting con_outputs from it seems to suffer from performance drop when the number of outputs increases! (maybe in the ops, the large outputs is still created, and it is then 'indexed/gather' to collect the con_outputs)

yunlongxu-numagic commented 2 years ago

directly indexing con_outputs from outputs also didn't help with jit time

    def _extract_con_outputs(self, outputs: lnp.ndarray) -> lnp.ndarray:
        idx = np.array(
            [
                self.get_var_index("outputs", c)
                for c in self.get_group_names("con_outputs")
            ],
            dtype=np.int32,
        )
        return outputs[idx]
yunlongxu-numagic commented 2 years ago

when we add additional dummy outputs to the tire model, even just 10 outputs, it significantly increases the jit time, and also slows down the execution (to a lesser extent)

Suspect this is because as we pass submodel outputs to top level, it goes through a train of ops to: 1) concat scalars to an array (make_vector(...)) 2) turned from an array to a dictionary (combine_submodel_outputs, array_to_dict) 3) concat scalars to an array (in the parent model) 4) and so on...

So maybe the PR #49 to use only array I/O for the autograd level is away to improve it? (and inside we just use dictionaries, which then hopefully would never see scatter/gather except for the top level)

yunlongxu-numagic commented 2 years ago

for jax compilation, dictionary inputs are flattened and 'separated', so a dictionary inputs with 100 floats are turned into 100 inputs, rather than an array of size 100! This indicates that using dictionary I/O would indeed help avoiding lots of scatter/gather ops if the computation are in general done on the scalars

eg:

import numpy as np
import jax
import jax.numpy as jnp

from jax.config import config

# By default we use 64bit as over/underflow are quite likely to happen with 32bit and
# 2nd derivative autograd, without a lot of careful management...
config.update("jax_enable_x64", True)

# For initial phase, we also report verbosely if jax is recompiling.
config.update("jax_log_compiles", 1)

def sum_dict(inputs):
    array_to_sum = jnp.array(list(inputs.values()))
    return jnp.sum(array_to_sum)

def create_dummy_dict(size=100):
    names = [f"name_{i}" for i in range(size)]
    values = np.random.randn(size)
    return dict(zip(names, values))

if __name__ == "__main__":
    inputs = create_dummy_dict(size=100)
    outputs = jax.jit(sum_dict)(inputs)

gives:

WARNING:absl:Finished tracing + transforming sum_dict for jit in 0.7426469326019287 sec
WARNING:absl:Compiling sum_dict (139826457603456) for 100 args.
WARNING:absl:Finished XLA compilation of sum_dict in 0.05084347724914551 sec

Note the compilation time increases super-linearly, for 1000 args, we get 2.5sec (50x), and for 10000args, we get 160sec (another 60x)

And if we only sum 100 elements, with still 10000 args as inputs, then the jit time drops to 3.5sec


def sum_dict(inputs):
    array_to_sum = jnp.array(list(inputs.values()))
    return jnp.sum(array_to_sum[:100])

And if we only sum 100 elements, and never form the size 10000 array in the first place, then jit time drops to 0.05sec again

def sum_dict(inputs):
    sub_inputs = dict(zip(list(inputs.keys())[:100], list(inputs.values())[:100]))
    array_to_sum = jnp.array(list(sub_inputs.values()))
    return jnp.sum(array_to_sum)

The lesson? Avoid unnecessary scatter and gather!