luigibonati / mlcolvar

A unified framework for machine learning collective variables for enhanced sampling simulations
MIT License
91 stars 24 forks source link

Can not compile some models to torchscripts with torch verion 2. #126

Open jintuzhang opened 4 months ago

jintuzhang commented 4 months ago

Example input:

import mlcolvar

cv = mlcolvar.cvs.DeepTDA(
    n_cvs=1,
    n_states=2,
    target_centers=[-10.0, 10.0],
    target_sigmas=[0.2, 0.2],
    layers=[4, 3, 2, 1]
)

cv.to_torchscript('model.ptc')
Errors: ```sh Traceback (most recent call last): File "/compile.py", line 11, in cv.to_torchscript('model.ptc') File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1429, in to_torchscript torchscript_module = torch.jit.script(self.eval(), **kwargs) File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_script.py", line 1284, in script return torch.jit._recursive.create_script_module( File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 480, in create_script_module return create_script_module_impl(nn_module, concrete_type, stubs_fn) File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct init_fn(script_module) File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn) File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 546, in create_script_module_impl create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs) File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 397, in create_methods_and_properties_from_stubs concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults) File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 867, in try_compile_fn return torch.jit.script(fn, _rcb=rcb) File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_script.py", line 1341, in script fn = torch._C._jit_script_compile( File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/annotations.py", line 366, in try_ann_to_type assert maybe_type, msg.format(repr(ann), repr(maybe_type)) AssertionError: Unsupported annotation typing.Union[list, torch.Tensor] could not be resolved because None could not be resolved. ```
conda list: ```sh # packages in environment at /calc/miniconda3/envs/TORCH: # # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_kmp_llvm conda-forge aiofiles 22.1.0 py310h06a4308_0 aiohttp 3.8.5 pypi_0 pypi aiosignal 1.3.1 pypi_0 pypi aiosqlite 0.18.0 py310h06a4308_0 annotated-types 0.5.0 pyhd8ed1ab_0 conda-forge anyio 3.6.2 pyhd8ed1ab_0 conda-forge argon2-cffi 21.3.0 pyhd3eb1b0_0 argon2-cffi-bindings 21.2.0 py310h5764c6d_3 conda-forge arrow 1.2.3 py310h06a4308_1 ase 3.22.1 pypi_0 pypi asttokens 2.2.1 pyhd8ed1ab_0 conda-forge astunparse 1.6.3 pyhd8ed1ab_0 conda-forge async-timeout 4.0.3 pypi_0 pypi attrs 22.2.0 pyh71513ae_0 conda-forge babel 2.12.1 pyhd8ed1ab_1 conda-forge backcall 0.2.0 pyhd3eb1b0_0 backoff 2.2.1 pyhd8ed1ab_0 conda-forge backports 1.1 pyhd3eb1b0_0 backports.functools_lru_cache 1.6.4 pyhd8ed1ab_0 conda-forge beautifulsoup4 4.12.2 py310h06a4308_0 blas 1.0 mkl bleach 6.0.0 pyhd8ed1ab_0 conda-forge blessed 1.19.1 pyhe4f9e05_2 conda-forge blosc 1.21.3 h6a678d5_0 bottleneck 1.3.7 py310h0a54255_0 conda-forge brotli 1.0.9 h166bdaf_8 conda-forge brotli-bin 1.0.9 h166bdaf_8 conda-forge brotlipy 0.7.0 py310h7f8727e_1002 bzip2 1.0.8 h7b6447c_0 c-ares 1.19.1 h5eee18b_0 c-blosc2 2.8.0 h6a678d5_0 ca-certificates 2023.7.22 hbcca054_0 conda-forge cachecontrol 0.12.11 py310h06a4308_1 certifi 2023.7.22 pyhd8ed1ab_0 conda-forge cffi 1.15.1 py310h5eee18b_3 cftime 1.6.2 pypi_0 pypi charset-normalizer 2.0.4 pyhd3eb1b0_0 cleo 2.0.1 py310h06a4308_0 click 8.1.7 unix_pyh707e725_0 conda-forge colorama 0.4.6 py310h06a4308_0 comm 0.1.2 py310h06a4308_0 contourpy 1.0.5 py310hdb19cb5_0 cpuonly 2.0 0 pytorch crashtest 0.4.1 py310h06a4308_0 croniter 1.4.1 pyhd8ed1ab_0 conda-forge cryptography 41.0.2 py310h22a60cf_0 cycler 0.11.0 pyhd3eb1b0_0 cyrus-sasl 2.1.28 h52b45da_1 dateutils 0.6.12 py_0 conda-forge dbus 1.13.18 hb2f20db_0 debugpy 1.6.7 py310h6a678d5_0 decorator 5.1.1 pyhd3eb1b0_0 deepdiff 6.3.1 pyhd8ed1ab_0 conda-forge defusedxml 0.7.1 pyhd8ed1ab_0 conda-forge distlib 0.3.7 pyhd8ed1ab_0 conda-forge dulwich 0.21.5 py310h2372a71_0 conda-forge e3nn 0.4.4 pypi_0 pypi einops 0.7.0 pypi_0 pypi entrypoints 0.4 py310h06a4308_0 exceptiongroup 1.1.3 pyhd8ed1ab_0 conda-forge executing 1.2.0 pyhd8ed1ab_0 conda-forge expat 2.5.0 hcb278e6_1 conda-forge expect 5.45.4 h555a92e_0 conda-forge fastapi 0.101.1 pyhd8ed1ab_0 conda-forge filelock 3.9.0 py310h06a4308_0 fontconfig 2.14.1 h4c34cd2_2 fonttools 4.39.3 py310h1fa729e_0 conda-forge freetype 2.12.1 hca18f0e_1 conda-forge frozenlist 1.4.0 pypi_0 pypi fsspec 2023.6.0 pyh1a96a4e_0 conda-forge giflib 5.2.1 h0b41bf4_3 conda-forge glib 2.69.1 he621ea3_2 gmp 6.2.1 h295c915_3 gmpy2 2.1.2 py310heeb90bb_0 gst-plugins-base 1.14.1 h6a678d5_1 gstreamer 1.14.1 h5eee18b_1 h11 0.14.0 pyhd8ed1ab_0 conda-forge h5py 3.9.0 py310he06866b_0 hdf5 1.12.1 h2b7332f_3 html5lib 1.1 pyhd3eb1b0_0 icu 58.2 hf484d3e_1000 conda-forge idna 3.4 py310h06a4308_0 importlib-metadata 6.8.0 pyha770c72_0 conda-forge importlib_metadata 6.8.0 hd8ed1ab_0 conda-forge iniconfig 2.0.0 pyhd8ed1ab_0 conda-forge inquirer 3.1.3 pyhd8ed1ab_0 conda-forge intel-openmp 2023.1.0 hdb19cb5_46305 ipykernel 6.25.0 py310h2f386ee_0 ipython 8.12.2 py310h06a4308_0 ipython_genutils 0.2.0 pyhd3eb1b0_1 itsdangerous 2.1.2 pyhd8ed1ab_0 conda-forge jaraco.classes 3.3.0 pyhd8ed1ab_0 conda-forge jedi 0.18.2 pyhd8ed1ab_0 conda-forge jeepney 0.8.0 pyhd8ed1ab_0 conda-forge jinja2 3.1.2 py310h06a4308_0 joblib 1.2.0 py310h06a4308_0 jpeg 9e h0b41bf4_3 conda-forge json5 0.9.6 pyhd3eb1b0_0 jsonschema 4.17.3 py310h06a4308_0 jupyter_client 8.1.0 py310h06a4308_0 jupyter_core 5.3.0 py310h06a4308_0 jupyter_events 0.6.3 py310h06a4308_0 jupyter_server 1.23.6 pyhd8ed1ab_0 conda-forge jupyter_server_fileid 0.9.0 py310h06a4308_0 jupyter_server_ydoc 0.8.0 py310h06a4308_1 jupyter_ydoc 0.2.4 py310h06a4308_0 jupyterlab 3.6.3 py310h06a4308_0 jupyterlab_pygments 0.2.2 pyhd8ed1ab_0 conda-forge jupyterlab_server 2.22.0 py310h06a4308_0 kdepy 1.1.5 py310h2372a71_0 conda-forge keyring 23.13.1 py310h06a4308_0 kiwisolver 1.4.4 py310h6a678d5_0 krb5 1.20.1 h143b758_1 lcms2 2.15 hfd0df8a_0 conda-forge ld_impl_linux-64 2.38 h1181459_1 lerc 3.0 h295c915_0 libbrotlicommon 1.0.9 h166bdaf_8 conda-forge libbrotlidec 1.0.9 h166bdaf_8 conda-forge libbrotlienc 1.0.9 h166bdaf_8 conda-forge libclang 14.0.6 default_hc6dbbc7_1 libclang13 14.0.6 default_he11475f_1 libcups 2.4.2 h2d74bed_1 libcurl 8.2.1 h251f7ec_0 libdeflate 1.17 h5eee18b_0 libedit 3.1.20221030 h5eee18b_0 libev 4.33 h7f8727e_1 libevent 2.1.12 hdbd6064_1 libexpat 2.5.0 hcb278e6_1 conda-forge libffi 3.4.4 h6a678d5_0 libgcc-ng 12.2.0 h65d4601_19 conda-forge libgfortran-ng 12.2.0 h69a702a_19 conda-forge libgfortran5 12.2.0 h337968e_19 conda-forge libllvm14 14.0.6 hdb19cb5_3 libnghttp2 1.52.0 h2d74bed_1 libpng 1.6.39 h5eee18b_0 libpq 12.15 hdbd6064_1 libprotobuf 3.20.3 he621ea3_0 libsodium 1.0.18 h36c2ea0_1 conda-forge libssh2 1.10.0 hdbd6064_2 libstdcxx-ng 13.1.0 hfd8a6a1_0 conda-forge libtiff 4.5.1 h6a678d5_0 libuuid 1.41.5 h5eee18b_0 libwebp 1.2.4 h11a3e52_1 libwebp-base 1.2.4 h5eee18b_1 libxcb 1.15 h7f8727e_0 libxkbcommon 1.0.1 h5eee18b_1 libxml2 2.10.4 hcbfbd50_0 libxslt 1.1.37 h2085143_0 libzlib 1.2.13 h166bdaf_4 conda-forge lightning 2.0.7 pyhd8ed1ab_0 conda-forge lightning-cloud 0.5.37 pyhd8ed1ab_0 conda-forge lightning-utilities 0.9.0 pyhd8ed1ab_0 conda-forge llvm-openmp 16.0.1 h417c0b6_0 conda-forge llvmlite 0.40.1 pypi_0 pypi lockfile 0.12.2 py_1 conda-forge lz4-c 1.9.4 h6a678d5_0 lzo 2.10 h516909a_1000 conda-forge mace 0.3.2 pypi_0 pypi mace-layer 0.0.0 pypi_0 pypi mace-torch 0.3.4 pypi_0 pypi markdown-it-py 3.0.0 pyhd8ed1ab_0 conda-forge markupsafe 2.1.1 py310h7f8727e_0 matplotlib 3.7.2 py310h06a4308_0 matplotlib-base 3.7.2 py310h1128e8f_0 matplotlib-inline 0.1.6 py310h06a4308_0 matscipy 0.8.0 pypi_0 pypi mdtraj 1.9.7 py310hd8d60c7_1 conda-forge mdurl 0.1.0 py310h06a4308_0 mistune 2.0.5 pyhd8ed1ab_0 conda-forge mkl 2023.1.0 h213fc3f_46343 mkl-service 2.4.0 py310h5eee18b_1 mkl_fft 1.3.6 py310h1128e8f_1 mkl_random 1.2.2 py310h1128e8f_1 mlcolvar 1+unknown pypi_0 pypi more-itertools 10.1.0 pyhd8ed1ab_0 conda-forge mpc 1.1.0 h10f8cd9_1 mpfr 4.0.2 hb69a4c5_1 mpiplus 0+unknown pypi_0 pypi mpmath 1.3.0 py310h06a4308_0 msgpack-python 1.0.3 py310hd09550d_0 multidict 6.0.4 pypi_0 pypi munkres 1.1.4 py_0 mysql 5.7.24 h721c034_2 nbclassic 0.5.5 pyh8b2e9e2_0 conda-forge nbclient 0.7.3 pyhd8ed1ab_0 conda-forge nbconvert-core 7.3.0 pyhd8ed1ab_2 conda-forge nbformat 5.8.0 pyhd8ed1ab_0 conda-forge ncurses 6.4 h6a678d5_0 nest-asyncio 1.5.6 py310h06a4308_0 netcdf 66.0.2 pypi_0 pypi netcdf4 1.6.4 pypi_0 pypi networkx 3.1 py310h06a4308_0 ninja 1.10.2 h06a4308_5 ninja-base 1.10.2 hd09550d_5 notebook 6.5.4 pyha770c72_0 conda-forge notebook-shim 0.2.2 py310h06a4308_0 nspr 4.35 h6a678d5_0 nss 3.89.1 h6a678d5_0 numba 0.57.1 pypi_0 pypi numexpr 2.8.4 py310h85018f9_1 numpy 1.24.4 pypi_0 pypi openmm 8.0.0 py310h5728c26_1 openmm-plumed 1.0 py310h552f1b7_9 openmmtools 0.23.1 pypi_0 pypi openssl 3.1.2 hd590300_0 conda-forge opt-einsum 3.0.0 py_0 conda-forge opt_einsum_fx 0.1.4 pyhd8ed1ab_0 conda-forge ordered-set 4.1.0 py310h06a4308_0 orjson 3.9.5 py310h1e2579a_0 conda-forge packaging 23.1 py310h06a4308_0 pandas 2.0.3 py310h1128e8f_0 pandocfilters 1.5.0 pyhd3eb1b0_0 parso 0.8.3 pyhd3eb1b0_0 pcre 8.45 h295c915_0 pexpect 4.8.0 pyhd3eb1b0_3 pickleshare 0.7.5 pyhd3eb1b0_1003 pillow 9.4.0 py310h6a678d5_0 pint 0.22 pypi_0 pypi pip 23.2.1 py310h06a4308_0 pkginfo 1.9.6 py310h06a4308_0 platformdirs 2.5.2 py310h06a4308_0 pluggy 1.2.0 pyhd8ed1ab_0 conda-forge plumed 2.9.0 pypi_0 pypi ply 3.11 py_1 conda-forge poetry 1.4.0 py310h06a4308_0 poetry-core 1.5.1 py310h06a4308_0 poetry-plugin-export 1.3.0 py310h4849bfd_0 prettytable 3.9.0 pypi_0 pypi prometheus_client 0.16.0 pyhd8ed1ab_0 conda-forge prompt-toolkit 3.0.38 pyha770c72_0 conda-forge psutil 5.9.0 py310h5eee18b_0 ptyprocess 0.7.0 pyhd3eb1b0_2 pure_eval 0.2.2 pyhd3eb1b0_0 py-cpuinfo 8.0.0 pyhd3eb1b0_1 pycparser 2.21 pyhd3eb1b0_0 pydantic 2.0.3 pyhd8ed1ab_1 conda-forge pydantic-core 2.3.0 py310hcb5633a_0 conda-forge pyg 2.3.0 py310_torch_2.0.0_cpu pyg pygments 2.15.1 py310h06a4308_1 pyjwt 2.8.0 pyhd8ed1ab_0 conda-forge pymbar 4.0.2 pypi_0 pypi pyopenssl 23.2.0 py310h06a4308_0 pyparsing 3.0.9 py310h06a4308_0 pyproject_hooks 1.0.0 py310h06a4308_0 pyqt 5.15.7 py310h6a678d5_1 pyqt5-sip 12.11.0 pypi_0 pypi pyrsistent 0.19.3 py310h1fa729e_0 conda-forge pysocks 1.7.1 py310h06a4308_0 pytables 3.8.0 py310hb8ae3fc_3 pytest 7.4.0 py310h06a4308_0 python 3.10.12 h955ad1f_0 python-build 0.10.0 pyhd8ed1ab_1 conda-forge python-dateutil 2.8.2 pyhd3eb1b0_0 python-editor 1.0.4 pyhd3eb1b0_0 python-fastjsonschema 2.16.3 pyhd8ed1ab_0 conda-forge python-installer 0.6.0 py310h06a4308_0 python-json-logger 2.0.7 py310h06a4308_0 python-multipart 0.0.6 py310h06a4308_0 python-tzdata 2023.3 pyhd3eb1b0_0 python_abi 3.10 2_cp310 conda-forge pytorch 2.0.1 cpu_py310hdc00b08_0 pytorch-lightning 2.0.7 pyhd8ed1ab_0 conda-forge pytorch-mutex 1.0 cpu pytorch pytorch-scatter 2.1.1 py310_torch_2.0.0_cpu pyg pytz 2023.3 pyhd8ed1ab_0 conda-forge pyyaml 6.0.1 py310h2372a71_0 conda-forge pyzmq 25.1.0 py310h6a678d5_0 qt-main 5.15.2 h7358343_9 qt-webengine 5.15.9 h9ab4d14_7 qtwebkit 5.212 h3fafdc1_5 rapidfuzz 2.13.7 py310h1128e8f_0 rdkit 2022.9.5 pypi_0 pypi readchar 4.0.5 pyhd8ed1ab_0 conda-forge readline 8.2 h5eee18b_0 requests 2.31.0 py310h06a4308_0 requests-toolbelt 0.10.1 pyhd8ed1ab_0 conda-forge rfc3339-validator 0.1.4 py310h06a4308_0 rfc3986-validator 0.1.1 py310h06a4308_0 rich 13.5.1 pyhd8ed1ab_0 conda-forge scikit-learn 1.3.0 py310h1128e8f_0 scipy 1.10.1 pypi_0 pypi secretstorage 3.3.3 py310hff52083_1 conda-forge send2trash 1.8.0 pyhd3eb1b0_1 setuptools 68.0.0 py310h06a4308_0 shellingham 1.5.3 pyhd8ed1ab_0 conda-forge sip 6.6.2 py310h6a678d5_0 six 1.16.0 pyhd3eb1b0_1 snappy 1.1.9 h295c915_0 sniffio 1.3.0 pyhd8ed1ab_0 conda-forge soupsieve 2.4 py310h06a4308_0 sqlite 3.41.2 h5eee18b_0 stack_data 0.6.2 pyhd8ed1ab_0 conda-forge starlette 0.27.0 py310h06a4308_0 starsessions 1.3.0 pyhd8ed1ab_0 conda-forge sympy 1.11.1 py310h06a4308_0 tbb 2021.8.0 hdb19cb5_0 terminado 0.17.1 py310h06a4308_0 threadpoolctl 2.2.0 pyh0d69192_0 tinycss2 1.2.1 py310h06a4308_0 tk 8.6.12 h1ccaba5_0 toml 0.10.2 pyhd3eb1b0_0 tomli 2.0.1 py310h06a4308_0 tomlkit 0.12.1 pyha770c72_0 conda-forge torch-ema 0.3 pypi_0 pypi torchmetrics 1.0.3 pyhd8ed1ab_0 conda-forge tornado 6.3.2 py310h5eee18b_0 tqdm 4.66.1 pyhd8ed1ab_0 conda-forge traitlets 5.9.0 pyhd8ed1ab_0 conda-forge trove-classifiers 2023.8.7 pyhd8ed1ab_0 conda-forge typing-extensions 4.7.1 py310h06a4308_0 typing_extensions 4.7.1 py310h06a4308_0 tzdata 2023c h04d1e81_0 unicodedata2 15.0.0 py310h5eee18b_0 urllib3 1.26.16 py310h06a4308_0 uvicorn 0.23.2 py310hff52083_0 conda-forge virtualenv 20.17.1 py310h06a4308_0 wcwidth 0.2.6 pyhd8ed1ab_0 conda-forge webencodings 0.5.1 py310h06a4308_1 websocket-client 1.5.1 pyhd8ed1ab_0 conda-forge websockets 11.0.3 py310h2372a71_0 conda-forge wheel 0.38.4 py310h06a4308_0 xz 5.4.2 h5eee18b_0 y-py 0.5.9 py310h52d8a92_0 yaml 0.2.5 h7f98852_2 conda-forge yarl 1.9.2 pypi_0 pypi ypy-websocket 0.8.2 py310h06a4308_0 zeromq 4.3.4 h9c3ff4c_1 conda-forge zipp 3.15.0 pyhd8ed1ab_0 conda-forge zlib 1.2.13 h166bdaf_4 conda-forge zlib-ng 2.0.7 h5eee18b_0 zstd 1.5.5 hc292b87_0 ```

I tried some tinkering, and here is an example of a compilable tda_loss module, although I'm not really sure about the correctness.

code: ```python #!/usr/bin/env python # ============================================================================= # MODULE DOCSTRING # ============================================================================= """ Target Discriminant Analysis Loss Function. """ __all__ = ["TDALoss", "tda_loss"] # ============================================================================= # GLOBAL IMPORTS # ============================================================================= from typing import Union, List, Tuple from warnings import warn import torch # ============================================================================= # LOSS FUNCTIONS # ============================================================================= class TDALoss(torch.nn.Module): """Compute a loss function as the distance from a simple Gaussian target distribution.""" def __init__( self, n_states: int, target_centers: Union[List[float], torch.Tensor], target_sigmas: Union[List[float], torch.Tensor], alpha: float = 1.0, beta: float = 100.0, ): """Constructor. Parameters ---------- n_states : int Number of states. The integer labels are expected to be in between 0 and ``n_states-1``. target_centers : list or torch.Tensor Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets. target_sigmas : list or torch.Tensor Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets. alpha : float, optional Centers_loss component prefactor, by default 1. beta : float, optional Sigmas loss compontent prefactor, by default 100. """ super().__init__() self.n_states = n_states self.target_centers = target_centers self.target_sigmas = target_sigmas self.alpha = alpha self.beta = beta def forward( self, H: torch.Tensor, labels: torch.Tensor, return_loss_terms: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Compute the value of the loss function. Parameters ---------- H : torch.Tensor Shape ``(n_batches, n_features)``. Output of the NN. labels : torch.Tensor Shape ``(n_batches,)``. Labels of the dataset. return_loss_terms : bool, optional If ``True``, the loss terms associated to the center and standard deviations of the target Gaussians are returned as well. Default is ``False``. Returns ------- loss : torch.Tensor Loss value. loss_centers : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the centers of the target Gaussians. loss_sigmas : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the standard deviations of the target Gaussians. """ return tda_loss( H, labels, self.n_states, self.target_centers, self.target_sigmas, self.alpha, self.beta, return_loss_terms, ) def tda_loss( H: torch.Tensor, labels: torch.Tensor, n_states: int, target_centers: Union[List[float], torch.Tensor], target_sigmas: Union[List[float], torch.Tensor], alpha: float = 1, beta: float = 100, return_loss_terms: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Compute a loss function as the distance from a simple Gaussian target distribution. Parameters ---------- H : torch.Tensor Shape ``(n_batches, n_features)``. Output of the NN. labels : torch.Tensor Shape ``(n_batches,)``. Labels of the dataset. n_states : int The integer labels are expected to be in between 0 and ``n_states-1``. target_centers : list or torch.Tensor Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets. target_sigmas : list or torch.Tensor Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets. alpha : float, optional Centers_loss component prefactor, by default 1. beta : float, optional Sigmas loss compontent prefactor, by default 100. return_loss_terms : bool, optional If ``True``, the loss terms associated to the center and standard deviations of the target Gaussians are returned as well. Default is ``False``. Returns ------- loss : torch.Tensor Loss value. loss_centers : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the centers of the target Gaussians. loss_sigmas : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the standard deviations of the target Gaussians. """ if not isinstance(target_centers, torch.Tensor): target_centers = torch.tensor(target_centers) if not isinstance(target_sigmas, torch.Tensor): target_sigmas = torch.tensor(target_sigmas) device = H.device target_centers = target_centers.to(device) target_sigmas = target_sigmas.to(device) loss_centers = torch.zeros_like(target_centers, device=device) loss_sigmas = torch.zeros_like(target_sigmas, device=device) for i in range(n_states): # check which elements belong to class i if not (labels == i).any(): raise ValueError( f"State {i} was not represented in this batch! Either use bigger batch_size or a more equilibrated dataset composition!" ) else: H_red = H[labels == i] # compute mean and standard deviation over the class i mu = torch.mean(H_red, 0) if len(torch.nonzero(labels == i)) == 1: warn( f"There is only one sample for state {i} in this batch! Std is set to 0, this may affect the training! Either use bigger batch_size or a more equilibrated dataset composition!" ) sigma = torch.tensor(0) else: sigma = torch.std(H_red, 0) # compute loss function contributes for class i loss_centers[i] = alpha * (mu - target_centers[i]).pow(2) loss_sigmas[i] = beta * (sigma - target_sigmas[i]).pow(2) # get total model loss loss_centers = torch.sum(loss_centers) loss_sigmas = torch.sum(loss_sigmas) loss = loss_centers + loss_sigmas if return_loss_terms: return loss, loss_centers, loss_sigmas return loss ```
EnricoTrizio commented 4 months ago

Thanks Jintu, it seems fine and we'll fix this