Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.96k stars 3.35k forks source link

Odd AttributeError when calling model.fit() #6772

Closed pnmartinez closed 3 years ago

pnmartinez commented 3 years ago

🐛 Bug

Having a really esoteric bug when using model.fit().

I've tried running with fast_dev_run = True, and in that case, everything goes smoothly.

Please check the logs below, and some snippet of the initialization used. Maybe you can find some inconsistency on it.

Environment

Console output

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-46-6221476b56fa> in <module>
    116 
    117         ## TRAINING
--> 118         trainer.fit(
    119             net,
    120             train_dataloader=train_dataloader,

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    472         self.call_hook("on_before_accelerator_backend_setup", model)
    473         self.accelerator.setup(self, model)
--> 474         self.setup_trainer(model)
    475 
    476         # ----------------------------

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in setup_trainer(self, model)
    421             self.logger.log_hyperparams(model.hparams_initial)
    422             self.logger.log_graph(model)
--> 423             self.logger.save()
    424 
    425     def fit(

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py in wrapped_fn(*args, **kwargs)
     38     def wrapped_fn(*args, **kwargs):
     39         if rank_zero_only.rank == 0:
---> 40             return fn(*args, **kwargs)
     41 
     42     return wrapped_fn

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/loggers/tensorboard.py in save(self)
    231         # save the metatags file if it doesn't exist and the log directory exists
    232         if self._fs.isdir(dir_path) and not self._fs.isfile(hparams_file):
--> 233             save_hparams_to_yaml(hparams_file, self.hparams)
    234 
    235     @rank_zero_only

~/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py in save_hparams_to_yaml(config_yaml, hparams)
    397     for k, v in hparams.items():
    398         try:
--> 399             yaml.dump(v)
    400         except TypeError:
    401             warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")

~/anaconda3/lib/python3.8/site-packages/yaml/__init__.py in dump(data, stream, Dumper, **kwds)
    288     If stream is None, return the produced string instead.
    289     """
--> 290     return dump_all([data], stream, Dumper=Dumper, **kwds)
    291 
    292 def safe_dump_all(documents, stream=None, **kwds):

~/anaconda3/lib/python3.8/site-packages/yaml/__init__.py in dump_all(documents, stream, Dumper, default_style, default_flow_style, canonical, indent, width, allow_unicode, line_break, encoding, explicit_start, explicit_end, version, tags, sort_keys)
    276         dumper.open()
    277         for data in documents:
--> 278             dumper.represent(data)
    279         dumper.close()
    280     finally:

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent(self, data)
     25 
     26     def represent(self, data):
---> 27         node = self.represent_data(data)
     28         self.serialize(node)
     29         self.represented_objects = {}

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_data(self, data)
     50             for data_type in data_types:
     51                 if data_type in self.yaml_multi_representers:
---> 52                     node = self.yaml_multi_representers[data_type](self, data)
     53                     break
     54             else:

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_object(self, data)
    340         if not args and not listitems and not dictitems \
    341                 and isinstance(state, dict) and newobj:
--> 342             return self.represent_mapping(
    343                     'tag:yaml.org,2002:python/object:'+function_name, state)
    344         if not listitems and not dictitems  \

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_mapping(self, tag, mapping, flow_style)
    116         for item_key, item_value in mapping:
    117             node_key = self.represent_data(item_key)
--> 118             node_value = self.represent_data(item_value)
    119             if not (isinstance(node_key, ScalarNode) and not node_key.style):
    120                 best_style = False

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_data(self, data)
     50             for data_type in data_types:
     51                 if data_type in self.yaml_multi_representers:
---> 52                     node = self.yaml_multi_representers[data_type](self, data)
     53                     break
     54             else:

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_object(self, data)
    354         if dictitems:
    355             value['dictitems'] = dictitems
--> 356         return self.represent_mapping(tag+function_name, value)
    357 
    358     def represent_ordered_dict(self, data):

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_mapping(self, tag, mapping, flow_style)
    116         for item_key, item_value in mapping:
    117             node_key = self.represent_data(item_key)
--> 118             node_value = self.represent_data(item_value)
    119             if not (isinstance(node_key, ScalarNode) and not node_key.style):
    120                 best_style = False

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_data(self, data)
     46         data_types = type(data).__mro__
     47         if data_types[0] in self.yaml_representers:
---> 48             node = self.yaml_representers[data_types[0]](self, data)
     49         else:
     50             for data_type in data_types:

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_list(self, data)
    197         #            break
    198         #if not pairs:
--> 199             return self.represent_sequence('tag:yaml.org,2002:seq', data)
    200         #value = []
    201         #for item_key, item_value in data:

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_sequence(self, tag, sequence, flow_style)
     90         best_style = True
     91         for item in sequence:
---> 92             node_item = self.represent_data(item)
     93             if not (isinstance(node_item, ScalarNode) and not node_item.style):
     94                 best_style = False

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_data(self, data)
     46         data_types = type(data).__mro__
     47         if data_types[0] in self.yaml_representers:
---> 48             node = self.yaml_representers[data_types[0]](self, data)
     49         else:
     50             for data_type in data_types:

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_tuple(self, data)
    284 
    285     def represent_tuple(self, data):
--> 286         return self.represent_sequence('tag:yaml.org,2002:python/tuple', data)
    287 
    288     def represent_name(self, data):

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_sequence(self, tag, sequence, flow_style)
     90         best_style = True
     91         for item in sequence:
---> 92             node_item = self.represent_data(item)
     93             if not (isinstance(node_item, ScalarNode) and not node_item.style):
     94                 best_style = False

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_data(self, data)
     50             for data_type in data_types:
     51                 if data_type in self.yaml_multi_representers:
---> 52                     node = self.yaml_multi_representers[data_type](self, data)
     53                     break
     54             else:

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_object(self, data)
    354         if dictitems:
    355             value['dictitems'] = dictitems
--> 356         return self.represent_mapping(tag+function_name, value)
    357 
    358     def represent_ordered_dict(self, data):

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_mapping(self, tag, mapping, flow_style)
    116         for item_key, item_value in mapping:
    117             node_key = self.represent_data(item_key)
--> 118             node_value = self.represent_data(item_value)
    119             if not (isinstance(node_key, ScalarNode) and not node_key.style):
    120                 best_style = False

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_data(self, data)
     46         data_types = type(data).__mro__
     47         if data_types[0] in self.yaml_representers:
---> 48             node = self.yaml_representers[data_types[0]](self, data)
     49         else:
     50             for data_type in data_types:

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_dict(self, data)
    205 
    206     def represent_dict(self, data):
--> 207         return self.represent_mapping('tag:yaml.org,2002:map', data)
    208 
    209     def represent_set(self, data):

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_mapping(self, tag, mapping, flow_style)
    116         for item_key, item_value in mapping:
    117             node_key = self.represent_data(item_key)
--> 118             node_value = self.represent_data(item_value)
    119             if not (isinstance(node_key, ScalarNode) and not node_key.style):
    120                 best_style = False

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_data(self, data)
     50             for data_type in data_types:
     51                 if data_type in self.yaml_multi_representers:
---> 52                     node = self.yaml_multi_representers[data_type](self, data)
     53                     break
     54             else:

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_object(self, data)
    344         if not listitems and not dictitems  \
    345                 and isinstance(state, dict) and not state:
--> 346             return self.represent_sequence(tag+function_name, args)
    347         value = {}
    348         if args:

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_sequence(self, tag, sequence, flow_style)
     90         best_style = True
     91         for item in sequence:
---> 92             node_item = self.represent_data(item)
     93             if not (isinstance(node_item, ScalarNode) and not node_item.style):
     94                 best_style = False

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_data(self, data)
     50             for data_type in data_types:
     51                 if data_type in self.yaml_multi_representers:
---> 52                     node = self.yaml_multi_representers[data_type](self, data)
     53                     break
     54             else:

~/anaconda3/lib/python3.8/site-packages/yaml/representer.py in represent_object(self, data)
    329         if dictitems is not None:
    330             dictitems = dict(dictitems)
--> 331         if function.__name__ == '__newobj__':
    332             function = args[0]
    333             args = args[1:]

AttributeError: 'str' object has no attribute '__name__'

Model initialization

I am using pytorch_forecastin on top of Pytorch lightning. The model is initialized as below:

#### PARAMETERS INIT

# DataSet object and DataLoader
training = TimeSeriesDataSet(
    training_df,
    time_idx="time_idx",
    target="ma",
    categorical_encoders={"group_ids": NaNLabelEncoder().fit(training_df.group_ids)},
    group_ids=["group_ids"],
    # only unknown variable is "ma" - N-Beats cannot take any additional variables (covariates)
    time_varying_unknown_reals=["ma"],
    max_encoder_length=context_length,
    max_prediction_length=prediction_length
)

# Validation set and dataloader construction
validation = TimeSeriesDataSet.from_dataset(training, dataset, min_prediction_idx=training_cutoff + 1)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size) #, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size) #, num_workers=0)

#### FIND BEST LEARNING RATE
# pl.seed_everything(42)
trainer = pl.Trainer(gpus=0, gradient_clip_val=0.1)

# Model initialization to search optimal lr
net = NBeats.from_dataset(
    training, 
    learning_rate=3e-2, 
    weight_decay=1e-2, 
    widths=[32, 512], 
    backcast_loss_ratio=0.7
)

# find optimal learning rate
res = trainer.tuner.lr_find(
    net, 
    train_dataloader=train_dataloader, 
    val_dataloaders=val_dataloader, 
    min_lr=1e-5
)
print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()
net.hparams.learning_rate = res.suggestion()

## TRAINING SETTINGS AND ACTUAL MODEL INIT

# Early Stopping callback
early_stop_callback = EarlyStopping(
    monitor="val_forecast_loss", 
    min_delta=1e-4, 
    patience=100, # non-improvement steps allowed til early stopping
    verbose=True, 
    mode="min"
)

# log to tensorboard
tensor_log_dir = "lightning_logs"
logger = TensorBoardLogger( 
                        tensor_log_dir, 
#                                 default_hp_metric = 'val_forecast_loss' # replaces "hp_metric" value on "metrics" dict
)  
#         logger.log_metrics(net.logging_metrics)

# checkpoint every time next metric is improved
checkpoint_callback = ModelCheckpoint(
    dirpath = 'nbeats_checkpoints', 
    monitor="val_forecast_loss"
)

# Actual trainer object
trainer = pl.Trainer(
    max_epochs=10_000,
    gpus=0,
    weights_summary="top",
    gradient_clip_val=0.1,
    callbacks=[early_stop_callback, checkpoint_callback],
    limit_train_batches=30,
    fast_dev_run = True, # use this to test: run a full train, val and test loop using 1 batch(es).
    default_root_dir='nbeats_checkpoints',
    logger = logger
)

# init actual model to be used
del net
net = NBeats.from_dataset(
    training,
    learning_rate=res.suggestion(),
    log_interval=-1,# use -1 suggestion for debugging
    log_val_interval=1,
    weight_decay=1e-2,
    widths=[32, 512],
    backcast_loss_ratio=0.7, # weight given to the 
)                            # backcast loss in the overall loss

## TRAINING 
trainer.fit(
    net,
    train_dataloader=train_dataloader,
    val_dataloaders=val_dataloader,
)
pnmartinez commented 3 years ago

It seems there's a compatibility issue in the newest PytorchLightning with PytorchForecasting.

Downgrading PytorchLightning from 1.2.6 to 1.1.8, as indicated in this answer on StackOverflow, solved the issue.