Describe the issue
Hello, I have a doubt about putting a loop for convergence of PSO. I read your test scripts for convergence and adapted them for my work, which is as follows.
To Reproduce
The code I wrote is
# Configure training.
lr = 1e-3 # low values helps in avoiding spikes
epochs = 1
from torch_pso import ParticleSwarmOptimizer
optimizer = ParticleSwarmOptimizer(net.parameters(),
inertial_weight=0.5,
cognitive_coefficient = 2,
social_coefficient = 2,
num_particles=60,
max_param_value=1,
min_param_value=-1)
criterion = PinballLoss(quantile=0.50, reduction='mean')
num_batches = len(train_loader)
############
atol = 1e-4
rtol = 0
#############
net.train()
epoch_losses = []
fig, ax = plt.subplots()
for epoch in tqdm_notebook(range(epochs)): # For each epoch.
total_loss = 0
batch_id = 0
for X_ts,y_ts in train_loader:
def closure():
optimizer.zero_grad()
output = net.forward(X_ts)
loss = criterion(output.squeeze(),y_ts)
#loss.backward()
return loss
# first forward pass
loss = closure()
#loss.backward(); #much required in sgd
if torch.allclose(loss, torch.Tensor([0.]), atol=atol, rtol=rtol):
converged = True
print( 'Convergence achieved before PSO, tighten tolerance')
exit()
else:
converged = False
for pso_iter in range(100):
loss = optimizer.step(closure)
#print(f"batch id {batch_id} pso_iter {pso_iter} loss {loss}")
if torch.allclose(loss, torch.Tensor([0.]), atol=atol, rtol=rtol): # < = atol+rtol*abs(0)
converged = True
break # jumps out of the loop when condition is met before end of loop
if converged == False:
print(f'Loss after PSO {loss}')
print('Convergence failed after PSO, Relax tolerance value')
else:
print(f'Loss after PSO {loss}')
print('Convergence achieved after PSO')
batch_id = batch_id + 1
total_loss+= loss.item()
avg_loss = total_loss/num_batches
epoch_losses.append(avg_loss)
ax.plot(epoch_losses)
plt.show();
Describe the issue Hello, I have a doubt about putting a loop for convergence of PSO. I read your test scripts for convergence and adapted them for my work, which is as follows.
To Reproduce The code I wrote is
Is the above convergence loop ok?