jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

lax.map does not work in this NeRF code #4126

Open BoyuanJackChen opened 4 years ago

BoyuanJackChen commented 4 years ago

This problem has been bothering me for months... I really want to know how to fix it. When implementing NeRF with jax, I want to divide the data into batches, so that the code can run on larger images with more sample points while keeping the speed. With the suggestions from another thread, I tried lax.map and jax.remat, but neither worked. I tried to hand-batchify and calculate each batch's loss with a for loop, but it was unacceptably slow (about x10 times slower).

In the original code published by the team, they used tensorflow 1.14. The line that they batched the input looks like this:

def batchify(fn, chunk=1024 * 32):
    return lambda inputs: jnp.concatenate(
        [fn(inputs[i : i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0,
    )
raw = batchify(net_fn)(pts_flat)

In jax, I tried the following:

raw = lax.map(net_fn, jnp.reshape(pts_flat, [-1, batch_size, pts_flat.shape[-1]]))

and

def batchify(fn, chunk=1024*32):
    return jax.remat(lambda inputs : jnp.concatenate([fn(inputs[i:i + chunk])
                                         for i in range(0, inputs.shape[0], chunk)], 0))
...
raw = batchify(net_fn)(pts_flat)

Both were compilable, yet neither saved memory.

My guess on the cause of the problem is that since you have to get the whole rendering of the image to calculate loss, this batch procedure is not working well. Nonetheless, the same code from tensorflow 1.14 was in the exact same structure, and yet tf.GradientTape seemed to be able to batchify well.

I am grateful to receive code example from myagues in my previous post #3865. Unfortunately their code did not solve the problem, either. Here I offer the links to the two failed attempts in jax from me: https://github.com/BoyuanJackChen/NeRF-Implementation.git and myagues:https://github.com/myagues/potpourri/blob/master/jax/tiny_nerf_jax.ipynb I created the model with flax.nn, and myagues used jax.experimental.stax. Both were able to learn, yet neither could save the memory. If you want to take a look, I think my code is a little bit simpler to read.

I sincerely hope this problem to be fixed. I would be super grateful for any help!

shoyer commented 4 years ago

Can you try lax.map(jax.remat(net_fn), ...)?

The remat call has to decorate the function call for which you want to save only a single gradients checkpoint. Within remat, values from the forward get recalculated instead of saved.

Generally speaking there isn't any point to using remat only once, because you don't end up saving memory once you add back in the re-evaluation of the forward pass.

BoyuanJackChen commented 4 years ago

@shoyer The change in only one line didn't work. Could you elaborate on where else to add remat?

BoyuanJackChen commented 4 years ago

To make viewing convenient, the whole loss function looks like this:

        def loss_fn(network_fn):
            def batchify(fn):
                return jax.remat(lambda inputs: jnp.concatenate([fn(inputs[i:i + batchify_size])
                                                                 for i in range(0, inputs.shape[0], batchify_size)], 0))
            z_vals = jnp.linspace(near, far, N_samples)
            if rand:
                key, subkey = jax.random.split(thekey)
                z_vals += jax.random.uniform(subkey, list(rays_o.shape[:-1]) + [N_samples], dtype=jnp.float32) \
                          * (far - near) / N_samples
            pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
            pts_flat = jnp.reshape(pts, [-1, 3])
            pts_flat = embed_fn(pts_flat)  # pts_flat is an array of shape (H*W*N*3, 51)
            raw = lax.map(remat(network_fn), jnp.reshape(pts_flat, [-1, batchify_size, pts_flat.shape[-1]]))
            raw = jnp.reshape(raw, list(pts.shape[:-1]) + [4])
            sigma_a = nn.relu(raw[..., 3])  # (H, W, N)
            rgb = nn.sigmoid(raw[..., :3])  # (H, W, N, 3)
            dists = jnp.concatenate((z_vals[..., 1:] - z_vals[..., :-1],
                                     jnp.broadcast_to([1e10], z_vals[..., :1].shape)), -1)  # (H, W, N)
            alpha = 1. - jnp.exp(-sigma_a * dists)
            weights = alpha * jnp.cumprod(1. - alpha + 1e-10, axis=-1, dtype=jnp.float32)  # (H, W, N)
            rgb = jnp.sum(weights[..., None] * rgb, -2)
            loss = jnp.mean(jnp.square(rgb - target))
            return loss, rgb
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, logits), grad = grad_fn(optimizer.target)
        optimizer = optimizer.apply_gradient(grad)
        print(f"Step {i} done; loss is {loss}")
cgarciae commented 4 years ago

@BoyuanJackChen I don't know if I am being too naive, but what if you just batch the data outside of Jax? I mean, just create a generator that produces numpy arrays of a certain batch size? You can easily do this with a tf.data.Dataset and use it with Jax by converting to an iterator of numpy arrays via its .as_numpy_iterator() method, you can even use Pytorch's Dataset + Dataloader without the to_tensor transformation.

BTW: Is there a reason why the authors implemented their own training loop in the colab you sent? It seems tf.keras.Model.fit could do the job.

BoyuanJackChen commented 4 years ago

@cgarciae Thanks for your advice! The iterator is a blind spot for me, so I am not sure how fast it could perform. Do you think it will be as fast as how tensorflow automatically batches data? Or is the batching using something like an iterator? One of my previous attempts was to divide the image by rows, (let's say 24 rows a batch), and then use a for loop to calculate the loss and grad for each batch. It was super slow but it worked. Is the iterator going to be of a similar structure? Could you enlighten me further with more details?

cgarciae commented 4 years ago

@BoyuanJackChen tf.data uses a lot of tricks to keep the GPU busy, in particular the idea is to do the preprocessing and batching on the CPU in parallel to the forward + backprop steps on the GPU, so the GPU doesn't have to wait for each new batch to be ready.

Check out this video on tf.data.

BoyuanJackChen commented 4 years ago

@cgarciae Another tricky part is that I don't put multiple images into one batch, but I need to divide each image into multiple batches. I guess I should make a tf.data for the rays of each image selected, and then call each batch with a for loop?

BoyuanJackChen commented 4 years ago

@cgarciae Thanks for your advice! It actually improved the code speed on cpu for a great extent! Below is a demo on how I used it:

target_data = tf.data.Dataset.from_tensor_slices(target_batched)
rays_o_data = tf.data.Dataset.from_tensor_slices(rays_o_batched)
rays_d_data = tf.data.Dataset.from_tensor_slices(rays_d_batched)
target_iter = iter(target_batched)
rays_o_iter = iter(rays_o_batched)
rays_d_iter = iter(rays_d_batched)
...
for target_batch, rays_o_batch, rays_d_batch in zip(target_iter, rays_o_iter, rays_d_iter):
    ...

Nonetheless, it still doesn't work on gpu - the code is even slower than for loop. My guess is that the iter is placed on cpu by default. Is there any way to make it run faster on GPU?