I tried some tinkering, and here is an example of a compilable tda_loss module, although I'm not really sure about the correctness.
code:
```python
#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
Target Discriminant Analysis Loss Function.
"""
__all__ = ["TDALoss", "tda_loss"]
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
from typing import Union, List, Tuple
from warnings import warn
import torch
# =============================================================================
# LOSS FUNCTIONS
# =============================================================================
class TDALoss(torch.nn.Module):
"""Compute a loss function as the distance from a simple Gaussian target distribution."""
def __init__(
self,
n_states: int,
target_centers: Union[List[float], torch.Tensor],
target_sigmas: Union[List[float], torch.Tensor],
alpha: float = 1.0,
beta: float = 100.0,
):
"""Constructor.
Parameters
----------
n_states : int
Number of states. The integer labels are expected to be in between 0
and ``n_states-1``.
target_centers : list or torch.Tensor
Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets.
target_sigmas : list or torch.Tensor
Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets.
alpha : float, optional
Centers_loss component prefactor, by default 1.
beta : float, optional
Sigmas loss compontent prefactor, by default 100.
"""
super().__init__()
self.n_states = n_states
self.target_centers = target_centers
self.target_sigmas = target_sigmas
self.alpha = alpha
self.beta = beta
def forward(
self, H: torch.Tensor, labels: torch.Tensor, return_loss_terms: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Compute the value of the loss function.
Parameters
----------
H : torch.Tensor
Shape ``(n_batches, n_features)``. Output of the NN.
labels : torch.Tensor
Shape ``(n_batches,)``. Labels of the dataset.
return_loss_terms : bool, optional
If ``True``, the loss terms associated to the center and standard
deviations of the target Gaussians are returned as well. Default
is ``False``.
Returns
-------
loss : torch.Tensor
Loss value.
loss_centers : torch.Tensor, optional
Only returned if ``return_loss_terms is True``. The value of the
loss term associated to the centers of the target Gaussians.
loss_sigmas : torch.Tensor, optional
Only returned if ``return_loss_terms is True``. The value of the
loss term associated to the standard deviations of the target Gaussians.
"""
return tda_loss(
H,
labels,
self.n_states,
self.target_centers,
self.target_sigmas,
self.alpha,
self.beta,
return_loss_terms,
)
def tda_loss(
H: torch.Tensor,
labels: torch.Tensor,
n_states: int,
target_centers: Union[List[float], torch.Tensor],
target_sigmas: Union[List[float], torch.Tensor],
alpha: float = 1,
beta: float = 100,
return_loss_terms: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Compute a loss function as the distance from a simple Gaussian target distribution.
Parameters
----------
H : torch.Tensor
Shape ``(n_batches, n_features)``. Output of the NN.
labels : torch.Tensor
Shape ``(n_batches,)``. Labels of the dataset.
n_states : int
The integer labels are expected to be in between 0 and ``n_states-1``.
target_centers : list or torch.Tensor
Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets.
target_sigmas : list or torch.Tensor
Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets.
alpha : float, optional
Centers_loss component prefactor, by default 1.
beta : float, optional
Sigmas loss compontent prefactor, by default 100.
return_loss_terms : bool, optional
If ``True``, the loss terms associated to the center and standard deviations
of the target Gaussians are returned as well. Default is ``False``.
Returns
-------
loss : torch.Tensor
Loss value.
loss_centers : torch.Tensor, optional
Only returned if ``return_loss_terms is True``. The value of the loss
term associated to the centers of the target Gaussians.
loss_sigmas : torch.Tensor, optional
Only returned if ``return_loss_terms is True``. The value of the loss
term associated to the standard deviations of the target Gaussians.
"""
if not isinstance(target_centers, torch.Tensor):
target_centers = torch.tensor(target_centers)
if not isinstance(target_sigmas, torch.Tensor):
target_sigmas = torch.tensor(target_sigmas)
device = H.device
target_centers = target_centers.to(device)
target_sigmas = target_sigmas.to(device)
loss_centers = torch.zeros_like(target_centers, device=device)
loss_sigmas = torch.zeros_like(target_sigmas, device=device)
for i in range(n_states):
# check which elements belong to class i
if not (labels == i).any():
raise ValueError(
f"State {i} was not represented in this batch! Either use bigger batch_size or a more equilibrated dataset composition!"
)
else:
H_red = H[labels == i]
# compute mean and standard deviation over the class i
mu = torch.mean(H_red, 0)
if len(torch.nonzero(labels == i)) == 1:
warn(
f"There is only one sample for state {i} in this batch! Std is set to 0, this may affect the training! Either use bigger batch_size or a more equilibrated dataset composition!"
)
sigma = torch.tensor(0)
else:
sigma = torch.std(H_red, 0)
# compute loss function contributes for class i
loss_centers[i] = alpha * (mu - target_centers[i]).pow(2)
loss_sigmas[i] = beta * (sigma - target_sigmas[i]).pow(2)
# get total model loss
loss_centers = torch.sum(loss_centers)
loss_sigmas = torch.sum(loss_sigmas)
loss = loss_centers + loss_sigmas
if return_loss_terms:
return loss, loss_centers, loss_sigmas
return loss
```
Example input:
Errors:
```sh Traceback (most recent call last): File "/compile.py", line 11, inconda list:
```sh # packages in environment at /calc/miniconda3/envs/TORCH: # # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_kmp_llvm conda-forge aiofiles 22.1.0 py310h06a4308_0 aiohttp 3.8.5 pypi_0 pypi aiosignal 1.3.1 pypi_0 pypi aiosqlite 0.18.0 py310h06a4308_0 annotated-types 0.5.0 pyhd8ed1ab_0 conda-forge anyio 3.6.2 pyhd8ed1ab_0 conda-forge argon2-cffi 21.3.0 pyhd3eb1b0_0 argon2-cffi-bindings 21.2.0 py310h5764c6d_3 conda-forge arrow 1.2.3 py310h06a4308_1 ase 3.22.1 pypi_0 pypi asttokens 2.2.1 pyhd8ed1ab_0 conda-forge astunparse 1.6.3 pyhd8ed1ab_0 conda-forge async-timeout 4.0.3 pypi_0 pypi attrs 22.2.0 pyh71513ae_0 conda-forge babel 2.12.1 pyhd8ed1ab_1 conda-forge backcall 0.2.0 pyhd3eb1b0_0 backoff 2.2.1 pyhd8ed1ab_0 conda-forge backports 1.1 pyhd3eb1b0_0 backports.functools_lru_cache 1.6.4 pyhd8ed1ab_0 conda-forge beautifulsoup4 4.12.2 py310h06a4308_0 blas 1.0 mkl bleach 6.0.0 pyhd8ed1ab_0 conda-forge blessed 1.19.1 pyhe4f9e05_2 conda-forge blosc 1.21.3 h6a678d5_0 bottleneck 1.3.7 py310h0a54255_0 conda-forge brotli 1.0.9 h166bdaf_8 conda-forge brotli-bin 1.0.9 h166bdaf_8 conda-forge brotlipy 0.7.0 py310h7f8727e_1002 bzip2 1.0.8 h7b6447c_0 c-ares 1.19.1 h5eee18b_0 c-blosc2 2.8.0 h6a678d5_0 ca-certificates 2023.7.22 hbcca054_0 conda-forge cachecontrol 0.12.11 py310h06a4308_1 certifi 2023.7.22 pyhd8ed1ab_0 conda-forge cffi 1.15.1 py310h5eee18b_3 cftime 1.6.2 pypi_0 pypi charset-normalizer 2.0.4 pyhd3eb1b0_0 cleo 2.0.1 py310h06a4308_0 click 8.1.7 unix_pyh707e725_0 conda-forge colorama 0.4.6 py310h06a4308_0 comm 0.1.2 py310h06a4308_0 contourpy 1.0.5 py310hdb19cb5_0 cpuonly 2.0 0 pytorch crashtest 0.4.1 py310h06a4308_0 croniter 1.4.1 pyhd8ed1ab_0 conda-forge cryptography 41.0.2 py310h22a60cf_0 cycler 0.11.0 pyhd3eb1b0_0 cyrus-sasl 2.1.28 h52b45da_1 dateutils 0.6.12 py_0 conda-forge dbus 1.13.18 hb2f20db_0 debugpy 1.6.7 py310h6a678d5_0 decorator 5.1.1 pyhd3eb1b0_0 deepdiff 6.3.1 pyhd8ed1ab_0 conda-forge defusedxml 0.7.1 pyhd8ed1ab_0 conda-forge distlib 0.3.7 pyhd8ed1ab_0 conda-forge dulwich 0.21.5 py310h2372a71_0 conda-forge e3nn 0.4.4 pypi_0 pypi einops 0.7.0 pypi_0 pypi entrypoints 0.4 py310h06a4308_0 exceptiongroup 1.1.3 pyhd8ed1ab_0 conda-forge executing 1.2.0 pyhd8ed1ab_0 conda-forge expat 2.5.0 hcb278e6_1 conda-forge expect 5.45.4 h555a92e_0 conda-forge fastapi 0.101.1 pyhd8ed1ab_0 conda-forge filelock 3.9.0 py310h06a4308_0 fontconfig 2.14.1 h4c34cd2_2 fonttools 4.39.3 py310h1fa729e_0 conda-forge freetype 2.12.1 hca18f0e_1 conda-forge frozenlist 1.4.0 pypi_0 pypi fsspec 2023.6.0 pyh1a96a4e_0 conda-forge giflib 5.2.1 h0b41bf4_3 conda-forge glib 2.69.1 he621ea3_2 gmp 6.2.1 h295c915_3 gmpy2 2.1.2 py310heeb90bb_0 gst-plugins-base 1.14.1 h6a678d5_1 gstreamer 1.14.1 h5eee18b_1 h11 0.14.0 pyhd8ed1ab_0 conda-forge h5py 3.9.0 py310he06866b_0 hdf5 1.12.1 h2b7332f_3 html5lib 1.1 pyhd3eb1b0_0 icu 58.2 hf484d3e_1000 conda-forge idna 3.4 py310h06a4308_0 importlib-metadata 6.8.0 pyha770c72_0 conda-forge importlib_metadata 6.8.0 hd8ed1ab_0 conda-forge iniconfig 2.0.0 pyhd8ed1ab_0 conda-forge inquirer 3.1.3 pyhd8ed1ab_0 conda-forge intel-openmp 2023.1.0 hdb19cb5_46305 ipykernel 6.25.0 py310h2f386ee_0 ipython 8.12.2 py310h06a4308_0 ipython_genutils 0.2.0 pyhd3eb1b0_1 itsdangerous 2.1.2 pyhd8ed1ab_0 conda-forge jaraco.classes 3.3.0 pyhd8ed1ab_0 conda-forge jedi 0.18.2 pyhd8ed1ab_0 conda-forge jeepney 0.8.0 pyhd8ed1ab_0 conda-forge jinja2 3.1.2 py310h06a4308_0 joblib 1.2.0 py310h06a4308_0 jpeg 9e h0b41bf4_3 conda-forge json5 0.9.6 pyhd3eb1b0_0 jsonschema 4.17.3 py310h06a4308_0 jupyter_client 8.1.0 py310h06a4308_0 jupyter_core 5.3.0 py310h06a4308_0 jupyter_events 0.6.3 py310h06a4308_0 jupyter_server 1.23.6 pyhd8ed1ab_0 conda-forge jupyter_server_fileid 0.9.0 py310h06a4308_0 jupyter_server_ydoc 0.8.0 py310h06a4308_1 jupyter_ydoc 0.2.4 py310h06a4308_0 jupyterlab 3.6.3 py310h06a4308_0 jupyterlab_pygments 0.2.2 pyhd8ed1ab_0 conda-forge jupyterlab_server 2.22.0 py310h06a4308_0 kdepy 1.1.5 py310h2372a71_0 conda-forge keyring 23.13.1 py310h06a4308_0 kiwisolver 1.4.4 py310h6a678d5_0 krb5 1.20.1 h143b758_1 lcms2 2.15 hfd0df8a_0 conda-forge ld_impl_linux-64 2.38 h1181459_1 lerc 3.0 h295c915_0 libbrotlicommon 1.0.9 h166bdaf_8 conda-forge libbrotlidec 1.0.9 h166bdaf_8 conda-forge libbrotlienc 1.0.9 h166bdaf_8 conda-forge libclang 14.0.6 default_hc6dbbc7_1 libclang13 14.0.6 default_he11475f_1 libcups 2.4.2 h2d74bed_1 libcurl 8.2.1 h251f7ec_0 libdeflate 1.17 h5eee18b_0 libedit 3.1.20221030 h5eee18b_0 libev 4.33 h7f8727e_1 libevent 2.1.12 hdbd6064_1 libexpat 2.5.0 hcb278e6_1 conda-forge libffi 3.4.4 h6a678d5_0 libgcc-ng 12.2.0 h65d4601_19 conda-forge libgfortran-ng 12.2.0 h69a702a_19 conda-forge libgfortran5 12.2.0 h337968e_19 conda-forge libllvm14 14.0.6 hdb19cb5_3 libnghttp2 1.52.0 h2d74bed_1 libpng 1.6.39 h5eee18b_0 libpq 12.15 hdbd6064_1 libprotobuf 3.20.3 he621ea3_0 libsodium 1.0.18 h36c2ea0_1 conda-forge libssh2 1.10.0 hdbd6064_2 libstdcxx-ng 13.1.0 hfd8a6a1_0 conda-forge libtiff 4.5.1 h6a678d5_0 libuuid 1.41.5 h5eee18b_0 libwebp 1.2.4 h11a3e52_1 libwebp-base 1.2.4 h5eee18b_1 libxcb 1.15 h7f8727e_0 libxkbcommon 1.0.1 h5eee18b_1 libxml2 2.10.4 hcbfbd50_0 libxslt 1.1.37 h2085143_0 libzlib 1.2.13 h166bdaf_4 conda-forge lightning 2.0.7 pyhd8ed1ab_0 conda-forge lightning-cloud 0.5.37 pyhd8ed1ab_0 conda-forge lightning-utilities 0.9.0 pyhd8ed1ab_0 conda-forge llvm-openmp 16.0.1 h417c0b6_0 conda-forge llvmlite 0.40.1 pypi_0 pypi lockfile 0.12.2 py_1 conda-forge lz4-c 1.9.4 h6a678d5_0 lzo 2.10 h516909a_1000 conda-forge mace 0.3.2 pypi_0 pypi mace-layer 0.0.0 pypi_0 pypi mace-torch 0.3.4 pypi_0 pypi markdown-it-py 3.0.0 pyhd8ed1ab_0 conda-forge markupsafe 2.1.1 py310h7f8727e_0 matplotlib 3.7.2 py310h06a4308_0 matplotlib-base 3.7.2 py310h1128e8f_0 matplotlib-inline 0.1.6 py310h06a4308_0 matscipy 0.8.0 pypi_0 pypi mdtraj 1.9.7 py310hd8d60c7_1 conda-forge mdurl 0.1.0 py310h06a4308_0 mistune 2.0.5 pyhd8ed1ab_0 conda-forge mkl 2023.1.0 h213fc3f_46343 mkl-service 2.4.0 py310h5eee18b_1 mkl_fft 1.3.6 py310h1128e8f_1 mkl_random 1.2.2 py310h1128e8f_1 mlcolvar 1+unknown pypi_0 pypi more-itertools 10.1.0 pyhd8ed1ab_0 conda-forge mpc 1.1.0 h10f8cd9_1 mpfr 4.0.2 hb69a4c5_1 mpiplus 0+unknown pypi_0 pypi mpmath 1.3.0 py310h06a4308_0 msgpack-python 1.0.3 py310hd09550d_0 multidict 6.0.4 pypi_0 pypi munkres 1.1.4 py_0 mysql 5.7.24 h721c034_2 nbclassic 0.5.5 pyh8b2e9e2_0 conda-forge nbclient 0.7.3 pyhd8ed1ab_0 conda-forge nbconvert-core 7.3.0 pyhd8ed1ab_2 conda-forge nbformat 5.8.0 pyhd8ed1ab_0 conda-forge ncurses 6.4 h6a678d5_0 nest-asyncio 1.5.6 py310h06a4308_0 netcdf 66.0.2 pypi_0 pypi netcdf4 1.6.4 pypi_0 pypi networkx 3.1 py310h06a4308_0 ninja 1.10.2 h06a4308_5 ninja-base 1.10.2 hd09550d_5 notebook 6.5.4 pyha770c72_0 conda-forge notebook-shim 0.2.2 py310h06a4308_0 nspr 4.35 h6a678d5_0 nss 3.89.1 h6a678d5_0 numba 0.57.1 pypi_0 pypi numexpr 2.8.4 py310h85018f9_1 numpy 1.24.4 pypi_0 pypi openmm 8.0.0 py310h5728c26_1I tried some tinkering, and here is an example of a compilable
tda_loss
module, although I'm not really sure about the correctness.code:
```python #!/usr/bin/env python # ============================================================================= # MODULE DOCSTRING # ============================================================================= """ Target Discriminant Analysis Loss Function. """ __all__ = ["TDALoss", "tda_loss"] # ============================================================================= # GLOBAL IMPORTS # ============================================================================= from typing import Union, List, Tuple from warnings import warn import torch # ============================================================================= # LOSS FUNCTIONS # ============================================================================= class TDALoss(torch.nn.Module): """Compute a loss function as the distance from a simple Gaussian target distribution.""" def __init__( self, n_states: int, target_centers: Union[List[float], torch.Tensor], target_sigmas: Union[List[float], torch.Tensor], alpha: float = 1.0, beta: float = 100.0, ): """Constructor. Parameters ---------- n_states : int Number of states. The integer labels are expected to be in between 0 and ``n_states-1``. target_centers : list or torch.Tensor Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets. target_sigmas : list or torch.Tensor Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets. alpha : float, optional Centers_loss component prefactor, by default 1. beta : float, optional Sigmas loss compontent prefactor, by default 100. """ super().__init__() self.n_states = n_states self.target_centers = target_centers self.target_sigmas = target_sigmas self.alpha = alpha self.beta = beta def forward( self, H: torch.Tensor, labels: torch.Tensor, return_loss_terms: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Compute the value of the loss function. Parameters ---------- H : torch.Tensor Shape ``(n_batches, n_features)``. Output of the NN. labels : torch.Tensor Shape ``(n_batches,)``. Labels of the dataset. return_loss_terms : bool, optional If ``True``, the loss terms associated to the center and standard deviations of the target Gaussians are returned as well. Default is ``False``. Returns ------- loss : torch.Tensor Loss value. loss_centers : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the centers of the target Gaussians. loss_sigmas : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the standard deviations of the target Gaussians. """ return tda_loss( H, labels, self.n_states, self.target_centers, self.target_sigmas, self.alpha, self.beta, return_loss_terms, ) def tda_loss( H: torch.Tensor, labels: torch.Tensor, n_states: int, target_centers: Union[List[float], torch.Tensor], target_sigmas: Union[List[float], torch.Tensor], alpha: float = 1, beta: float = 100, return_loss_terms: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Compute a loss function as the distance from a simple Gaussian target distribution. Parameters ---------- H : torch.Tensor Shape ``(n_batches, n_features)``. Output of the NN. labels : torch.Tensor Shape ``(n_batches,)``. Labels of the dataset. n_states : int The integer labels are expected to be in between 0 and ``n_states-1``. target_centers : list or torch.Tensor Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets. target_sigmas : list or torch.Tensor Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets. alpha : float, optional Centers_loss component prefactor, by default 1. beta : float, optional Sigmas loss compontent prefactor, by default 100. return_loss_terms : bool, optional If ``True``, the loss terms associated to the center and standard deviations of the target Gaussians are returned as well. Default is ``False``. Returns ------- loss : torch.Tensor Loss value. loss_centers : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the centers of the target Gaussians. loss_sigmas : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the standard deviations of the target Gaussians. """ if not isinstance(target_centers, torch.Tensor): target_centers = torch.tensor(target_centers) if not isinstance(target_sigmas, torch.Tensor): target_sigmas = torch.tensor(target_sigmas) device = H.device target_centers = target_centers.to(device) target_sigmas = target_sigmas.to(device) loss_centers = torch.zeros_like(target_centers, device=device) loss_sigmas = torch.zeros_like(target_sigmas, device=device) for i in range(n_states): # check which elements belong to class i if not (labels == i).any(): raise ValueError( f"State {i} was not represented in this batch! Either use bigger batch_size or a more equilibrated dataset composition!" ) else: H_red = H[labels == i] # compute mean and standard deviation over the class i mu = torch.mean(H_red, 0) if len(torch.nonzero(labels == i)) == 1: warn( f"There is only one sample for state {i} in this batch! Std is set to 0, this may affect the training! Either use bigger batch_size or a more equilibrated dataset composition!" ) sigma = torch.tensor(0) else: sigma = torch.std(H_red, 0) # compute loss function contributes for class i loss_centers[i] = alpha * (mu - target_centers[i]).pow(2) loss_sigmas[i] = beta * (sigma - target_sigmas[i]).pow(2) # get total model loss loss_centers = torch.sum(loss_centers) loss_sigmas = torch.sum(loss_sigmas) loss = loss_centers + loss_sigmas if return_loss_terms: return loss, loss_centers, loss_sigmas return loss ```