mila-iqia / blocks

A Theano framework for building and training neural networks
Other
1.16k stars 351 forks source link

LookupTable gradient #672

Closed adbrebs closed 9 years ago

adbrebs commented 9 years ago

With the current implementation of LookupTable, the gradient is computed with respect to all the weights of the lookup table. However we only need to compute the gradient with respect to the weights selected by the indices required when calling apply(self, indices).

If the lookup table has dimensions (D,d) and we retrieve k rows, the current gradient costs O(Dd), whereas it should only cost O(kd).

dwf commented 9 years ago

Sounds like you might need more than a Brick, but maybe a custom Theano op.

bartvm commented 9 years ago

Wait, what makes you say that? The actual gradient should only be computed with respect to the rows selected by indices. It's just that if you ask Theano for the gradient w.r.t. the entire weight matrix, W, it will allocate a tensor of zeros (the size of W), fill in the rows with non-zero gradient, and return that to you.

adbrebs commented 9 years ago

If I delegate the computation of the gradients to a TrainingAlgorithm such as:

    algorithm = GradientDescent(
        cost=cost, # with a lookup table in the graph
        step_rule=..,
        params=ComputationGraph(cost).parameters)

then ComputationGraph(cost).parameters contains the lookup full matrix W, whereas it should only be W[indices].

bartvm commented 9 years ago

Why? You need a shared variable to take the gradients with respect to, returning W[indices] as a parameter wouldn't work.

adbrebs commented 9 years ago

Hmm, LookupTable params is W. So when you let a TrainingAlgorithm compute the gradient as I described, then the the gradient will be computed for all the weights of W (idem for the updates), which is obviously a waste of computations because we already know that only the indices rows have non-zero gradients. Am I missing something? Btw, this is confirmed by my experiments: when I increase the vocabulary size of my LookupTable, the backpropagation takes longer, which should not happen.

dwf commented 9 years ago

Bart is right, though: this is a fundamental problem with Theano, not blocks.

bartvm commented 9 years ago

It would be point 3 from https://github.com/Theano/Theano/issues/2219. There's a way around it in Blocks, but I never committed it because I think it should be fix made in Theano. If you need it short-term, give this code a try:

class SubtensorGradientFix(GradientDescent):
    """Gradient descent with support for indexed gradients.

    Parameters
    ----------
    subtensor_params : dict
        A dictionary of shared variables to the subtensors which ontribute
        to the gradient.

    """
    def __init__(self, cost, params, subtensor_params=None, *args, *kwargs):
        self.subtensor_params = subtensor_params
        params = [param
                  if param not in subtensor_params else subtensor_params[param]
                  for param in params]
        super(SubtensorGradientFix, self).__init__(cost=cost, arams=params,
                                                   *args, **kwargs)

    def initialize(self):
        all_updates = self.updates
        # Note: the gradients are computed in the same order in which
        # the parameters were given. Keep it like that to ensure
        # reproducibility.
        reverse_subtensor_params = {value: key for key, value
                                    in self.subtensor_params.items()}
        for param in self.params:
            if param in reverse_subtensor_params:
                all_updates.append((
                    reverse_subtensor_params[param],
                    tensor.inc_subtensor(param, -self.steps[param])
                ))
            else:
                all_updates.append((param, param - self.steps[param]))
        all_updates += self.step_rule_updates
        self._function = theano.function(self.inputs, [], updates=all_updates)
dwf commented 9 years ago

There's a typo in there, "aram".

bartvm commented 9 years ago

Woops, fixed. I swear it worked though... Maybe this was an older version of the code? Anyway, should give you a place to start.

adbrebs commented 9 years ago

Thank you both, I will try that then!

adbrebs commented 9 years ago

With Bart's code and a small hack in the LookupTable code (to save W[indices]), it works just as expected, which means x10 speedups in my experiments on CPU with lookuptables of 100000 rows.