atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.27k stars 103 forks source link

Code for Normalized Path Energy calculation #140

Open WANG-CR opened 1 month ago

WANG-CR commented 1 month ago

Hello,

Thank you very much for this nice package and the papers. I wanted to kindly ask if the code for calculating the Normalized Path Energy for optimal transport used in Experiment 5.1 is provided?

atong01 commented 1 month ago

Hi!

Short answer is no. The code for this one is mostly post processed as it is somewhat inefficient.

See codesnippet below for how the baseline was calculated for each seed, then to calculate path energy we take the recorded path energy from Wandb in the runner directory and divide by this baseline one.

try:
    baseline_df = pd.read_pickle("baseline_10000.pkl")
except FileNotFoundError:
    import os
    import sys

    import hydra
    import omegaconf
    import pyrootutils
    import torch

    sys.path.append("..")
    from src.models.components.optimal_transport import wasserstein

    cwd = os.path.abspath("")
    torch.manual_seed(42)
    x0 = torch.randn(1000, 2)
    results = []
    for system in ["scurve", "moons"]:
        for seed in [42, 43, 44, 45, 46]:
            root = pyrootutils.setup_root(cwd, pythonpath=True)
            cfg = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "sklearn.yaml")
            cfg.system = system
            cfg.seed = seed
            cfg.train_val_test_split = [10000, 10000, 10000]
            datamodule = hydra.utils.instantiate(cfg)

            x1_test = datamodule.data_test.dataset[datamodule.data_test.indices]
            x1_val = datamodule.data_val.dataset[datamodule.data_val.indices]
            dist = wasserstein(x0, x1_test, power=2)
            dist1 = wasserstein(x0, x1_test, power=1)
            baseline = wasserstein(x1_val, x1_test, power=2)

            results.append(("Baseline", system, seed, dist**2, dist1, baseline))
    for system in ["moon-8gaussians"]:
        for seed in [42, 43, 44, 45, 46]:
            root = pyrootutils.setup_root(cwd, pythonpath=True)
            cfg = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "twodim.yaml")
            cfg.system = system
            cfg.seed = seed
            cfg.train_val_test_split = [10000, 10000, 10000]

            datamodule = hydra.utils.instantiate(cfg)
            dm_test = datamodule.split_timepoint_data[1][2]
            dm_test_zero = datamodule.split_timepoint_data[0][2]
            dm_val = datamodule.split_timepoint_data[1][1]
            x1_test = dm_test.dataset[dm_test.indices]
            x1_val = dm_val.dataset[dm_val.indices]
            x0_test = dm_test_zero.dataset[dm_test_zero.indices]
            dist = wasserstein(x0_test, x1_test, power=2)
            dist1 = wasserstein(x0, x1_test, power=1)
            baseline = wasserstein(x1_val, x1_test, power=2)

            results.append(("Baseline", system, seed, dist**2, dist1, baseline))

    for system in ["gaussians"]:
        for seed in [42, 43, 44, 45, 46]:
            root = pyrootutils.setup_root(cwd, pythonpath=True)
            cfg = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "torchdyn.yaml")
            cfg.train_val_test_split = [10000, 10000, 10000]
            cfg.system = system
            cfg.seed = seed
            cfg.system_kwargs.noise = 1e-4
            datamodule = hydra.utils.instantiate(cfg)

            x1_test = datamodule.data_test.dataset[datamodule.data_test.indices]
            x1_val = datamodule.data_val.dataset[datamodule.data_val.indices]
            dist = wasserstein(x0, x1_test, power=2)
            dist1 = wasserstein(x0, x1_test, power=1)
            baseline = wasserstein(x1_val, x1_test, power=2)

            results.append(("Baseline", system, seed, dist**2, dist1, baseline))
    baseline_df = pd.DataFrame(
        results,
        columns=[
            "Method",
            "Dataset",
            "seed",
            "Path Energy",
            "Path Length",
            "2-Wasserstein",
        ],
    )
    baseline_df.groupby("Dataset").mean()
    baseline_df.to_pickle("baseline_10000.pkl")

Code for normalization:

# implement normalized metrics
# per dataset normalization
# m / m_baseline
def normalize_metrics(df, metrics):
    baseline_means = clean_df.groupby(["Method", "Dataset"]).mean().xs("Baseline", level="Method")
    cdf = df
    normed = cdf[metrics] / baseline_means[metrics]
    res = []
    for i, row in cdf[["Dataset", *metrics]].iterrows():
        res.append(
            abs(row[metrics] - baseline_means.loc[row["Dataset"]][metrics])
            / baseline_means.loc[row["Dataset"]][metrics]
        )
    normed = pd.DataFrame(res)
    normed.columns = [f"Normalized {metric}" for metric in normed.columns]
    cdf = cdf.reset_index()
    cdf = pd.concat([cdf, normed], axis=1)
    cdf = cdf.loc[:, ~cdf.columns.duplicated(keep="last")]
    cdf.index = df.index
    return cdf

clean_df = normalize_metrics(clean_df, ["Path Energy"])