Tsinghua-MARS-Lab / StateTransformer

209 stars 18 forks source link

Diffusion decoder related 3 #109

Closed JingzheShi closed 1 year ago

JingzheShi commented 1 year ago

Add a model called scratch/pretrain-diffusion_KP_decoder_gpt: It's the same as TrajectoryGPT but uses a diffusion key point decoder to generate key points rather than a auto-regressive way. Adjustments include:

  1. diffusion_KP_model.py: a class called TrajectoryGPTFeatureGen is added. It is the same as TrajectoryGPT except it can save the keypoin_hidden_feature as the condition to train the diffusion key point decoder separately (expected to be used from a pretrained class of TrajectoryGPT). A class called TrajectoryGPTDiffusionKPDecoder in diffusion_KP_model.py: the main model for diffusion_KP_decoder_gpt.
  2. transformer4planning.diffusion_decode.py: The class for the diffusion key point decoder: the diffusion key point decoder used.
  3. runner_diffusionKPdecoder.py: used to train the diffusion key point decoder separately on the feature saved by TrajectoryGPTFeatureGen.
  4. runner.py: adjusted so it can be used as an entrance for TrajectoryGPTFeatureGen to generate and save the feature, and an entrance to train the TrajectoryGPTDiffusionKPDecoder.
  5. transformer4planning.utils: add some args to model_args.
  6. requirements.txt: einops and wandb is used in some of these codes.
  7. trainer.py: some adjustments are made for generating process: when generating and saving the features using TrajectoryGPTFeatureGen class, we do not need to evaluate any metrics since it is expected to be used with pretrained backbone.
  8. model.py: changed function of build_model so it can build the two new models added.
JingzheShi commented 1 year ago

These codes are written and expected to be used based on TrajectoryGPT model and Nuplan dataset (k=1) only.