pytorch / contrib

Implementations of ideas from recent papers
391 stars 42 forks source link

SWA wasn't applied to param {}; skipping it".format(p)) #30

Open murtazabasu opened 4 years ago

murtazabasu commented 4 years ago

hello, I get this error have no clue what I am doing wrong. Here's my code

                for _ in range(self.K_epochs):

                    # Evaluating old actions and values :
                    logprobs, values, dist_entropy = self.policy.evaluate(old_states, old_actions)
                    # Finding the ratio (pi_theta / pi_theta__old):
                    advantages = self.calculate_advantages(reward_batch, values.detach())
                    ratios = torch.exp(logprobs - old_logprobs)

                    # Finding Surrogate Loss:
                    surr1 = ratios * advantages
                    surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
                    actor_loss = -torch.min(surr1, surr2)
                    critic_loss = 0.5*self.MseLoss(values, returns)
                    loss = actor_loss + critic_loss + 0.5*self.MseLoss(values, returns) - 0.01*dist_entropy

                    # take gradient step
                    self.optimizer.zero_grad()
                    loss.mean().backward()
                    self.SWAoptim.step()
                self.SWAoptim.swap_swa_sgd()         
                # Copy new weights into old policy:
                self.policy_old.load_state_dict(self.policy.state_dict())
                self.loSS.append(loss.mean().item())

I get the following error,

/home/murtaza/.local/lib/python2.7/site-packages/torchcontrib/optim/swa.py:191: UserWarning: SWA wasn't applied to param Parameter containing:
tensor([[-0.0556,  0.1067,  0.0519, -0.1137,  0.0632, -0.0402,  0.0576, -0.0704,
         -0.0888, -0.1129, -0.0102,  0.0503,  0.0469, -0.0822, -0.1028, -0.0354,
         -0.0007,  0.0863, -0.0221, -0.1036,  0.0431,  0.0164,  0.0004, -0.1106,
          0.0466, -0.0283, -0.0954, -0.1001, -0.0113,  0.0089,  0.0471, -0.0335,
          0.0501,  0.0773,  0.1195, -0.0987,  0.0455, -0.0468, -0.0520, -0.1011,
         -0.0373, -0.0642,  0.0105,  0.0455,  0.0452, -0.0569, -0.0551, -0.1137,
         -0.0057,  0.0203,  0.0088,  0.0077,  0.0917, -0.1203,  0.0266,  0.0904,
         -0.0180,  0.0097, -0.0717, -0.0547, -0.0954,  0.1197,  0.0836,  0.0938]],
       requires_grad=True); skipping it
  "SWA wasn't applied to param {}; skipping it".format(p))
janvainer commented 4 years ago

Hi, I am having the same issue.

nlml commented 4 years ago

Just had the same issue. I think the warning occurs when you call opt.swap_swa_sgd() before any stochastic weight averaging has actually taken place.

Here's my toy example:

import torch
import torchcontrib

FAILURE_MODE = True

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = torch.nn.BatchNorm1d(10)
        self.fc = torch.nn.Linear(10, 1)
    def forward(self, x):
        return self.fc(self.bn(x))

model = Model()
base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
swa_start = 2
opt = torchcontrib.optim.SWA(base_opt, swa_start=swa_start, swa_freq=1, swa_lr=0.05)
for step in range(6):
    print(step)
    opt.zero_grad()
    input = torch.randn(100, 10) + 5
    target = torch.randn(100, 1)
    loss_fn = lambda x, y: ((x - y) ** 2).mean()
    loss_fn(model(input), target).backward()
    opt.step()
    print(model.fc.weight[0, :3].detach())
    if FAILURE_MODE or step >= swa_start:
        opt.swap_swa_sgd()
    print(model.fc.weight[0, :3].detach())
    if FAILURE_MODE or step >= swa_start:
        opt.swap_swa_sgd()

If you set FAILURE_MODE = False you don't get the error, as in that case no attempt to swap the SWA and SGD params is made before any SWA has actually taken place :smile:

janvainer commented 4 years ago

Thanks, solved my issue :)