Open miheHMR opened 2 years ago
I have a similar issue, I'm trying to train multiple target model but do not know How to set target normalizer and group_ids Thx
@jhbale11 I made a new column as group_id and initialized it with 0. Because i've only one group. Normalizer i have not configured yet.
I made other changes and they seem to be working
In the TemporalFusionTransformer.from_dataset
part it seems that theres a list needed as output_size.
output_size=[7, 7, 7, 7], # 4 target variables
and for loss calculations i put in a MultiLoss
loss=MultiLoss([QuantileLoss(), QuantileLoss(), QuantileLoss(), QuantileLoss()]),
thanks a lot!!!!
I have a further question about multi target variables.
I'm trying to get output for multiple targets with Softmax.
Which means sum of the target's outputs should be 1.
What should i do then?
Thanks again for your kindness :)
thanks a lot!!!!
I have a further question about multi target variables.
I'm trying to get output for multiple targets with Softmax.
Which means sum of the target's outputs should be 1.
What should i do then?
Thanks again for your kindness :)
Were you able to figure this out? I am running into issues with multiple targets, and now I'm stuck at the trainer.fit step. I am working on a multi-class classification problem. I have 7 different target variables and would like the sum of their probabilities -- coupled with ( 1 - (the probabilities of all 7 added together)) -- to add to 1. I followed the suggestions in this thread (MultiLoss and output size as a list). The output is below:
`LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
TypeError Traceback (most recent call last) Input In [45], in <cell line: 2>() 1 # fit network ----> 2 trainer.fit( 3 tft, 4 train_dataloaders=train_dataloader, 5 val_dataloaders=val_dataloader, 6 )
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:768, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
749 r"""
750 Runs the full optimization routine.
751
(...)
765 datamodule: An instance of :class:~pytorch_lightning.core.datamodule.LightningDataModule
.
766 """
767 self.strategy.model = model
--> 768 self._call_and_handle_interrupt(
769 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
770 )
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:721, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, kwargs) 719 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, *kwargs) 720 else: --> 721 return trainer_fn(args, kwargs) 722 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 723 except KeyboardInterrupt as exception:
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:809, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 805 ckpt_path = ckpt_path or self.resume_from_checkpoint 806 self._ckpt_path = self.__set_ckpt_path( 807 ckpt_path, model_provided=True, model_connected=self.lightning_module is not None 808 ) --> 809 results = self._run(model, ckpt_path=self.ckpt_path) 811 assert self.state.stopped 812 self.training = False
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1234, in Trainer._run(self, model, ckpt_path) 1230 self._checkpoint_connector.restore_training_state() 1232 self._checkpoint_connector.resume_end() -> 1234 results = self._run_stage() 1236 log.detail(f"{self.class.name}: trainer tearing down") 1237 self._teardown()
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1321, in Trainer._run_stage(self) 1319 if self.predicting: 1320 return self._run_predict() -> 1321 return self._run_train()
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1343, in Trainer._run_train(self) 1340 self._pre_training_routine() 1342 with isolate_rng(): -> 1343 self._run_sanity_check() 1345 # enable train mode 1346 self.model.train()
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1411, in Trainer._run_sanity_check(self) 1409 # run eval step 1410 with torch.no_grad(): -> 1411 val_loop.run() 1413 self._call_callback_hooks("on_sanity_check_end") 1415 # reset logger connector
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/loops/base.py:204, in Loop.run(self, *args, kwargs) 202 try: 203 self.on_advance_start(*args, *kwargs) --> 204 self.advance(args, kwargs) 205 self.on_advance_end() 206 self._restarting = False
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:154, in EvaluationLoop.advance(self, *args, **kwargs) 152 if self.num_dataloaders > 1: 153 kwargs["dataloader_idx"] = dataloader_idx --> 154 dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) 156 # store batch level output per dataloader 157 self._outputs.append(dl_outputs)
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/loops/base.py:204, in Loop.run(self, *args, kwargs) 202 try: 203 self.on_advance_start(*args, *kwargs) --> 204 self.advance(args, kwargs) 205 self.on_advance_end() 206 self._restarting = False
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:127, in EvaluationEpochLoop.advance(self, data_fetcher, dl_max_batches, kwargs) 124 self.batch_progress.increment_started() 126 # lightning module methods --> 127 output = self._evaluation_step(**kwargs) 128 output = self._evaluation_step_end(output) 130 self.batch_progress.increment_processed()
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:222, in EvaluationEpochLoop._evaluation_step(self, *kwargs) 220 output = self.trainer._call_strategy_hook("test_step", kwargs.values()) 221 else: --> 222 output = self.trainer._call_strategy_hook("validation_step", *kwargs.values()) 224 return output
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1763, in Trainer._call_strategy_hook(self, hook_name, *args, *kwargs) 1760 return 1762 with self.profiler.profile(f"[Strategy]{self.strategy.class.name}.{hook_name}"): -> 1763 output = fn(args, **kwargs) 1765 # restore current_fx when nested context 1766 pl_module._current_fx_name = prev_fx_name
File ~/.local/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py:344, in Strategy.validation_step(self, *args, *kwargs)
339 """The actual validation step.
340
341 See :meth:~pytorch_lightning.core.lightning.LightningModule.validation_step
for more details
342 """
343 with self.precision_plugin.val_step_context():
--> 344 return self.model.validation_step(args, **kwargs)
File ~/.local/lib/python3.8/site-packages/pytorch_forecasting/models/base_model.py:413, in BaseModel.validation_step(self, batch, batch_idx) 411 def validation_step(self, batch, batch_idx): 412 x, y = batch --> 413 log, out = self.step(x, y, batch_idx) 414 log.update(self.create_log(x, y, out, batch_idx)) 415 return log
File ~/.local/lib/python3.8/site-packages/pytorch_forecasting/models/base_model.py:547, in BaseModel.step(self, x, y, batch_idx, *kwargs) 545 loss = loss (1 + monotinicity_loss) 546 else: --> 547 out = self(x, **kwargs) 549 # calculate loss 550 prediction = out["prediction"]
File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, *kwargs) 1106 # If we don't have any hooks, we want to skip the rest of the logic in 1107 # this function, and just call forward. 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.local/lib/python3.8/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/init.py:503, in TemporalFusionTransformer.forward(self, x) 499 else: 500 output = self.output_layer(output) 502 return self.to_network_output( --> 503 prediction=self.transform_output(output, target_scale=x["target_scale"]), 504 encoder_attention=attn_output_weights[..., :max_encoder_length], 505 decoder_attention=attn_output_weights[..., max_encoder_length:], 506 static_variables=static_variable_selection, 507 encoder_variables=encoder_sparse_weights, 508 decoder_variables=decoder_sparse_weights, 509 decoder_lengths=decoder_lengths, 510 encoder_lengths=encoder_lengths, 511 )
File ~/.local/lib/python3.8/site-packages/pytorch_forecasting/models/base_model.py:336, in BaseModel.transform_output(self, prediction, target_scale) 325 """ 326 Extract prediction from network output and rescale it to real space / de-normalize it. 327 (...) 333 torch.Tensor: rescaled prediction 334 """ 335 if isinstance(self.loss, MultiLoss): --> 336 out = self.loss.rescale_parameters( 337 prediction, 338 target_scale=target_scale, 339 encoder=self.output_transformer.normalizers, # need to use normalizer per encoder 340 ) 341 else: 342 out = self.loss.rescale_parameters(prediction, target_scale=target_scale, encoder=self.output_transformer)
File ~/.local/lib/python3.8/site-packages/pytorch_forecasting/metrics.py:452, in MultiLoss.getattr.
File ~/.local/lib/python3.8/site-packages/pytorch_forecasting/metrics.py:75, in Metric.rescale_parameters(self, parameters, target_scale, encoder) 61 def rescale_parameters( 62 self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator 63 ) -> torch.Tensor: 64 """ 65 Rescale normalized parameters into the scale required for the output. 66 (...) 73 torch.Tensor: parameters in real/not normalized space 74 """ ---> 75 return encoder(dict(prediction=parameters, target_scale=target_scale))
TypeError: 'list' object is not callable`
I'm trying to create a dataset with multiple targets. In my case it looks like this:
I get errors in several places.
1. During the initialization of the TimeSeriesDataSet
My dataTypes:
If i remove the categorical target it works. Is it possible to mix categorical and continous target values?
2. During Init TemporalFusionTransformer
Error:
Message looks clear:)
At this palce in
class TemporalFusionTransformer
I'm thankful for any help Maybe there exist any example?