xingyuuchen / tri-depth

[WACV 2023] Self-Supervised Monocular Depth Estimation: Solving the Edge-Fattening Problem
GNU General Public License v3.0
78 stars 6 forks source link

How can I plug the triplet loss you proposed into my own depth estimation model? #8

Open breknddone opened 1 year ago

breknddone commented 1 year ago

RT,thanks

xingyuuchen commented 1 year ago

RT,thanks

Hello, @breknddone ,

Thanks for your interests in our work!

To plug the triplet loss into your own model, you can refer to the code comment! The pluging is really simple!

This git commit is about the guidence (I added two lines of comments in the code which you can follow).

breknddone commented 1 year ago

Hello, @xingyuuchen , I added the two functions compute_sgt_loss and compute_affinity, but during training, the error is reported as follows:

Traceback (most recent call last): File "train.py", line 74, in main() File "train.py", line 70, in main trainer.fit(model, loader) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 609, in fit self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl self._run(model, ckpt_path=self.ckpt_path) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run results = self._run_stage() File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage self._run_train() File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1214, in _run_train self.fit_loop.run() File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(*args, *kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance self._outputs = self.epoch_loop.run(self._data_fetcher) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(args, kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 213, in advance batch_output = self.batch_loop.run(kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(*args, kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance outputs = self.optimizer_loop.run(optimizers, kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(*args, *kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 202, in advance result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position]) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 249, in _run_optimization self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 379, in _optimizer_step using_lbfgs=is_lbfgs, File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1356, in _call_lightning_module_hook output = fn(args, kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/core/module.py", line 1754, in optimizer_step optimizer.step(closure=optimizer_closure) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 169, in step step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py", line 235, in optimizer_step optimizer, model=model, optimizer_idx=opt_idx, closure=closure, kwargs File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 119, in optimizer_step return optimizer.step(closure=closure, kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/torch/optim/lr_scheduler.py", line 68, in wrapper return wrapped(*args, *kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/torch/optim/optimizer.py", line 140, in wrapper out = func(args, kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/torch/optim/optimizer.py", line 23, in _use_grad ret = func(self, *args, kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/torch/optim/adam.py", line 183, in step loss = closure() File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 105, in _wrap_closure closure_result = closure() File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 149, in call self._result = self.closure(*args, kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 135, in closure step_output = self._step_fn() File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 419, in _training_step training_step_output = self.trainer._call_strategy_hook("training_step", kwargs.values()) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1494, in _call_strategy_hook output = fn(args, kwargs) File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py", line 378, in training_step return self.model.training_step(*args, kwargs) File "/home/user/code/RNW/models/mono2.py", line 54, in training_step disp_loss_dict = self.compute_disp_losses(batch_data, outputs) File "/home/user/code/RNW/models/mono2.py", line 203, in compute_disp_losses sgt_loss = self.compute_sgt_loss(inputs, outputs) File "/home/user/code/RNW/models/mono2.py", line 88, in compute_sgt_loss seg_target = inputs['seg', 0, 0] KeyError: ('seg', 0, 0)

It seems that this answer cannot be achieved. My data set is not KITTI, but RobotCar and Nuscenes, so I don’t understand how to deal with this situation. Hope to get your answer, thank you!

breknddone commented 1 year ago

Here is the preprocessing code for RobotCar and Nuscenes datasets.

xingyuuchen commented 1 year ago

Hello, @xingyuuchen , I added the two functions compute_sgt_loss and compute_affinity, but during training, the error is reported as follows:

Traceback (most recent call last):

File "train.py", line 74, in

main()

File "train.py", line 70, in main

trainer.fit(model, loader)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 609, in fit

self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt

return trainer_fn(*args, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl

self._run(model, ckpt_path=self.ckpt_path)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run

results = self._run_stage()

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage

self._run_train()

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1214, in _run_train

self.fit_loop.run()

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run

self.advance(*args, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance

self._outputs = self.epoch_loop.run(self._data_fetcher)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run

self.advance(*args, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 213, in advance

batch_output = self.batch_loop.run(kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run

self.advance(*args, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance

outputs = self.optimizer_loop.run(optimizers, kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run

self.advance(*args, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 202, in advance

result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 249, in _run_optimization

self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 379, in _optimizer_step

using_lbfgs=is_lbfgs,

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1356, in _call_lightning_module_hook

output = fn(*args, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/core/module.py", line 1754, in optimizer_step

optimizer.step(closure=optimizer_closure)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 169, in step

step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py", line 235, in optimizer_step

optimizer, model=model, optimizer_idx=opt_idx, closure=closure, **kwargs

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 119, in optimizer_step

return optimizer.step(closure=closure, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/torch/optim/lr_scheduler.py", line 68, in wrapper

return wrapped(*args, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/torch/optim/optimizer.py", line 140, in wrapper

out = func(*args, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/torch/optim/optimizer.py", line 23, in _use_grad

ret = func(self, *args, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/torch/optim/adam.py", line 183, in step

loss = closure()

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 105, in _wrap_closure

closure_result = closure()

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 149, in call

self._result = self.closure(*args, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 135, in closure

step_output = self._step_fn()

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 419, in _training_step

training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1494, in _call_strategy_hook

output = fn(*args, **kwargs)

File "/home/user/.conda/envs/rnw/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py", line 378, in training_step

return self.model.training_step(*args, **kwargs)

File "/home/user/code/RNW/models/mono2.py", line 54, in training_step

disp_loss_dict = self.compute_disp_losses(batch_data, outputs)

File "/home/user/code/RNW/models/mono2.py", line 203, in compute_disp_losses

sgt_loss = self.compute_sgt_loss(inputs, outputs)

File "/home/user/code/RNW/models/mono2.py", line 88, in compute_sgt_loss

seg_target = inputs['seg', 0, 0]

KeyError: ('seg', 0, 0)

It seems that this answer cannot be achieved. My data set is not KITTI, but RobotCar and Nuscenes, so I don’t understand how to deal with this situation.

Hope to get your answer, thank you!

The KeyError is because you tried to load the key that did not exist. You can debug the input loading process step by step using e.g. pycharm, to check at which step your segmentation is not loaded into the form that you requested, i.e. ('seg', 0, 0), or, probably worse, not loaded at all.

breknddone commented 1 year ago

The KeyError is because you tried to load the key that did not exist. You can debug the input loading process step by step using e.g. pycharm, to check at which step your segmentation is not loaded into the form that you requested, i.e. ('seg', 0, 0), or, probably worse, not loaded at all.

I checked my code and it was indeed not loaded. How should I modify it?

xingyuuchen commented 10 months ago

The KeyError is because you tried to load the key that did not exist. You can debug the input loading process step by step using e.g. pycharm, to check at which step your segmentation is not loaded into the form that you requested, i.e. ('seg', 0, 0), or, probably worse, not loaded at all.

I checked my code and it was indeed not loaded. How should I modify it?

In case that you have other datasets, you must customize the dataloader class, to meet the specific data format of that dataset, like I did in this directory (there are different dataset.py file used for different datasets).