thieu1995 / mealpy

A Collection Of The State-of-the-art Metaheuristic Algorithms In Python (Metaheuristic/Optimizer/Nature-inspired/Biology)
https://mealpy.readthedocs.io
GNU General Public License v3.0
871 stars 177 forks source link

[BUG]: NotImplementedError in mealpy.swarm_based.AVOA.OriginalAVOA.solve #136

Closed vedik2002 closed 8 months ago

vedik2002 commented 8 months ago

Description of the bug

My error

NotImplementedError                       Traceback (most recent call last)
[<ipython-input-30-279fcffe97e1>](https://localhost:8080/#) in <cell line: 15>()
     13 
     14 
---> 15 FEDAVO(train,test,model,n_iter,local_epoch,tuning_epoch)

5 frames
[<ipython-input-29-37a29fd79b48>](https://localhost:8080/#) in FEDAVO(train, test, model, n_iter, local_epoch, tuning_epoch)
     65                     }
     66 
---> 67           model_hyper.solve(problem_iid)
     68 
     69 

[/usr/local/lib/python3.10/dist-packages/mealpy/optimizer.py](https://localhost:8080/#) in solve(self, problem, mode, n_workers, termination, starting_solutions, seed)
    221             g_best: g_best, the best found agent, that hold the best solution and the best target. Access by: .g_best.solution, .g_best.target
    222         """
--> 223         self.check_problem(problem, seed)
    224         self.check_mode_and_workers(mode, n_workers)
    225         self.check_termination("start", termination, None)

[/usr/local/lib/python3.10/dist-packages/mealpy/optimizer.py](https://localhost:8080/#) in check_problem(self, problem, seed)
    155         elif type(problem) == dict:
    156             problem["seed"] = seed
--> 157             self.problem = Problem(**problem)
    158         else:
    159             raise ValueError("problem needs to be a dict or an instance of Problem class.")

[/usr/local/lib/python3.10/dist-packages/mealpy/utils/problem.py](https://localhost:8080/#) in __init__(self, bounds, minmax, **kwargs)
     26         self.__set_keyword_arguments(kwargs)
     27         self.set_bounds(bounds)
---> 28         self.__set_functions()
     29         self.logger = Logger(self.log_to, log_file=self.log_file).create_logger(name=f"{__name__}.{__class__.__name__}",
     30                                     format_str='%(asctime)s, %(levelname)s, %(name)s [line: %(lineno)d]: %(message)s')

[/usr/local/lib/python3.10/dist-packages/mealpy/utils/problem.py](https://localhost:8080/#) in __set_functions(self)
     63         tested_solution = self.generate_solution(encoded=True)
     64         self.n_dims = len(tested_solution)
---> 65         result = self.obj_func(tested_solution)
     66         if type(result) in self.SUPPORTED_ARRAYS:
     67             result = np.array(result).flatten()

[/usr/local/lib/python3.10/dist-packages/mealpy/utils/problem.py](https://localhost:8080/#) in obj_func(self, x)
     95             float: Function value of `x`.
     96         """
---> 97         raise NotImplementedError
     98 
     99     def get_name(self) -> str:

NotImplementedError:

Steps To Reproduce

  1. Go to google colab

  2. My package versions are as follows absl-py==1.4.0 aiohttp==3.9.1 aiosignal==1.3.1 alabaster==0.7.16 albumentations==1.3.1 altair==4.2.2 anyio==3.7.1 appdirs==1.4.4 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 array-record==0.5.0 arviz==0.15.1 astropy==5.3.4 astunparse==1.6.3 async-timeout==4.0.3 atpublic==4.0 attrs==23.2.0 audioread==3.0.1 autograd==1.6.2 Babel==2.14.0 backcall==0.2.0 beautifulsoup4==4.11.2 bidict==0.22.1 bigframes==0.19.1 bleach==6.1.0 blinker==1.4 blis==0.7.11 blosc2==2.0.0 bokeh==3.3.3 bqplot==0.12.42 branca==0.7.0 build==1.0.3 CacheControl==0.13.1 cachetools==5.3.2 catalogue==2.0.10 certifi==2023.11.17 cffi==1.16.0 chardet==5.2.0 charset-normalizer==3.3.2 chex==0.1.7 click==8.1.7 click-plugins==1.1.1 cligj==0.7.2 cloudpickle==2.2.1 cmake==3.27.9 cmdstanpy==1.2.0 colorcet==3.0.1 colorlover==0.3.0 colour==0.1.5 community==1.0.0b1 confection==0.1.4 cons==0.4.6 contextlib2==21.6.0 contourpy==1.2.0 cryptography==41.0.7 cufflinks==0.17.3 cupy-cuda12x==12.2.0 cvxopt==1.3.2 cvxpy==1.3.2 cycler==0.12.1 cymem==2.0.8 Cython==3.0.8 dask==2023.8.1 datascience==0.17.6 db-dtypes==1.2.0 dbus-python==1.2.18 debugpy==1.6.6 decorator==4.4.2 defusedxml==0.7.1 diskcache==5.6.3 distributed==2023.8.1 distro==1.7.0 dlib==19.24.2 dm-tree==0.1.8 docutils==0.18.1 dopamine-rl==4.0.6 duckdb==0.9.2 earthengine-api==0.1.385 easydict==1.11 ecos==2.0.12 editdistance==0.6.2 eerepr==0.0.4 en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.6.0/en_core_web_sm-3.6.0-py3-none-any.whl#sha256=83276fc78a70045627144786b52e1f2728ad5e29e5e43916ec37ea9c26a11212 entrypoints==0.4 et-xmlfile==1.1.0 etils==1.6.0 etuples==0.3.9 exceptiongroup==1.2.0 fastai==2.7.13 fastcore==1.5.29 fastdownload==0.0.7 fastjsonschema==2.19.1 fastprogress==1.0.3 fastrlock==0.8.2 filelock==3.13.1 fiona==1.9.5 firebase-admin==5.3.0 Flask==2.2.5 flatbuffers==23.5.26 flax==0.7.5 folium==0.14.0 fonttools==4.47.2 frozendict==2.4.0 frozenlist==1.4.1 fsspec==2023.6.0 future==0.18.3 gast==0.5.4 gcsfs==2023.6.0 GDAL==3.4.3 gdown==4.6.6 geemap==0.30.4 gensim==4.3.2 geocoder==1.38.1 geographiclib==2.0 geopandas==0.13.2 geopy==2.3.0 gin-config==0.5.0 glob2==0.7 google==2.0.3 google-ai-generativelanguage==0.4.0 google-api-core==2.11.1 google-api-python-client==2.84.0 google-auth==2.17.3 google-auth-httplib2==0.1.1 google-auth-oauthlib==1.2.0 google-cloud-aiplatform==1.39.0 google-cloud-bigquery==3.12.0 google-cloud-bigquery-connection==1.12.1 google-cloud-bigquery-storage==2.24.0 google-cloud-core==2.3.3 google-cloud-datastore==2.15.2 google-cloud-firestore==2.11.1 google-cloud-functions==1.13.3 google-cloud-iam==2.13.0 google-cloud-language==2.9.1 google-cloud-resource-manager==1.11.0 google-cloud-storage==2.8.0 google-cloud-translate==3.11.3 google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz#sha256=8a8adcd0618fcdab16ee0442d46a604e7f676885d2fde98c604abf6501495e8b google-crc32c==1.5.0 google-generativeai==0.3.2 google-pasta==0.2.0 google-resumable-media==2.7.0 googleapis-common-protos==1.62.0 googledrivedownloader==0.4 graphviz==0.20.1 greenlet==3.0.3 grpc-google-iam-v1==0.13.0 grpcio==1.60.0 grpcio-status==1.48.2 gspread==3.4.2 gspread-dataframe==3.3.1 gym==0.25.2 gym-notices==0.0.8 h5netcdf==1.3.0 h5py==3.9.0 holidays==0.41 holoviews==1.17.1 html5lib==1.1 httpimport==1.3.1 httplib2==0.22.0 huggingface-hub==0.20.2 humanize==4.7.0 hyperopt==0.2.7 ibis-framework==7.1.0 idna==3.6 imageio==2.31.6 imageio-ffmpeg==0.4.9 imagesize==1.4.1 imbalanced-learn==0.10.1 imgaug==0.4.0 importlib-metadata==7.0.1 importlib-resources==6.1.1 imutils==0.5.4 inflect==7.0.0 iniconfig==2.0.0 install==1.3.5 intel-openmp==2023.2.3 ipyevents==2.0.2 ipyfilechooser==0.6.0 ipykernel==5.5.6 ipyleaflet==0.18.1 ipython==7.34.0 ipython-genutils==0.2.0 ipython-sql==0.5.0 ipytree==0.2.2 ipywidgets==7.7.1 itsdangerous==2.1.2 jax==0.4.23 jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.23+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl#sha256=8e42000672599e7ec0ea7f551acfcc95dcdd0e22b05a1d1f12f97b56a9fce4a8 jeepney==0.7.1 jieba==0.42.1 Jinja2==3.1.3 joblib==1.3.2 jsonpickle==3.0.2 jsonschema==4.19.2 jsonschema-specifications==2023.12.1 jupyter-client==6.1.12 jupyter-console==6.1.0 jupyter-server==1.24.0 jupyter_core==5.7.1 jupyterlab-widgets==3.0.9 jupyterlab_pygments==0.3.0 kaggle==1.5.16 kagglehub==0.1.5 keras==2.15.0 keyring==23.5.0 kiwisolver==1.4.5 langcodes==3.3.0 launchpadlib==1.10.16 lazr.restfulclient==0.14.4 lazr.uri==1.0.6 lazy_loader==0.3 libclang==16.0.6 librosa==0.10.1 lida==0.0.10 lightgbm==4.1.0 linkify-it-py==2.0.2 llmx==0.0.15a0 llvmlite==0.41.1 locket==1.0.0 logical-unification==0.4.6 lxml==4.9.4 malloy==2023.1067 Markdown==3.5.2 markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.7.1 matplotlib-inline==0.1.6 matplotlib-venn==0.11.9 mdit-py-plugins==0.4.0 mdurl==0.1.2 mealpy==3.0.1 miniKanren==1.0.3 missingno==0.5.2 mistune==0.8.4 mizani==0.9.3 mkl==2023.2.0 ml-dtypes==0.2.0 mlxtend==0.22.0 more-itertools==10.1.0 moviepy==1.0.3 mpmath==1.3.0 msgpack==1.0.7 multidict==6.0.4 multipledispatch==1.0.0 multitasking==0.0.11 murmurhash==1.0.10 music21==9.1.0 natsort==8.4.0 nbclassic==1.0.0 nbclient==0.9.0 nbconvert==6.5.4 nbformat==5.9.2 nest-asyncio==1.5.9 networkx==3.2.1 nibabel==4.0.2 nltk==3.8.1 notebook==6.5.5 notebook_shim==0.2.3 numba==0.58.1 numexpr==2.8.8 numpy==1.23.5 oauth2client==4.1.3 oauthlib==3.2.2 opencv-contrib-python==4.8.0.76 opencv-python==4.8.0.76 opencv-python-headless==4.9.0.80 openpyxl==3.1.2 opfunu==1.0.1 opt-einsum==3.3.0 optax==0.1.8 orbax-checkpoint==0.4.4 osqp==0.6.2.post8 packaging==23.2 pandas==1.5.3 pandas-datareader==0.10.0 pandas-gbq==0.19.2 pandas-stubs==1.5.3.230304 pandocfilters==1.5.0 panel==1.3.6 param==2.0.1 parso==0.8.3 parsy==2.1 partd==1.4.1 pathlib==1.0.1 pathlib_abc==0.1.1 pathy==0.11.0 patsy==0.5.6 peewee==3.17.0 pexpect==4.9.0 pickleshare==0.7.5 Pillow==9.4.0 pins==0.8.4 pip-tools==6.13.0 platformdirs==4.1.0 plotly==5.15.0 plotnine==0.12.4 pluggy==1.3.0 polars==0.17.3 pooch==1.8.0 portpicker==1.5.2 prefetch-generator==1.0.3 preshed==3.0.9 prettytable==3.9.0 proglog==0.1.10 progressbar2==4.2.0 prometheus-client==0.19.0 promise==2.3 prompt-toolkit==3.0.43 prophet==1.1.5 proto-plus==1.23.0 protobuf==3.20.3 psutil==5.9.5 psycopg2==2.9.9 ptyprocess==0.7.0 py-cpuinfo==9.0.0 py4j==0.10.9.7 pyarrow==10.0.1 pyarrow-hotfix==0.6 pyasn1==0.5.1 pyasn1-modules==0.3.0 pycocotools==2.0.7 pycparser==2.21 pyct==0.5.0 pydantic==1.10.13 pydata-google-auth==1.8.2 pydot==1.4.2 pydot-ng==2.0.0 pydotplus==2.0.2 PyDrive==1.3.1 PyDrive2==1.6.3 pyerfa==2.0.1.1 pygame==2.5.2 Pygments==2.16.1 PyGObject==3.42.1 PyJWT==2.3.0 pymc==5.7.2 pymystem3==0.2.0 PyOpenGL==3.1.7 pyOpenSSL==23.3.0 pyparsing==3.1.1 pyperclip==1.8.2 pyproj==3.6.1 pyproject_hooks==1.0.0 pyshp==2.3.1 PySocks==1.7.1 pytensor==2.14.2 pytest==7.4.4 python-apt==0.0.0 python-box==7.1.1 python-dateutil==2.8.2 python-louvain==0.16 python-slugify==8.0.1 python-utils==3.8.1 pytz==2023.3.post1 pyviz_comms==3.0.1 PyWavelets==1.5.0 PyYAML==6.0.1 pyzmq==23.2.1 qdldl==0.1.7.post0 qudida==0.0.4 ratelim==0.1.6 referencing==0.32.1 regex==2023.6.3 requests==2.31.0 requests-oauthlib==1.3.1 requirements-parser==0.5.0 rich==13.7.0 rpds-py==0.17.1 rpy2==3.4.2 rsa==4.9 safetensors==0.4.1 scikit-image==0.19.3 scikit-learn==1.2.2 scipy==1.11.4 scooby==0.9.2 scs==3.2.4.post1 seaborn==0.13.1 SecretStorage==3.3.1 Send2Trash==1.8.2 shapely==2.0.2 six==1.16.0 sklearn-pandas==2.2.0 smart-open==6.4.0 sniffio==1.3.0 snowballstemmer==2.2.0 sortedcontainers==2.4.0 soundfile==0.12.1 soupsieve==2.5 soxr==0.3.7 spacy==3.6.1 spacy-legacy==3.0.12 spacy-loggers==1.0.5 Sphinx==5.0.2 sphinxcontrib-applehelp==1.0.8 sphinxcontrib-devhelp==1.0.6 sphinxcontrib-htmlhelp==2.0.5 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.7 sphinxcontrib-serializinghtml==1.1.10 SQLAlchemy==2.0.24 sqlglot==19.9.0 sqlparse==0.4.4 srsly==2.4.8 stanio==0.3.0 statsmodels==0.14.1 sympy==1.12 tables==3.8.0 tabulate==0.9.0 tbb==2021.11.0 tblib==3.0.0 tenacity==8.2.3 tensorboard==2.15.1 tensorboard-data-server==0.7.2 tensorflow==2.15.0 tensorflow-datasets==4.9.4 tensorflow-estimator==2.15.0 tensorflow-gcs-config==2.15.0 tensorflow-hub==0.15.0 tensorflow-io-gcs-filesystem==0.35.0 tensorflow-metadata==1.14.0 tensorflow-probability==0.22.0 tensorstore==0.1.45 termcolor==2.4.0 terminado==0.18.0 text-unidecode==1.3 textblob==0.17.1 tf-slim==1.1.0 thinc==8.1.12 threadpoolctl==3.2.0 tifffile==2023.12.9 tinycss2==1.2.1 tokenizers==0.15.0 toml==0.10.2 tomli==2.0.1 toolz==0.12.0 torch @ https://download.pytorch.org/whl/cu121/torch-2.1.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=0d4e8c52a1fcf5ed6cfc256d9a370fcf4360958fc79d0b08a51d55e70914df46 torchaudio @ https://download.pytorch.org/whl/cu121/torchaudio-2.1.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=676bda4042734eda99bc59b2d7f761f345d3cde0cad492ad34e3aefde688c6d8 torchdata==0.7.0 torchsummary==1.5.1 torchtext==0.16.0 torchvision @ https://download.pytorch.org/whl/cu121/torchvision-0.16.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=e76e78d0ad43636c9884b3084ffaea8a8b61f21129fbfa456a5fe734f0affea9 tornado==6.3.2 tqdm==4.66.1 traitlets==5.7.1 traittypes==0.2.1 transformers==4.35.2 triton==2.1.0 tweepy==4.14.0 typer==0.9.0 types-pytz==2023.3.1.1 types-setuptools==69.0.0.20240115 typing_extensions==4.5.0 tzlocal==5.2 uc-micro-py==1.0.2 uritemplate==4.1.1 urllib3==2.0.7 vega-datasets==0.9.0 wadllib==1.3.6 wasabi==1.1.2 wcwidth==0.2.13 webcolors==1.13 webencodings==0.5.1 websocket-client==1.7.0 Werkzeug==3.0.1 widgetsnbextension==3.6.6 wordcloud==1.9.3 wrapt==1.14.1 xarray==2023.7.0 xarray-einstats==0.6.0 xgboost==2.0.3 xlrd==2.0.1 xxhash==3.4.1 xyzservices==2023.10.1 yarl==1.9.4 yellowbrick==1.5 yfinance==0.2.35 zict==3.0.0 zipp==3.17.0

  3. This is the part of the code which raised the error

    
    train,test = MNIST()

for client_id, train_loader in enumerate(train):

#print(f"Client {client_id + 1} - Training Data:")
#for batch_idx, (data, target) in enumerate(train_loader):
   #print(f"Batch {batch_idx + 1} - Data Shape: {data.shape}, Target Shape: {target.shape}")

model = MNISTCNN()

temp variables and can be altered

n_iter= 10 local_epoch = 100 tuning_epoch = 50

FEDAVO(train,test,model,n_iter,local_epoch,tuning_epoch)

```python
def FEDAVO(train,test,model,n_iter,local_epoch,tuning_epoch):

    loss_f=loss_classifier
    no_clients = len(train)
    n_samples=sum([len(db.dataset) for db in train])
    weights=([len(db.dataset)/n_samples for db in train])
    print("Clients' weights:",weights)

    loss_hist=[[float(loss_dataset(model,dl, loss_f).detach()) for dl in train]]
    #acc_hist=[[accuracy_dataset(model, dl) for dl in test]]
    #print(acc_hist)
    server_hist=[[tens_param.detach().cpu().numpy() for tens_param in list(model.parameters())]]

    models_hist = []

    server_loss=sum([weights[i]*loss_hist[-1][i] for i in range(len(weights))]) #sum the loss and accuracy of all the three clients
    #server_acc=sum([weights[i]*acc_hist[-1][i] for i in range(len(weights))])

    print(f'====> i: 0 Loss: {server_loss}')

    server_loss_list = [] #list to store the sum client of server loss and accuracy
    server_accuracy_list = []

    for i in tqdm(range(n_iter)):

        clients_params=[]
        clients_models=[]
        clients_losses=[]

        clients_param=[]
        clients_model=[]
        clients_loss =[]

        for k in range(no_clients):

          def objective_function(solution):

            lr = solution[0]
            momentum = solution[1]
            decay = solution[2]
            local_model = deepcopy(model)
            local_optimizer=optim.SGD(local_model.parameters(),lr=lr,momentum=momentum,weight_decay=decay)
            local_loss=local_learning(local_model,local_optimizer,train[k],local_epoch,loss_f)

            clients_loss.append(local_loss)
            list_param=1
             #GET THE PARAMETER TENSORS OF THE MODEL
            list_params=list(local_model.parameters())
            list_params=[tens_param.detach() for tens_param in list_params]
            clients_param.append(list_params)
            clients_model.append(deepcopy(local_model))

            return local_loss

          model_hyper = OriginalAVOA(epoch=tuning_epoch,pop_size=100)

          problem_iid = {
                    "fit_func":objective_function,
                    "bounds":[FloatVar(lb=0.01,ub=0.1),FloatVar(lb=0.001,ub=0.1),FloatVar(lb=0.001,ub=0.1)],
                    "minmax": "min",
                    }

          model_hyper.solve(problem_iid)

          min_loss = min(clients_loss)
          min_index = clients_loss.index(min_loss)

          clients_params.append(clients_param[min_index])
          clients_models.append(clients_model[min_index])
          clients_losses.append(clients_loss[min_index])

          clients_param=[]
          clients_model=[]
          clients_loss =[]

          print(f"Best Learning Rate:{model_hyper.solution[0][0]},Best momentum Rate:{model_hyper.solution[0][1]},Best Decay Rate:{model_hyper.solution[0][2]}")
          print(f"Lowest Loss:{model_hyper.solution[0]},Lowest:{model_hyper.solution[1]}")

          models_hist.append(clients_models)

          loss_hist+=[[float(loss_dataset(model, dl, loss_f).detach())
                    for dl in train]]
          #acc_hist+=[[accuracy_dataset(model, dl) for dl in test]]

          server_loss=sum([weights[i]*loss_hist[-1][i] for i in range(len(weights))])
          #server_acc=sum([weights[i]*acc_hist[-1][i] for i in range(len(weights))])

          print(f'====> i: {i+1} Loss: {server_loss} Server Test Accuracy: {server_acc}')
          #server_accuracy_list.append(server_acc)
          server_loss_list.append(server_loss)

          server_hist.append([tens_param.detach().cpu().numpy()
                    for tens_param in list(model.parameters())])

Additional Information

No response