Deci-AI / super-gradients

Easily train or fine-tune SOTA computer vision models with one open source training library. The home of Yolo-NAS.
https://www.supergradients.com
Apache License 2.0
4.52k stars 489 forks source link

How to use the knowledge distillation on yolo_nas_pose? #1926

Closed adam0217 closed 5 months ago

adam0217 commented 5 months ago

💡 Your Question

I follows the "how_to_use_knowledge_distillation_for_classification.ipynb" and try to distill the model. However, I get the error like this 擷取

The following is my code:

trainer = Trainer(experiment_name="yolonaspose_test1", ckpt_root_dir=checkpoint_dir)
train_dataloader = dataloaders.get("coco2017_pose_train", dataloader_params={"batch_size": 4, "num_workers": 2,'drop_last': True,'shuffle': True}, dataset_params={"target_generator":DEKRTargetsGenerator})
val_dataloader = dataloaders.get("coco2017_pose_val", dataloader_params={"batch_size": 4, "num_workers": 2,'drop_last': True,'shuffle': True}, dataset_params={"target_generator":DEKRTargetsGenerator})
pretrained_model = models.get('yolo_nas_pose_l', pretrained_weights="coco_pose").cuda()
student_model = models.get('yolo_nas_pose_s',  pretrained_weights="coco_pose").cuda()

kd_params = {
    "max_epochs": 3,         
    'lr_cooldown_epochs': 0, 
    'lr_warmup_epochs': 0,    
    "loss": KDLogitsLoss(distillation_loss_coeff=0.8, task_loss_fn=CrossEntropyLoss()),
    "loss_logging_items_names": ["Loss", "Task Loss", "Distillation Loss"]}

training_params = training_hyperparams.get("coco_kd_test", overriding_params=kd_params)
experiment_name = "kd_coco_ltos"

kd_trainer = KDTrainer(experiment_name=experiment_name, ckpt_root_dir=checkpoint_dir)
kd_trainer.train(training_params=training_params,
             student=student_model,
             teacher=pretrained_model,
             kd_architecture="kd_module",
             train_loader=train_dataloader, valid_loader=val_dataloader)

And the yaml file:

  defaults:
    - training_hyperparams: coco2017_yolo_nas_pose_train_params
    - dataset_params: coco_pose_estimation_yolo_nas_mosaic_heavy_dataset_params
    - checkpoint_params: default_checkpoint_params
    - _self_
    - variable_setup

  train_dataloader: coco2017_pose_train
  val_dataloader: coco2017_pose_val

  resume: True
  training_hyperparams:
    resume: ${resume}
    loss: KDLogitsLoss
    criterion_params:
      distillation_loss_coeff: 0.8
      task_loss_fn:
        _target_: super_gradients.training.losses.label_smoothing_cross_entropy_loss.CrossEntropyLoss

  arch_params:
    teacher_input_adapter:
      _target_: super_gradients.training.utils.kd_trainer_utils.NormalizationAdapter
      mean_original: [0.485, 0.456, 0.406]
      std_original: [0.229, 0.224, 0.225]
      mean_required: [0.5, 0.5, 0.5]
      std_required: [0.5, 0.5, 0.5]

  student_arch_params:
    num_classes: ${dataset_params.num_joints}

  teacher_arch_params:
    num_classes: ${dataset_params.num_joints}

  teacher_checkpoint_params:
    load_backbone: False # whether to load only backbone part of checkpoint
    checkpoint_path: https://sghub.deci.ai/models/yolo_nas_l_coco.pth
    strict_load: # key matching strictness for loading checkpoint's weights
      _target_: super_gradients.training.sg_trainer.StrictLoad
      value: key_matching
    pretrained_weights: coco_pose

  checkpoint_params:
    teacher_pretrained_weights: coco_pose

  student_checkpoint_params:
    load_backbone: False # whether to load only backbone part of checkpoint
    checkpoint_path: https://sghub.deci.ai/models/yolo_nas_s_coco.pth
    strict_load: # key matching strictness for loading checkpoint's weights
      _target_: super_gradients.training.sg_trainer.StrictLoad
      value: key_matching
    pretrained_weights: coco_pose # a string describing the dataset of the pretrained weights (for example "imagenent").

  run_teacher_on_eval: True
  experiment_name: kd_coco_ltos
  multi_gpu: DDP
  num_gpus: 1

  architecture: kd_module
  student_architecture: yolo_nas_pose_s
  teacher_architecture: yolo_nas_pose_l

Versions

No response

shaydeci commented 5 months ago

There's no support for a specifc loss for pose estimation (you are using the standard cross entropy which cannot be used for that task as the outputs are in a different format then in classification). Contributions are always welcome.