Closed RoyYang0714 closed 1 year ago
Callback base class does not need every_n_epochs or num_epochs, the trainer should decide, e.g. when to run the validation loop. Currently, the validation loop could be executed without running the evaluator after. Hence I'd propose moving the every_n_epochs to the trainer as in PL. For callbacks that should only be run occasionally, this could be implemented in the specific callback (e.g. visualizer) instead of the base class.
┌──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
│ Trainer │
│ │
│ ┌──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ Training loop │ │
│ │ │ │
│ │ ┌───────────────────┐ batch ┌──────────────────────┐ │ │
│ │ │ Train Data Loader ├──┬────► Train Data Connector │ │ │
│ │ └───────────────────┘ │ └──────────┬───────────┘ │ │
│ │ │ │ key map │ │
│ │ │ ┌───▼───┐ │ │
│ │ │ │ Model │ │ │
│ │ │ └───┬───┘ ┌───────────────────────────────────────────────────────────────────┐ │ │
│ │ │ │ │ LossModule │ │ │
│ │ │ │ Pred │ ┌─────────────────────────────────────────────┐ │ │ │
│ │ │ │ │ │ Loss_i │ │ │ │
│ │ │ │ │ │ │ │ │ │
│ │ │ ┌────────┴───────► │ pred─┐ ┌────────────────┐key map┌─────────┐ │ │ │ │
│ │ │ │ │ │ ├─►Loss Connector i├───────►loss op i├─┼─►loss_i x weight_i│ │ │
│ │ Trainer State ├───── │ ───────────────► │ data─┘ └────────────────┘ └─────────┘ │ │ │ │
│ │ │ │ │ │ │ │ │ │ │
│ │ │ │ │ │ └─────────────────────────────────────────────┘ │ │ │
│ │ │ │ │ │ . │ │ │
│ │ │ │ │ │ . │ │ │
│ │ │ │ │ │ . │ │ │
│ │ │ │ │ └─────────────────────────────┬─────────────────────────────────────┘ │ │
│ │ ┌────────────▼───────▼──────▼──────────┐ │ │ │
│ │ │ Callbacks │ │ │ │
│ │ │ ┌ -- -- -- -- -- -- -- - ┐ │ ▼ │ │
│ │ │ |Train Callback Connector| │ losses │ │
│ │ │ └ - -- -- -- -- -- -- -- ┘ │ │ │ │
│ │ │ │ ┌──▼──┐ │ │
│ │ └────────────────┬─────────────────────┘ │ Sum │ │ │
│ │ │ key map & key args └──┬──┘ │ │
│ │ ┌ ▼ ┐ │ │ │
│ │ ┌──────────┐ ┌─────────┐ ▼ ┌──────────┐ │ │
│ │ │Visualizer│ │Evaluator│ ...... total_loss─────► Backward │ │ │
│ │ └──────────┘ └─────────┘ └──────────┘ │ │
│ │ └ ┘ │ │
│ │ │ │
│ │ │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
Also, we should rename class_config to class_cfg or similar
Should we rename the class_config in this PR? We should also create the required PRs or add it to the roadmap for the minor missing things.
Other than that LGTM :)
Should we rename the class_config in this PR? We should also create the required PRs or add it to the roadmap for the minor missing things.
Other than that LGTM :)
I think we will have an Abbrivation PR to rename many things. Maybe let's do it there?
This PR aims to make the config more user-friendly for the release.
Config
vis4d.config
hierarchy.vis4d.config.config_dict
FieldConfigDict
and only use it for parameter links (e.g. hyper-parameters, data, output_dir…), otherwise usingml_collections.ConfigDict
. Also, make sure it is in the value_mode before instantiate_classes.DelayedInstantiator
→ work with instantiate for different cases.instantiate_classes
→ Check whether the input isConfigDict
. If the config isFieldConfig
, ensure it is in thevalue mode
. Handle theinit_args
is empty but withkwargs
to instantiate together.copy_and_resolve_references
→ Fix typing. Copy and resolve references toConfigDict
instead ofFieldConfigDict
in value mode. Handle dictionary._instantiate_classes
→ Fix typing. Handle dictionary.Loss Module
vis4d.engine.loss
tovis4d.engine.loss_module
. It maps the input key from prediction & data to each loss function input key correctly and provides loss weighting control.vis4d.op.loss
/nn.Module
.LossConnector
for key mapping.Data Connector
StaticDataConnector
,DataConnectorInfo
vis4d.engine.connectors
and separate DataConnector → Resolve Loss Module input keys, reuse connector for callbacks, and separate train and test connector (e.g. now can use DataConnector for training and MultiSensorDataConnector for inference):DataConnector
: Used for Trainer (train / test data connector).LossConnector
: Used for Loss Module.CallbackConnector
: Used for Callback (train / test connector)MultiSensorDataConnector
MultiSensorLossConnector
MultiSensorCallbackConnector
Callbacks
vis4d.engine
.show
/save_to_disk
per batch.Dataset
sample_names
andoriginal_images
for visualization.Visualize module:
Optim module
vis4d.optim
andvis4d.engine.opt
asvis4d.engine.optim
.Bug Fix