phlippe / uvadlc_notebooks

Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2023
https://uvadlc-notebooks.readthedocs.io/en/latest/
MIT License
2.59k stars 590 forks source link

code style: jax tutorial 3 (activation fns) #80

Closed murphyk closed 1 year ago

murphyk commented 1 year ago

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial3/Activation_Functions.html

I suggest you use this more jaxonic way of getting per-example gradients :)

def get_grads(act_fn, x):
    """
    Computes the gradients of an activation function at specified positions.

    Inputs:
        act_fn - An module or function of the forward pass of the activation function.
        x - 1D input array. 
    Output:
        An array with the same size of x containing the gradients of act_fn at x.
    """
   # return jax.grad(lambda inp: act_fn(inp).sum())(x) # obscure
    return jax.vmap(jax.grad(act_fn))(x)
murphyk commented 1 year ago

Also, ax.set_title(act_fn.name) does not seem to work (name is None), so the plots are not named. A quick fix is shown below.

def vis_act_fn(act_fn, ax, x, name):
    # Run activation function
    y = act_fn(x)
    y_grads = get_grads(act_fn, x)
    # Push x, y and gradients back to cpu for plotting
    # x, y, y_grads = x.cpu().numpy(), y.cpu().numpy(), y_grads.cpu().numpy()
    ## Plotting
    ax.plot(x, y, linewidth=2, label=name) ### NEW
    ax.plot(x, y_grads, linewidth=2, label="Gradient")
    ax.set_title(act_fn.name)
    ax.legend()
    ax.set_ylim(-1.5, x.max())

# Add activation functions if wanted
act_fns = [act_fn() for act_fn in act_fn_by_name.values()]
names = list(act_fn_by_name.keys()) ### NEW
x = np.linspace(-5, 5, 1000) # Range on which we want to visualize the activation functions
## Plotting
rows = math.ceil(len(act_fns)/2.0)
fig, ax = plt.subplots(rows, 2, figsize=(8, rows*4))
for i, act_fn in enumerate(act_fns):
    vis_act_fn(act_fn, ax[divmod(i,2)], x, names[i]). ### NEW
fig.subplots_adjust(hspace=0.3)
plt.show()
phlippe commented 1 year ago

Hi, thanks for the suggestion, I indeed translated the PyTorch code too literally there. :) Changed to better JAX style in a30e19dd1c8c94e61eb93258380a13826bcbda08 and fixed the plot naming.