Open murtazabasu opened 4 years ago
Hi, I am having the same issue.
Just had the same issue. I think the warning occurs when you call opt.swap_swa_sgd()
before any stochastic weight averaging has actually taken place.
Here's my toy example:
import torch
import torchcontrib
FAILURE_MODE = True
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm1d(10)
self.fc = torch.nn.Linear(10, 1)
def forward(self, x):
return self.fc(self.bn(x))
model = Model()
base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
swa_start = 2
opt = torchcontrib.optim.SWA(base_opt, swa_start=swa_start, swa_freq=1, swa_lr=0.05)
for step in range(6):
print(step)
opt.zero_grad()
input = torch.randn(100, 10) + 5
target = torch.randn(100, 1)
loss_fn = lambda x, y: ((x - y) ** 2).mean()
loss_fn(model(input), target).backward()
opt.step()
print(model.fc.weight[0, :3].detach())
if FAILURE_MODE or step >= swa_start:
opt.swap_swa_sgd()
print(model.fc.weight[0, :3].detach())
if FAILURE_MODE or step >= swa_start:
opt.swap_swa_sgd()
If you set FAILURE_MODE = False
you don't get the error, as in that case no attempt to swap the SWA and SGD params is made before any SWA has actually taken place :smile:
Thanks, solved my issue :)
hello, I get this error have no clue what I am doing wrong. Here's my code
I get the following error,