nnaisense / evotorch

Advanced evolutionary computation library built directly on top of PyTorch, created at NNAISENSE.
https://evotorch.ai
Apache License 2.0
997 stars 62 forks source link

Issue with `problem.make_net` Shape Compatibility #110

Open Duxo opened 1 month ago

Duxo commented 1 month ago

Description:

I am experiencing an issue with the problem.make_net method. When I try to generate a trained network using the center parameter from the searcher.status, the method fails unless I manually call squeeze() on the center parameter.

Details:

Code to Reproduce:

import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import (
    GCNConv,
    global_mean_pool,
)
from evotorch.neuroevolution import NEProblem
from evotorch.algorithms import CMAES

class GCN_xs(torch.nn.Module):
    def __init__(self):
        super().__init__()

        hidden_dim = 64
        node_features = 15

        self.conv1 = GCNConv(node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc1 = Linear(hidden_dim, hidden_dim)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.fc1(x)
        return x

def fitness(network):
    return 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network = GCN_xs()

problem = NEProblem(
    objective_sense="max",
    network=network,
    network_eval_func=fitness,
    device=device,
)
searcher = CMAES(problem, stdev_init=0.01)
searcher.run(0)
trained = problem.make_net(searcher.status["center"])

Issue:

The code above raises an error when trying to generate a trained network using problem.make_net(searcher.status["center"]). However, if I modify the last line to:

trained = problem.make_net(searcher.status["center"].squeeze())

it works without any issues.

Expected Behavior:

problem.make_net should handle the shape of the center parameter directly without needing to call squeeze().

engintoklu commented 1 month ago

Hello @Duxo,

Thank you very much for this very helpful feedback!

Pull request addressing this issue is here: https://github.com/nnaisense/evotorch/pull/111

Feel free to let us know whether or not this pull request correctly addresses the issue you encountered.

Thanks again!