jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.57k stars 2.81k forks source link

Looking for ways to save memory through batchify #3865

Open BoyuanJackChen opened 4 years ago

BoyuanJackChen commented 4 years ago

I'm currently implementing NeRF, a neural rendering method published this year with great performance. The original code was offered in TensorFlow, and I mainly modified from this vanilla lego example: https://colab.research.google.com/github/bmild/nerf/blob/master/tiny_nerf.ipynb

The runnable version of my code can be found in this Github repository: https://github.com/BoyuanJackChen/NeRF-Implementation You can run "/Code/NeRF-jax.py" directly. I put a jax profiler exporter under raw = batchify(network_fn)(pts_flat). You can see that the memory usage increases proportionally as you increase N_samples variable..

While transferring it to JAX, I encountered a problem in batchifying the training data. I hope that instead of hand-batchifying, which is troublesome, buggy, and slow, there could be some equally elegant way to do the job as the counterpart in TensorFlow. I hereby offer the code down below. Hope to receive some enlightenment!

So the idea of NeRF is that for each pixel in an image, shoot one ray through that pixel, sample N points along the ray, and use the model to calculate the (r,g,b,alpha) value on each point, then summing them back to get the rgb value for that pixel. The loss is then calculated from the difference between the model-generated image and the original image.

The problem is that calculating so many sample points for over 8k pixels at once is too memory-consuming. Therefore, the original code offered a way to batchify in Tensorflow:

def batchify(fn, chunk=1024*32):
        return lambda inputs : tf.concat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
...
raw = batchify(network_fn)(pts_flat)

With this, the TensorFlow is supposedly learning the grads from one chunk of pts at a time.

I tried to do the same thing by simply changing tf.concat to jnp.concatenate. Nonetheless, the code is simply not batchifying, and my memory soon explodes! I mean, I can certainly hand-batchify by splitting the pts_flat with reshape, and calculate the grads for each sub-array. But it is going to get ugly and vulnerable to inefficiency. I wonder if there are good ways to solve it with some quick solutions.

By the way, I checked some previous threads on jax.remat. I attached "@jax.remat" above render_rays, and it didn't work. I also tried to attach it above batchify without success, saying "<class 'function'> is not a valid JAX type". Maybe I just didn't use it in the right way. Hope you guys can help me!

def train_nerf(optimizer, images, poses, focal, near, far, batchify_size=1024*32, N_samples=64, N_iters=3000):
    H,W = (images.shape)[1:3]
    rand_key = jax.random.PRNGKey(0)
    for i in range(N_iters + 1):
        img_i = random.randint(0, images.shape[0])
        target = images[img_i]
        pose = poses[img_i]
        rays_o, rays_d = get_rays(H, W, focal, pose)
        # Render the image and calculate the loss
        def loss_fn(model):
            rgb, depth, acc = render_rays(model, rays_o, rays_d, near=near, far=far,
                                batchify_size=batchify_size, N_samples=N_samples, rand=True)
            loss = jnp.mean(jnp.square(rgb - target))
            return loss, rgb
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, logits), grad = grad_fn(optimizer.target)
        optimizer = optimizer.apply_gradient(grad)
        print(f"Step {i} done; loss is {loss}")
    return optimizer

@jax.remat
def render_rays(network_fn, rays_o, rays_d, near, far, N_samples, rand=False,
                key=jax.random.PRNGKey(0), batchify_size=1024*32):
    def batchify(fn, chunk=batchify_size):
        return lambda inputs: jnp.concatenate([fn(inputs[i:i + chunk])
                                         for i in range(0, inputs.shape[0], chunk)], 0)
    z_vals = jnp.linspace(near, far, N_samples)
    if rand:
        key, subkey = jax.random.split(key)
        z_vals += jax.random.uniform(subkey, list(rays_o.shape[:-1]) + [N_samples], dtype=jnp.float32) \
                  * (far - near) / N_samples
    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]

    pts_flat = jnp.reshape(pts, [-1, 3])
    pts_flat = embed_fn(pts_flat)     # pts_flat is an array of shape (H*W*N*3, 51)
    # --- This is where I wish it to batchify, but it fails to! ---
    raw = batchify(network_fn)(pts_flat)
    raw = jnp.reshape(raw, list(pts.shape[:-1]) + [4])

    # Compute opacities and colors
    sigma_a = nn.relu(raw[..., 3])    # (H, W, N_samples)
    rgb = nn.sigmoid(raw[..., :3])    # (H, W, N_samples, 3)

    # Do volume rendering (P6 equation (3))
    dists = jnp.concatenate((z_vals[..., 1:] - z_vals[..., :-1],
                       jnp.broadcast_to([1e10], z_vals[..., :1].shape)), -1)   # (H, W, N_samples)
    alpha = 1. - jnp.exp(-sigma_a * dists)
    weights = alpha * jnp.cumprod(1. - alpha + 1e-10, axis=-1, dtype=jnp.float32)
    # Compute cumulative product along axis
    rgb_map   = jnp.sum(weights[..., None] * rgb, -2)
    depth_map = jnp.sum(weights * z_vals, -1)
    acc_map   = jnp.sum(weights, -1)
    return rgb_map, depth_map, acc_map

I can provide further information if you need!

mattjj commented 4 years ago

Hey, thanks for the question! NeRF is awesome.

I attached "@jax.remat" above render_rays, and it didn't work. I also tried to attach it above batchify without success, saying "<class 'function'> is not a valid JAX type".

Ah, the trouble is that batchify returns a function, while jax.remat (like jax.jit and all JAX transformations) can only apply to functions that take array inputs and produce array outputs (well, actually containers of arrays are allowed at inputs and outputs).

I tried to do the same thing by simply changing tf.concat to jnp.concatenate. Nonetheless, the code is simply not batchifying, and my memory soon explodes!

Are you using jax.jit? If not, and you're seeing memory issues, then we might not be freeing up memory as efficiently as possible in JAX autodiff, and we'd want to fix that. If you are using jax.jit, then memory use is all under XLA's control; XLA usually does a fantastic job of reducing memory usage, but perhaps it's not here for some reason, and I'd want to follow up with the XLA folks about it.

Overall, if you could share a full runnable repro that shows the memory explosions you're seeing, that would make it much easier to help! Think that's possible?

jekbradbury commented 4 years ago

I think what you're looking for (in JAX terms) is a map that is partially-sequential, partially-vectorized (so intermediate between lax.map, which applies a function sequentially over instances, and jax.vmap, which applies a function in parallel over instances). You can simulate this right now by reshaping your inputs and applying vmap to the inner batch dimension and lax.map to the outer one, and we're working on more convenient mechanisms. Does that sound like what you're interested in doing?

mattjj commented 4 years ago

I was thinking along similar lines, but then batchify is already chunking up arrays into pieces rather than doing it all in a vectorized fashion.

One thing to try is putting jax.remat here, inside batchify rather than on the outside:

def batchify(fn, chunk=1024*32):
    return jax.remat(lambda inputs : ...)

But under a jit XLA might be able to figure out a small working set size without us requiring jax.remat.

mattjj commented 4 years ago

Hmm, it's not clear to me how that batchify implementation will save memory when reverse-mode differentiated unless it's staged out to a TF graph somehow or unless there's some manual gradient checkpointing / remat going on which I'm missing in the notebook...

BoyuanJackChen commented 4 years ago

Overall, if you could share a full runnable repro that shows the memory explosions you're seeing, that would make it much easier to help! Think that's possible?

The repo is created! You can clone and run the code in /Code/NeRF-jax.py. More information can be found in the post. Unfortunately, adding jax.remat inside batchify did not work.

BoyuanJackChen commented 4 years ago

I think what you're looking for (in JAX terms) is a map that is partially-sequential, partially-vectorized (so intermediate between lax.map, which applies a function sequentially over instances, and jax.vmap, which applies a function in parallel over instances). You can simulate this right now by reshaping your inputs and applying vmap to the inner batch dimension and lax.map to the outer one, and we're working on more convenient mechanisms. Does that sound like what you're interested in doing?

Yes that sounds promising. I never thought that you can use lax.map with jax.vmap. I will definitely try it later.

mattjj commented 4 years ago

Unfortunately, adding jax.remat inside batchify did not work.

Didn't work in that you got an error, or in that there was no error but you're still seeing too much memory use?

BoyuanJackChen commented 4 years ago

Unfortunately, adding jax.remat inside batchify did not work.

Didn't work in that you got an error, or in that there was no error but you're still seeing too much memory use?

The latter. Memory use was still high.

BoyuanJackChen commented 4 years ago

I think what you're looking for (in JAX terms) is a map that is partially-sequential, partially-vectorized (so intermediate between lax.map, which applies a function sequentially over instances, and jax.vmap, which applies a function in parallel over instances). You can simulate this right now by reshaping your inputs and applying vmap to the inner batch dimension and lax.map to the outer one, and we're working on more convenient mechanisms. Does that sound like what you're interested in doing?

Thanks for your advice. Yet I am not quite familiar with lax.map. Below I have the input batchified by the number of rows in the image. For example, the image size in this case is (95, 126), so I batchified by each row_batch=5 rows of data. The loss function not changed yet. The first problem is that I have to "vmap" on both pts_batched and target_batched. I wonder how you can do that. The second is how can you update the model while batchifying. Could you give some pseudocode on how to use lax.map and jax.vmap? That would be enlightening!

row_batch = 5
def train_nerf_batchifyrow(optimizer, images, poses, focal, near, far, batchify_size=1024*32,
               N_samples=32, N_iters=2000, test_index=0):
    H,W = (images.shape)[1:3]
    testimg = images[test_index]; testpose = poses[test_index]
    for i in range(N_iters + 1):
        key = jax.random.PRNGKey(i)
        img_i = random.randint(0, images.shape[0])
        target = images[img_i]; pose = poses[img_i]
        rays_o, rays_d = get_rays(H, W, focal, pose)     # (H, W, 3)
        z_vals = jnp.linspace(near, far, N_samples)      # (H, W, N)
        key, subkey = jax.random.split(key)
        z_vals += jax.random.uniform(subkey, list(rays_o.shape[:-1]) + [N_samples], dtype=jnp.float32) \
                  * (far - near) / N_samples
        pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
        num_batch = int(pts.shape[0]/row_batch)
        pts_batched = jnp.reshape(pts, newshape=[num_batch, row_batch, pts.shape[1], pts.shape[2], pts.shape[3]])
        target_batched = jnp.reshape(target, newshape=[num_batch, row_batch, target.shape[1], target.shape[2]])
        # pts_batched is (num_batch, size_of_each_batch, W, N, 3); 
        # target_batched is (num_batch, size_of_each_batch, W, 3)

        def loss_fn(model):
            rgb_batch, depth, acc = render_rays_batchifyrow(model, z_vals, pts_batched)
            loss = jnp.mean(jnp.square(rgb_batch - target_batch))
            return loss, rgb
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, logits), grad = grad_fn(optimizer.target)
        optimizer = optimizer.apply_gradient(grad)
        print(f"Step {i} done; loss is {loss}")
    return optimizer
bionicles commented 4 years ago

Is it possible to move the definition of the objective function outside of that loop?

myagues commented 4 years ago

lax.map is what worked for me when playing with tiny NeRF. I'll leave my tiny_NeRF notebook, which is almost a transcript version of the original TF, but maybe helps you in building the big thing. The relevant part of the code is:

pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
pts_flat = jnp.reshape(pts, [-1, 3])
pts_flat = embed_fn(pts_flat, L_embed)

# def batchify(fn, chunk=1024 * 32):
#     return lambda inputs: jnp.concatenate(
#         [fn(inputs[i : i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0,
#     )
# raw = batchify(net_fn)(pts_flat)

raw = lax.map(net_fn, jnp.reshape(pts_flat, [-1, batch_size, pts_flat.shape[-1]]))
raw = jnp.reshape(raw, list(pts.shape[:-1]) + [4])
BoyuanJackChen commented 4 years ago

lax.map is what worked for me when playing with tiny NeRF. I'll leave my tiny_NeRF notebook, which is almost a transcript version of the original TF, but maybe helps you in building the big thing. The relevant part of the code is:

pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
pts_flat = jnp.reshape(pts, [-1, 3])
pts_flat = embed_fn(pts_flat, L_embed)

# def batchify(fn, chunk=1024 * 32):
#     return lambda inputs: jnp.concatenate(
#         [fn(inputs[i : i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0,
#     )
# raw = batchify(net_fn)(pts_flat)

raw = lax.map(net_fn, jnp.reshape(pts_flat, [-1, batch_size, pts_flat.shape[-1]]))
raw = jnp.reshape(raw, list(pts.shape[:-1]) + [4])

Hi myagues! Thanks for sharing your code. Unfortunately, your code does not work either. I tried to run it on Google Colab, and I increased the N_samples to 640 instead of 64. Then I got an OOM error. I used another way to check by putting jax.profiler.save_device_memory_profile("row.prof") after the line where you create raw in render_rays. Changing batch_size does not change the memory occupation.

I am doing a project on NeRF and I really want this code to run well. I think it is worth figuring out why lax.map and jax.remat are not working in this case. In my understanding, the problem is that you need to get value from batches in order to calculate the loss, which is the distance between the whole-image-rendering and the image. Nonetheless, it worked quite OK in tensorflow 1.14, as is shown in the team's code: https://github.com/bmild/nerf. After experiments, I found that both my code in flax.nn and myagues' code in jax.experimental.stax had this problem. Maybe there should be some specific ways to deal with this issue.

Any further ideas, everyone? @mattjj could you look further into the code, please?

BoyuanJackChen commented 4 years ago

An alternative I tried was to hand-batchify the image into rows. I used a for loop to go through each batch, calculated the loss on each sub-image, summed up the grads, and then applied the summed-grad to the optimizer. Unfortunately, it was about 8-10 times slower than how it should be.

BoyuanJackChen commented 4 years ago

I think what you're looking for (in JAX terms) is a map that is partially-sequential, partially-vectorized (so intermediate between lax.map, which applies a function sequentially over instances, and jax.vmap, which applies a function in parallel over instances). You can simulate this right now by reshaping your inputs and applying vmap to the inner batch dimension and lax.map to the outer one, and we're working on more convenient mechanisms. Does that sound like what you're interested in doing?

Could you please elaborate on your point with the code?