y0ast / deterministic-uncertainty-quantification

Code for "Uncertainty Estimation Using a Single Deep Deterministic Neural Network"
https://arxiv.org/abs/2003.02037
MIT License
268 stars 31 forks source link

Some questions of paper and codes #6

Closed WangDeyu closed 2 years ago

WangDeyu commented 2 years ago

The method of updating centroids was introduced in the Appendix of van den Oord et al. (2017) for updating quantised latent variable.

I only found the paper, but didn't find the Appendix of van den Oord et al. (2017), Can you provide the link of the Appendix?

def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        y_pred = model(x)

        y = F.one_hot(y, num_classes).float()

        loss = F.binary_cross_entropy(y_pred, y, reduction="mean")

        if l_gradient_penalty > 0:
            gp = calc_gradient_penalty(x, y_pred)
            loss += l_gradient_penalty * gp

        loss.backward()
        optimizer.step()

        x.requires_grad_(False)

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

Is the gradient of x just for calculating gradient penalty? How does the loss of l_gradient_penalty * gp backpropagate?

 def update_embeddings(self, x, y):
        self.N = self.gamma * self.N + (1 - self.gamma) * y.sum(0)

        z = self.feature_extractor(x)

        z = torch.einsum("ij,mnj->imn", z, self.W)
        embedding_sum = torch.einsum("ijk,ik->jk", z, y)

        self.m = self.gamma * self.m + (1 - self.gamma) * embedding_sum

Could you please explain the process of model.update_embeddings ? What’s the meaning of self.N and self.m?

Thank you so much!

y0ast commented 2 years ago
  1. The updating mechanism is described in equation 4,5,6 of the paper: https://arxiv.org/abs/2003.02037
  2. See page 11 of: https://arxiv.org/pdf/1711.00937.pdf

Is the gradient of x just for calculating gradient penalty? How does the loss of l_gradient_penalty * gp backpropagate?

Yes, that's right. It's added to the loss, which is followed by loss.backward()

Could you please explain the process of model.update_embeddings ? What’s the meaning of self.N and self.m?

model.update_embeddings is called after each model update step. self.N is the number of points of a particular class in a minibatch, self.m is the sum of the embeddings per class. Both are exponentially averaged.

This is explained below equation 6 in the paper (see 1. above).