unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.05k stars 878 forks source link

[BUG] issues suppressing output in NBeats fit call using `pl_trainer_kwargs` #933

Closed TamerAbdelmigid closed 2 years ago

TamerAbdelmigid commented 2 years ago

Describe the bug I wanted to copletely suppress the output of model creation and fitting, when I discovered a bug when I re-run the cell of forecaster creation when using a callback in pl_trainer_kwargs in NBeats model.

No error is produced if I used verbose=False during the fit step, or if I set '"progress_bar_refresh_rate":0' in pl_trainer_kwargs. Side note, setting '"enable_progress_bar":False" in pl_trainer_kwargs does not have any effect.

To Reproduce from pytorch_lightning.callbacks.progress import TQDMProgressBar progress = TQDMProgressBar(refresh_rate=0)

model_nbeats = NBEATSModel(
    input_chunk_length=15,
    output_chunk_length=1,
    generic_architecture=False,
    num_stacks=10,
    num_blocks=4,
    num_layers=10,
    layer_widths=512,
    n_epochs=10,
    nr_epochs_val_period=1,
    batch_size=800,
    model_name="nbeats_run",
    pl_trainer_kwargs={"callbacks": [progress]},
)

Expected behavior to run the cell containing model creation again with changed parameter with no error.

System (please complete the following information):

Additional context I want to suppress the output entirely btw to use model creation inside a function in an Bayesian optimization loop. Too long of error message to put here but the error type is TypeError: __deepcopy__() takes 1 positional argument but 2 were given

dennisbader commented 2 years ago

Hi @TamerAbdelmigid. You need to create a new/fresh progress bar object progress = TQDMProgressBar(refresh_rate=0) each time before you initialize a model.

Concerning {"enable_progress_bar": False}: currently this is overwritten by verbose=True (default) or False when calling fit() or predict(). The verbose paramater is deprecated and will be removed in a future darts version.