arsedler9 / lfads-torch

A PyTorch implementation of Latent Factor Analysis via Dynamical Systems (LFADS) and AutoLFADS.
https://arxiv.org/abs/2309.01230
Other
85 stars 20 forks source link

error with ray #21

Open rosskempner opened 2 months ago

rosskempner commented 2 months ago

Hi,

I am running the second part of the tutorial, the 2_run_pbt.py script.

Here is my script:


import os
import shutil
from datetime import datetime
from pathlib import Path

from ray import tune
from ray.tune import CLIReporter
from ray.tune.search.basic_variant import BasicVariantGenerator

from lfads_torch.extensions.tune import (
    BinaryTournamentPBT,
    HyperParam,
    ImprovementRatioStopper,
)
from lfads_torch.run_model import run_model

# ---------- OPTIONS ----------
PROJECT_STR = "pbt"
DATASET_STR = "rouse_multisession_PCR"
RUN_TAG = datetime.now().strftime("%y%m%d")
RUN_DIR = os.getcwd()
import pdb
pdb.set_trace()
HYPERPARAM_SPACE = {
    "model.lr_init": HyperParam(
        1e-4, 1e-3, explore_wt=0.3, enforce_limits=True, init=1e-3
    ),
    "model.dropout_rate": HyperParam(
        0.0, 0.6, explore_wt=0.3, enforce_limits=True, sample_fn="uniform"
    ),
    "model.train_aug_stack.transforms.0.cd_rate": HyperParam(
        0.01, 0.99, explore_wt=0.3, enforce_limits=True, init=0.5, sample_fn="uniform"
    ),
    "model.kl_co_scale": HyperParam(1e-5, 1e-3, explore_wt=0.8),
    "model.kl_ic_scale": HyperParam(1e-5, 1e-3, explore_wt=0.8),
    "model.l2_gen_scale": HyperParam(1e-5, 1e-0, explore_wt=0.8),
    "model.l2_con_scale": HyperParam(1e-4, 1e-0, explore_wt=0.8),
}
# ------------------------------

# Function to keep dropout and CD rates in-bounds
def clip_config_rates(config):
    return {k: min(v, 0.99) if "_rate" in k else v for k, v in config.items()}

init_space = {name: tune.sample_from(hp.init) for name, hp in HYPERPARAM_SPACE.items()}
# Set the mandatory config overrides to select datamodule and model
mandatory_overrides = {
    "datamodule": DATASET_STR,
    "model": DATASET_STR,
    "logger.wandb_logger.project": PROJECT_STR,
    "logger.wandb_logger.tags.1": DATASET_STR,
    "logger.wandb_logger.tags.2": RUN_TAG,
}

# Copy this script into the run directory

# Run the hyperparameter search
metric = "valid/recon_smth"
num_trials = 20
perturbation_interval = 15
burn_in_period = 50 + 15
analysis = tune.run(
    tune.with_parameters(
        run_model,
        config_path="../../../configs/pbt.yaml",
        do_posterior_sample=False,
    ),
    metric=metric,
    mode="min",
    name=RUN_DIR,
    stop=ImprovementRatioStopper(
        num_trials=num_trials,
        perturbation_interval=perturbation_interval,
        burn_in_period=burn_in_period,
        metric=metric,
        patience=4,
        min_improvement_ratio=5e-4,
    ),
    config={**mandatory_overrides, **init_space},
    resources_per_trial=dict(cpu=3, gpu=0.5),
    num_samples=num_trials,
    local_dir=RUN_DIR,
    search_alg=BasicVariantGenerator(random_state=0),
    scheduler=BinaryTournamentPBT(
        perturbation_interval=perturbation_interval,
        burn_in_period=burn_in_period,
        hyperparam_mutations=HYPERPARAM_SPACE,
    ),
    keep_checkpoints_num=1,
    verbose=1,
    progress_reporter=CLIReporter(
        metric_columns=[metric, "cur_epoch"],
        sort_by_metric=True,
    ),
    trial_dirname_creator=lambda trial: str(trial),
)
# Copy the best model to a new folder so it is easy to identify
best_model_dir = RUN_DIR / "best_model"
shutil.copytree(analysis.best_logdir, best_model_dir)
# Switch working directory to this folder (usually handled by tune)
os.chdir(best_model_dir)
# Load the best model and run posterior sampling (skip training)
best_ckpt_dir = best_model_dir / Path(analysis.best_checkpoint._local_path).name
run_model(
    overrides=mandatory_overrides,
    checkpoint_dir=best_ckpt_dir,
    config_path="../../../configs/pbt.yaml",
    do_train=False,
)

And here is my error message:


== Status ==
Current time: 2024-09-05 16:47:51 (running for 00:00:22.52)
Memory usage on this node: 52.9/1007.7 GiB 
PopulationBasedTraining: 0 checkpoints, 0 perturbs
Resources requested: 0/128 CPUs, 0/2 GPUs, 0.0/767.03 GiB heap, 0.0/186.26 GiB objects (0.0/1.0 accelerator_type:RTX)
Result logdir: /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs
Number of trials: 20/20 (19 ERROR, 1 PENDING)
+-----------------------+----------+-----------------------+----------------------+---------------------+---------------------+----------------------+----------------------+-----------------+------------------------+
| Trial name            | status   | loc                   |   model.dropout_rate |   model.kl_co_scale |   model.kl_ic_scale |   model.l2_con_scale |   model.l2_gen_scale |   model.lr_init |   ....train_aug_stack. |
|                       |          |                       |                      |                     |                     |                      |                      |                 |   transforms.0.cd_rate |
|-----------------------+----------+-----------------------+----------------------+---------------------+---------------------+----------------------+----------------------+-----------------+------------------------|
| run_model_0f720_00019 | PENDING  |                       |            0.183342  |         1.96506e-05 |         2.02864e-05 |          0.000313861 |          0.125441    |           0.001 |                    0.5 |
| run_model_0f720_00000 | ERROR    | 10.81.105.145:2546365 |            0.0998097 |         0.000145566 |         0.000562019 |          0.362005    |          0.260521    |           0.001 |                    0.5 |
| run_model_0f720_00001 | ERROR    | 10.81.105.145:2546872 |            0.428983  |         0.000678712 |         1.55435e-05 |          0.0184235   |          0.00489324  |           0.001 |                    0.5 |
| run_model_0f720_00002 | ERROR    | 10.81.105.145:2546875 |            0.103402  |         0.000216868 |         1.79223e-05 |          0.857996    |          0.0610244   |           0.001 |                    0.5 |
| run_model_0f720_00003 | ERROR    | 10.81.105.145:2546879 |            0.444794  |         0.000201164 |         0.000190564 |          0.177565    |          0.0100834   |           0.001 |                    0.5 |
| run_model_0f720_00004 | ERROR    | 10.81.105.145:2547180 |            0.373833  |         0.000137225 |         0.000604685 |          0.00033449  |          0.149631    |           0.001 |                    0.5 |
| run_model_0f720_00005 | ERROR    | 10.81.105.145:2547927 |            0.0448103 |         7.36163e-05 |         0.000395537 |          0.170101    |          0.00522518  |           0.001 |                    0.5 |
| run_model_0f720_00006 | ERROR    | 10.81.105.145:2547933 |            0.241651  |         1.58007e-05 |         2.34636e-05 |          0.581212    |          1.47931e-05 |           0.001 |                    0.5 |
| run_model_0f720_00007 | ERROR    | 10.81.105.145:2548169 |            0.412236  |         6.3678e-05  |         5.93763e-05 |          0.699554    |          0.398774    |           0.001 |                    0.5 |
| run_model_0f720_00008 | ERROR    | 10.81.105.145:2548173 |            0.117632  |         0.000613384 |         8.97283e-05 |          0.0157385   |          0.000186606 |           0.001 |                    0.5 |
| run_model_0f720_00009 | ERROR    | 10.81.105.145:2548890 |            0.122786  |         0.000559937 |         3.01349e-05 |          0.000128701 |          0.0237087   |           0.001 |                    0.5 |
| run_model_0f720_00010 | ERROR    | 10.81.105.145:2548893 |            0.395912  |         0.000156902 |         2.0803e-05  |          0.255673    |          0.269453    |           0.001 |                    0.5 |
| run_model_0f720_00011 | ERROR    | 10.81.105.145:2549739 |            0.324184  |         3.9222e-05  |         8.08373e-05 |          0.239034    |          0.000984247 |           0.001 |                    0.5 |
| run_model_0f720_00012 | ERROR    | 10.81.105.145:2549778 |            0.310792  |         3.46667e-05 |         0.00051653  |          0.005516    |          0.000158445 |           0.001 |                    0.5 |
| run_model_0f720_00013 | ERROR    | 10.81.105.145:2549785 |            0.19944   |         9.65847e-05 |         2.86928e-05 |          0.000929231 |          0.0169142   |           0.001 |                    0.5 |
| run_model_0f720_00014 | ERROR    | 10.81.105.145:2549795 |            0.400014  |         0.000282342 |         2.32923e-05 |          0.0177753   |          0.000776878 |           0.001 |                    0.5 |
| run_model_0f720_00015 | ERROR    | 10.81.105.145:2550830 |            0.412812  |         0.00033784  |         0.000113766 |          0.000269373 |          9.77298e-05 |           0.001 |                    0.5 |
| run_model_0f720_00016 | ERROR    | 10.81.105.145:2550831 |            0.339418  |         0.000107645 |         0.000920339 |          0.213078    |          0.00158166  |           0.001 |                    0.5 |
| run_model_0f720_00017 | ERROR    | 10.81.105.145:2550839 |            0.341545  |         7.79686e-05 |         2.2213e-05  |          0.000881869 |          0.00873363  |           0.001 |                    0.5 |
| run_model_0f720_00018 | ERROR    | 10.81.105.145:2550840 |            0.324184  |         0.000384731 |         3.67654e-05 |          0.00945664  |          0.00138483  |           0.001 |                    0.5 |
+-----------------------+----------+-----------------------+----------------------+---------------------+---------------------+----------------------+----------------------+-----------------+------------------------+
Number of errored trials: 19
+-----------------------+--------------+-------------------------------------------------------------------------------------------------------------------+
| Trial name            |   # failures | error file                                                                                                        |
|-----------------------+--------------+-------------------------------------------------------------------------------------------------------------------|
| run_model_0f720_00000 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00000/error.txt |
| run_model_0f720_00001 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00001/error.txt |
| run_model_0f720_00002 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00002/error.txt |
| run_model_0f720_00003 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00003/error.txt |
| run_model_0f720_00004 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00004/error.txt |
| run_model_0f720_00005 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00005/error.txt |
| run_model_0f720_00006 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00006/error.txt |
| run_model_0f720_00007 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00007/error.txt |
| run_model_0f720_00008 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00008/error.txt |
| run_model_0f720_00009 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00009/error.txt |
| run_model_0f720_00010 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00010/error.txt |
| run_model_0f720_00011 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00011/error.txt |
| run_model_0f720_00012 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00012/error.txt |
| run_model_0f720_00013 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00013/error.txt |
| run_model_0f720_00014 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00014/error.txt |
| run_model_0f720_00015 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00015/error.txt |
| run_model_0f720_00016 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00016/error.txt |
| run_model_0f720_00017 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00017/error.txt |
| run_model_0f720_00018 |            1 | /home/xx/Documents/xx_lab_projects/lfads_xx/tutorials/multisession/runs/run_model_0f720_00018/error.txt |
+-----------------------+--------------+-------------------------------------------------------------------------------------------------------------------+

(pid=2551608) WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
(pid=2551608) I0000 00:00:1725569274.455947 2551608 fork_posix.cc:77] Other threads are currently calling into gRPC, skipping fork() handlers
(pid=2551608) I0000 00:00:1725569274.490916 2551608 fork_posix.cc:77] Other threads are currently calling into gRPC, skipping fork() handlers
(pid=2551608) I0000 00:00:1725569274.671377 2551608 fork_posix.cc:77] Other threads are currently calling into gRPC, skipping fork() handlers
2024-09-05 16:47:55,277 ERROR serialization.py:371 -- Failed to unpickle serialized exception
Traceback (most recent call last):
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/exceptions.py", line 46, in from_ray_exception
    return pickle.loads(ray_exception.serialized_exception)
TypeError: __init__() missing 1 required positional argument: 'missing_cfg_file'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/_private/serialization.py", line 369, in deserialize_objects
    obj = self._deserialize_object(data, metadata, object_ref)
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/_private/serialization.py", line 275, in _deserialize_object
    return RayError.from_bytes(obj)
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/exceptions.py", line 40, in from_bytes
    return RayError.from_ray_exception(ray_exception)
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/exceptions.py", line 49, in from_ray_exception
    raise RuntimeError(msg) from e
RuntimeError: Failed to unpickle serialized exception
2024-09-05 16:47:55,277 ERROR trial_runner.py:993 -- Trial run_model_0f720_00019: Error processing event.
ray.tune.error._TuneNoNextExecutorEventError: Traceback (most recent call last):
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/tune/execution/ray_trial_executor.py", line 1050, in get_next_executor_event
    future_result = ray.get(ready_future)
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/_private/worker.py", line 2291, in get
    raise value
ray.exceptions.RaySystemError: System error: Failed to unpickle serialized exception
traceback: Traceback (most recent call last):
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/exceptions.py", line 46, in from_ray_exception
    return pickle.loads(ray_exception.serialized_exception)
TypeError: __init__() missing 1 required positional argument: 'missing_cfg_file'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/_private/serialization.py", line 369, in deserialize_objects
    obj = self._deserialize_object(data, metadata, object_ref)
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/_private/serialization.py", line 275, in _deserialize_object
    return RayError.from_bytes(obj)
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/exceptions.py", line 40, in from_bytes
    return RayError.from_ray_exception(ray_exception)
  File "/home/xx/anaconda3/envs/lfads-torch/lib/python3.9/site-packages/ray/exceptions.py", line 49, in from_ray_exception
    raise RuntimeError(msg) from e
RuntimeError: Failed to unpickle serialized exception
arsedler9 commented 2 months ago

It seems like the workers are not able to find the config files:

TypeError: __init__() missing 1 required positional argument: 'missing_cfg_file'

Are you sure that the config paths are specified correctly relative to your working directory?

rosskempner commented 2 months ago

I tried a fix by aligning the relative paths to the working directory I printed inside the run function of tune.py in the ray/tune/ directory. That did not work. My second attempt has been to try to more exactly do what your tutorial does and just replace the RUN_DIR path to something that makes sense since it seems like you were expecting people to change that path. That did not work either.

Note that the program gets up to but then stop at the following lines in tune.py in ray/tune/ directory:

    if incomplete_trials:
        if raise_on_failed_trial and not state["signal"]:
            raise TuneError("Trials did not complete", incomplete_trials)
        else:
            logger.error("Trials did not complete: %s", incomplete_trials)

Do you have any ideas on how I can troubleshoot?

Here is my very minimal changed code almost identical to your code for the pbt but now with just one line changed:


import os
import shutil
from datetime import datetime
from pathlib import Path

from ray import tune
from ray.tune import CLIReporter
from ray.tune.search.basic_variant import BasicVariantGenerator

from lfads_torch.extensions.tune import (
    BinaryTournamentPBT,
    HyperParam,
    ImprovementRatioStopper,
)
from lfads_torch.run_model import run_model

# ---------- OPTIONS ----------
PROJECT_STR = "pbt"
DATASET_STR = "rouse_multisession_PCR"
RUN_TAG = datetime.now().strftime("%y%m%d")
import random
RUN_DIR = Path(f"run{RUN_TAG}{str(random.random)}") #/ PROJECT_STR / DATASET_STR / RUN_TAG
HYPERPARAM_SPACE = {
    "model.lr_init": HyperParam(
        1e-4, 1e-3, explore_wt=0.3, enforce_limits=True, init=1e-3
    ),
    "model.dropout_rate": HyperParam(
        0.0, 0.6, explore_wt=0.3, enforce_limits=True, sample_fn="uniform"
    ),
    "model.train_aug_stack.transforms.0.cd_rate": HyperParam(
        0.01, 0.99, explore_wt=0.3, enforce_limits=True, init=0.5, sample_fn="uniform"
    ),
    "model.kl_co_scale": HyperParam(1e-5, 1e-3, explore_wt=0.8),
    "model.kl_ic_scale": HyperParam(1e-5, 1e-3, explore_wt=0.8),
    "model.l2_gen_scale": HyperParam(1e-5, 1e-0, explore_wt=0.8),
    "model.l2_con_scale": HyperParam(1e-4, 1e-0, explore_wt=0.8),
}
# ------------------------------

# Function to keep dropout and CD rates in-bounds
def clip_config_rates(config):
    return {k: min(v, 0.99) if "_rate" in k else v for k, v in config.items()}

init_space = {name: tune.sample_from(hp.init) for name, hp in HYPERPARAM_SPACE.items()}
# Set the mandatory config overrides to select datamodule and model
mandatory_overrides = {
    "datamodule": DATASET_STR,
    "model": DATASET_STR,
    "logger.wandb_logger.project": PROJECT_STR,
    "logger.wandb_logger.tags.1": DATASET_STR,
    "logger.wandb_logger.tags.2": RUN_TAG,
}
RUN_DIR.mkdir(parents=True)
# Copy this script into the run directory
print("__file__" , __file__)
shutil.copyfile(__file__, RUN_DIR / Path(__file__).name)
# Run the hyperparameter search
metric = "valid/recon_smth"
num_trials = 20
perturbation_interval = 15
burn_in_period = 50 + 15
analysis = tune.run(
    tune.with_parameters(
        run_model,
        config_path="../configs/pbt.yaml",
        do_posterior_sample=False,
    ),
    metric=metric,
    mode="min",
    name=RUN_DIR.name,
    stop=ImprovementRatioStopper(
        num_trials=num_trials,
        perturbation_interval=perturbation_interval,
        burn_in_period=burn_in_period,
        metric=metric,
        patience=4,
        min_improvement_ratio=5e-4,
    ),
    config={**mandatory_overrides, **init_space},
    resources_per_trial=dict(cpu=3, gpu=0.5),
    num_samples=num_trials,
    local_dir=RUN_DIR.parent,
    search_alg=BasicVariantGenerator(random_state=0),
    scheduler=BinaryTournamentPBT(
        perturbation_interval=perturbation_interval,
        burn_in_period=burn_in_period,
        hyperparam_mutations=HYPERPARAM_SPACE,
    ),
    keep_checkpoints_num=1,
    verbose=1,
    progress_reporter=CLIReporter(
        metric_columns=[metric, "cur_epoch"],
        sort_by_metric=True,
    ),
    trial_dirname_creator=lambda trial: str(trial),
)
# Copy the best model to a new folder so it is easy to identify
best_model_dir = RUN_DIR / "best_model"
shutil.copytree(analysis.best_logdir, best_model_dir)
# Switch working directory to this folder (usually handled by tune)
os.chdir(best_model_dir)
# Load the best model and run posterior sampling (skip training)
best_ckpt_dir = best_model_dir / Path(analysis.best_checkpoint._local_path).name
run_model(
    overrides=mandatory_overrides,
    checkpoint_dir=best_ckpt_dir,
    config_path="../configs/pbt.yaml",
    do_train=False,
)