cs231n / cs231n.github.io

Public facing notes page
MIT License
10.06k stars 4.06k forks source link

two_layer_net.ipynb defines the lambda with parameter W which is redundant #254

Open oonisim opened 3 years ago

oonisim commented 3 years ago

two_layer_net.ipynb defines the lambda with parameter W which is redundant.

from cs231n.gradient_check import eval_numerical_gradient
loss, grads = net.loss(X, y, reg=0.05)
for param_name in grads:

    f = lambda W: net.loss(X, y, reg=0.05)[0]   # <--- W is not used anywhere
    # f = lambda : net.loss(X, y, reg=0.05)[0]  # <--- Should be like this because W is redundant 

    param_grad_num = \
        eval_numerical_gradient(f,              # <--- lambda passed as f
                                net.params[param_name], verbose=False)

cs231n.gradient_check.eval_numerical_gradient.py invokes as f(x) but x will not be used. Without W in the lambda, it can be invoked simply f().

def eval_numerical_gradient(f, x, verbose=True, h=0.00001):
    """
    a naive implementation of numerical gradient of f at x
    - f should be a function that takes a single argument  # <--- Should have no need to take an argument
    - x is the point (numpy array) to evaluate the gradient at
    """
    grad = np.zeros_like(x)
    it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
    while not it.finished:

        # evaluate function at x+h
        ix = it.multi_index
        oldval = x[ix]
        x[ix] = oldval + h # increment by h
        fxph = f(x) # evalute f(x + h)          # <--- x will not be used 
        # fxph = f()                            # <--- Should be like this
        x[ix] = oldval - h
        fxmh = f(x) # evaluate f(x - h)
        x[ix] = oldval # restore

        # compute the partial derivative with centered formula
        grad[ix] = (fxph - fxmh) / (2 * h) # the slope
        it.iternext() # step to next dimension

    return grad

It happens to be working because np.nditer(x, flags=['multi_index'], op_flags=['readwrite']) directly updating the net.params[param_name].