facebookresearch / encodec

State-of-the-art deep learning based audio codec supporting both mono 24 kHz audio and stereo 48 kHz audio.
MIT License
3.51k stars 304 forks source link

How backward balancer when using huggingface accelerate #47

Open skiiwoo opened 1 year ago

skiiwoo commented 1 year ago

❓ Questions

when training encodec using huggingface's accelerate package, can't using balancer?

this is part of my training script

            self.balancer._set_losses_and_input(
                losses={'t': recon_loss, 'f': m_recon_loss, 'g': ads_loss, 'feat': rfm_loss},
                input=output
            )
            # self.balancer.backward()
            self.accelerator.backward(self.balancer)

and i change Balancer little

    def __mul__(self, other):
        for name, loss in self.losses.items():
            self.losses[name] = loss * other

    def __truediv__(self, other):
        for name, loss in self.losses.items():
            self.losses[name] = loss / other

    def _set_losses_and_input(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor):
        self.losses = losses
        self.input = input

    @property
    def metrics(self):
        return self._metrics

    def backward(self):
        losses = self.losses
        input = self.input

        norms = {}
        grads = {}
        for name, loss in losses.items():
            grad, = autograd.grad(loss, [input], retain_graph=True)
            if self.per_batch_item:
                dims = tuple(range(1, grad.dim()))
                norm = grad.norm(dim=dims).mean()
            else:
                norm = grad.norm()
            norms[name] = norm
            grads[name] = grad

        count = 1
        if self.per_batch_item:
            count = len(grad)
        avg_norms = average_metrics(self.averager(norms), count)
        total = sum(avg_norms.values())

        self._metrics = {}
        if self.monitor:
            for k, v in avg_norms.items():
                self._metrics[f'ratio_{k}'] = v / total

        total_weights = sum([self.weights[k] for k in avg_norms])
        ratios = {k: w / total_weights for k, w in self.weights.items()}

        out_grad: tp.Any = 0
        for name, avg_norm in avg_norms.items():
            if self.recale_grads:
                scale = ratios[name] * self.total_norm / (self.epsilon + avg_norm)
                grad = grads[name] * scale
            else:
                grad = self.weights[name] * grads[name]
            out_grad += grad

        input.backward(out_grad)