Open Benjamin-Etheredge opened 1 year ago
A temporary workaround for this issue is to declare a TensorBoard logger ahead of the MLflow one. Like so,
cli = LightningCLI(
BoringModel,
BoringDataModule,
trainer_defaults=dict(
max_epochs=1,
logger=[
{
"class_path": "pytorch_lightning.loggers.TensorBoardLogger",
"init_args": {
"save_dir": "tb_logs",
}
},
"pytorch_lightning.loggers.MLFlowLogger"
],
)
)
@Benjamin-Etheredge Here is my workaround, which still leverage the goodness of CLI module and its yaml file.
cli = LightningCLI(
LightningToneClassifier,
ToneDataModule,
run=False,
)
with open("lightning/trainer_config.yaml", "r") as f:
config = yaml.safe_load(f)
config["trainer"]["logger"] = MLFlowLogger(
experiment_name="xxxx",
tracking_uri="xxxx",
log_model=True,
)
train_dataloader, val_dataloader = prepare_fit_dataloader(cli)
trainer = Trainer(**config["trainer"])
trainer.logger.log_hyperparams(config)
trainer.fit(cli.model, train_dataloader, val_dataloader)
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!
Hi @vincentwu0730 ,
thank you for your workaround! can you share what is prepare_fit_dataloader
?
The issue surfaces through the usage in LightningCLI because it calls the log dir, but the origin of the problem as suspected by @Benjamin-Etheredge is because the save_dir from MLFlowLogger returns None in case tracking is not done locally:
Two possible solutions that come to my mind to address this:
Two possible solutions that come to my mind to address this:
I can suggest another solution. Implement a custom save config class that saves the config in mlflow as an artifact, instead of saving the config locally. If logging remotely it makes sense to also save the config in the same place.
A realization of @mauvilsa idea:
from lightning.pytorch.cli import SaveConfigCallback
class MLFlowSaveConfigCallback(SaveConfigCallback):
def __init__(self, parser, config, config_filename='config.yaml', overwrite=False, multifile=False):
super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir=False)
def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
# Convert Namespace to dict
config_dict = vars(self.config)
# Log parameters to MLFlow
pl_module.logger.log_hyperparams(config_dict)
def cli_compile_main():
cli = LightningCLI(datamodule_class=PRDataModule, run=False, save_config_callback=MLFlowSaveConfigCallback)
compiled_model = torch.compile(cli.model)
cli.trainer.fit(compiled_model, datamodule=cli.datamodule)
cli.trainer.test(datamodule=cli.datamodule)
Slight modification of @terbed if you want to safe the file as yaml
from lightning.pytorch.cli import SaveConfigCallback
from lightning import Trainer, LightningModule
import tempfile
class MLFlowSaveConfigCallback(SaveConfigCallback):
def __init__(self, parser, config, config_filename='config.yaml', overwrite=False, multifile=False):
super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir=False)
def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
# convert namespace to dict
config_dict = vars(self.config)
if trainer.is_global_zero:
with tempfile.TemporaryDirectory() as tmp_dir:
config_path = Path(tmp_dir) / 'config.yaml'
self.parser.save(
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
)
trainer.logger.experiment.log_artifact(local_path=config_path,
run_id=trainer.logger.run_id)
Bug description
Running with the LightningCLI, MLflow logger, and
MLFLOW_TRACKING_URI
environment variable set causes an assertion failure with logging. I think using a remote tracking server causes no local log files to be created which the CLI doesn't like.I suspect it's a similar issue to #12748.
How to reproduce the bug
Error messages and logs
Environment
More info
No response
cc @carmocca @mauvilsa