Open ayulockin opened 9 months ago
dcd5254cdb
)[!TIP] I'll email you at ayusht@wandb.com when I complete this pull request!
Here are the GitHub Actions logs prior to making any changes:
ee5fd38
Checking src/lightning/pytorch/loggers/wandb.py for syntax errors... β src/lightning/pytorch/loggers/wandb.py has no syntax errors!
1/1 βChecking src/lightning/pytorch/loggers/wandb.py for syntax errors... β src/lightning/pytorch/loggers/wandb.py has no syntax errors!
Sandbox passed on the latest master
, so sandbox checks will be enabled for this issue.
I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.
src/lightning/pytorch/loggers/wandb.py
β https://github.com/ayulockin/pytorch-lightning/commit/a0f21497e0f689e5bafa45718de93abbb5e02920 Edit
Modify src/lightning/pytorch/loggers/wandb.py with contents:
β’ Modify the `log_hyperparams` method of the `WandbLogger` class to prevent the second instantiation of models from callable parameters.
β’ Before calling `_sanitize_callable_params`, check if any of the parameters are callable and related to model creation (e.g., a factory method). If so, store the callable's reference without calling it, or if it must be called, ensure that the result is cached and reused instead of creating a new instance.
β’ Update the documentation of the `log_hyperparams` method to explain the new behavior with callable parameters, especially those that instantiate models.
β’ Add a unit test to verify that when a factory method is passed as a hyperparameter, it is not called again during the logging process.
--- +++ @@ -415,10 +415,23 @@ @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], exclude_keys: Optional[List[str]] = None) -> None: # type: ignore[override] params = _convert_params(params) - params = _sanitize_callable_params(params) + params = _sanitize_callable_params(params, exclude_keys=exclude_keys) self.experiment.config.update(params, allow_val_change=True) + + """ + Logs hyperparameters to Weights & Biases. + + If some parameters are callable (e.g., a model factory), they can be excluded from being called and + logged by specifying them in the 'exclude_keys' argument. This is useful when the callable creates + large models that should not be instanced more than necessary. + + Args: + params: Dictionary containing the hyperparameters + exclude_keys: Optional list of keys to exclude from logging when callable + """ + @override @rank_zero_only
src/lightning/pytorch/loggers/wandb.py
β Edit
Check src/lightning/pytorch/loggers/wandb.py with contents:
Ran GitHub Actions for a0f21497e0f689e5bafa45718de93abbb5e02920:
I have finished reviewing the code for completeness. I did not find errors for sweep/nccl_timeout_or_gpu_ooms_when_using_wand
.
π‘ To recreate the pull request edit the issue title or description. To tweak the pull request, leave a comment on the pull request.Something wrong? Let us know.
This is an automated message generated by Sweep AI.
Bug description
Hello! I found this weird interaction that took me a while to debug, so hopefully someone finds it useful or it's possible to fix something in Lightning.
When constructing large models, it's recommended to use
configure_model
. To configure the model creation outside thelightning
, I've been using factories, so that a fully configured factory can just make a model under the strategy context (e.g. deepspeed).Additionally, I've been using
self.save_hyperparameters()
andwandb
logger for convenience. I found that after certain model size, my setup started hanging. I found that the_sanitize_callable_params
function insidelog_hyperparams
ofWandbLogger
calls my factory again hence temporarily creating yet another copy of a model.I can't quite find docs on callable parameters for Modules. Is it a bug or a feature? Why would one resolve the callable second time?
Workaround:
self.save_hyperparameters(ignore="model_factory")
What version are you seeing the problem on?
v2.1
How to reproduce the bug
Error messages and logs
NCCL hanging for me because rank0 GPU reaches 99% capacity:
but probably can lead to OOMs?
Environment
Current environment
``` #- Lightning Component LightningModule, WandbLogger #- PyTorch Lightning Version (e.g., 1.5.0): 2.1.2 #- PyTorch Version (e.g., 2.0): 2.1.2 #- Python version (e.g., 3.9): 3.9 ```More info
No response
Suggested Solution:
This issue appears to be caused by the unintended side-effects of the
_sanitize_callable_params
function in theWandbLogger
of PyTorch Lightning. A fix can be implemented by modifying theWandbLogger
to avoid creating unnecessary copies of large models. Here's a suggested plan of action:_sanitize_callable_params
Function: Look closely at the code for_sanitize_callable_params
to understand exactly how it's behaving with callable parameters.WandbLogger
: Update theWandbLogger
to avoid evaluating the factory function a second time. This could involve setting a flag to indicate whether the factory has already been run or directly setting the model attribute with the instantiating it only once.self.save_hyperparameters(ignore="model_factory")
) into the code as a temporary fix until a more permanent solution can be implemented.Checklist
- [X] Modify `src/lightning/pytorch/loggers/wandb.py` β https://github.com/ayulockin/pytorch-lightning/commit/a0f21497e0f689e5bafa45718de93abbb5e02920 [Edit](https://github.com/ayulockin/pytorch-lightning/edit/sweep/nccl_timeout_or_gpu_ooms_when_using_wand/src/lightning/pytorch/loggers/wandb.py#L416-L420) - [X] Running GitHub Actions for `src/lightning/pytorch/loggers/wandb.py` β [Edit](https://github.com/ayulockin/pytorch-lightning/edit/sweep/nccl_timeout_or_gpu_ooms_when_using_wand/src/lightning/pytorch/loggers/wandb.py#L416-L420)