qthequartermasterman / torch_pso

Particle Swarm Optimization implemented using PyTorch Optimizer API
MIT License
16 stars 1 forks source link

Convergence loop for PSO #21

Open Simply-Adi opened 2 years ago

Simply-Adi commented 2 years ago

Describe the issue Hello, I have a doubt about putting a loop for convergence of PSO. I read your test scripts for convergence and adapted them for my work, which is as follows.

To Reproduce The code I wrote is

# Configure training.
lr = 1e-3 # low values helps in avoiding spikes
epochs = 1
from torch_pso import ParticleSwarmOptimizer   

optimizer = ParticleSwarmOptimizer(net.parameters(), 
                            inertial_weight=0.5,
                            cognitive_coefficient = 2,
                            social_coefficient = 2,
                            num_particles=60,
                            max_param_value=1, 
                            min_param_value=-1) 
criterion = PinballLoss(quantile=0.50, reduction='mean')
num_batches = len(train_loader)
############

atol = 1e-4
rtol = 0

#############

net.train()
epoch_losses = []
fig, ax = plt.subplots()

for epoch in tqdm_notebook(range(epochs)): # For each epoch.
    total_loss = 0
    batch_id = 0
    for X_ts,y_ts in train_loader:
        def closure():
            optimizer.zero_grad()
            output = net.forward(X_ts)
            loss = criterion(output.squeeze(),y_ts)
            #loss.backward()
            return loss

        # first forward pass
        loss = closure()
        #loss.backward(); #much required in sgd

        if torch.allclose(loss, torch.Tensor([0.]), atol=atol, rtol=rtol):
            converged = True
            print( 'Convergence achieved before PSO, tighten tolerance')
            exit()
        else:
            converged = False
            for pso_iter in range(100):
                loss = optimizer.step(closure)
                #print(f"batch id {batch_id} pso_iter {pso_iter} loss {loss}")  
                if torch.allclose(loss, torch.Tensor([0.]), atol=atol, rtol=rtol): # < = atol+rtol*abs(0)
                    converged = True
                    break  # jumps out of the loop when condition is met before end of loop

            if converged == False:
                print(f'Loss after PSO {loss}')
                print('Convergence failed after PSO, Relax tolerance value')
            else:
                print(f'Loss after PSO {loss}')
                print('Convergence achieved after PSO')
        batch_id = batch_id + 1
        total_loss+= loss.item()
    avg_loss = total_loss/num_batches
    epoch_losses.append(avg_loss)
    ax.plot(epoch_losses) 
plt.show();

Is the above convergence loop ok?