Aaditya-Prasad / consistency-policy

[RSS 2024] Consistency Policy: Accelerated Visuomotor Policies via Consistency Distillation
https://consistency-policy.github.io/
MIT License
94 stars 7 forks source link

CTM Loss Not Actually Used in the Student Training? #4

Closed Zhendong-Wang closed 1 month ago

Zhendong-Wang commented 1 month ago

Refer to the code here: https://github.com/Aaditya-Prasad/consistency-policy/blob/4cc328ebca2299f25c7b96c8bc04866d69cfe210/consistency_policy/student/ctm_policy.py#L482C13-L495C58

# t -> s
pred = self._forward(self.model, noise_traj, times, stops, 
                        local_cond=local_cond, global_cond=global_cond)

# u -> s
with torch.no_grad():
    target = self._forward(self.model_ema, u_noise_traj, u_times, stops,
                        local_cond=local_cond, global_cond=global_cond)

with torch.no_grad():
    start = torch.tensor([self.noise_scheduler.time_min], device = trajectory.device).expand(times.shape)

    pred = self._forward(self.model_ema, pred, stops, start, 
                        local_cond=local_cond, global_cond=global_cond)

    target = self._forward(self.model_ema, target, stops, start,
                        local_cond=local_cond, global_cond=global_cond)

loss = Huber_Loss(pred, target, delta = self.delta, weights=weights)

total_loss["ctm"] = loss * self.losses["ctm"]

The pred and target are both insidetorch.no_grad(), and this loss variable is not going to have grad_fn for backpropagation.

When I run the code, in line https://github.com/Aaditya-Prasad/consistency-policy/blob/4cc328ebca2299f25c7b96c8bc04866d69cfe210/configs/ctmp_square.yaml#L82 If I simply change it to losses: [["ctm"], [1]], it will raise no grad_fn error.

Zhendong-Wang commented 1 month ago

The error could also be easily reproduced by

import torch
import torch.nn as nn
import torch.nn.functional as F

dim = 10
x = torch.randn((1, dim))
y = torch.zeros_like(x)

model = nn.Linear(dim, dim)
model_ema = nn.Linear(dim, dim)

pred = model(x)
print(f'pred: {pred}')
with torch.no_grad():
    pred = model_ema(pred)
print(f'pred: {pred}')
loss = F.mse_loss(pred, y)
print(f'loss inside of no_grad: {loss}')
loss.backward()
print(f'backward successful!')
>>> python test_grad.py 
pred: tensor([[-0.2391,  0.0476, -1.7864,  0.4620, -0.3655,  1.0930, -0.1256,  0.9729,
          0.6107, -1.2118]], grad_fn=<AddmmBackward0>)
pred: tensor([[ 0.1322, -0.6358,  0.1975,  0.2054, -0.7371,  0.4366, -0.5071,  0.1329,
          0.3269, -0.1779]])
loss inside of no_grad: 0.16503188014030457
Traceback (most recent call last):
  File "/home/zhendongw/research/fast_dp_robotics_dev/test_grad.py", line 19, in <module>
    loss.backward()
  File "/home/zhendongw/miniconda3/envs/umi-gpu/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/home/zhendongw/miniconda3/envs/umi-gpu/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Aaditya-Prasad commented 1 month ago

yes, I (actually Kevin) realized this as well, and fixed this in the recent commit.

I am unsure when this happened, because I actually ablated the ctm loss (by running just CTM with weight 1, no DSM) back when I was producing results, and back then it performed fine. Thus I am also re-running Consistency Policy on our sim tasks to check performance.

Aaditya-Prasad commented 1 month ago

see https://github.com/Aaditya-Prasad/consistency-policy/commit/3faa563091defb01dfc99a5fdeb5423e4cecaa33

Zhendong-Wang commented 1 month ago

Thanks for the response! Then could you help verify the performance of the new code? Thanks!

Aaditya-Prasad commented 1 month ago

Yes. I am distilling Consistency Policy again, starting with the Square and ToolHang tasks since these were the hardest. This will take some time since I am also training the teachers from scratch again.

I am evaluating the same way as in our paper (take the highest performing checkpoint and evaluate it 200x on a fixed seed) and will confirm if performance matches/exceeds what we reported.

Aaditya-Prasad commented 1 month ago

Testing has been taking some time because I've been having issues on the cluster I use which forces me to restart runs.

On square, 50 epochs of EDM -> 300 epochs of Consistency Policy training led to 91% test mean score and 100% train mean score over 200 evaluations. This was with distillation weight 8 and dropout probability of .3. Increasing the distillation loss weight and adding more regularization (or just training for longer) would likely improve performance, but I did not keep testing this because I didn't do much tuning for the Diffusion Policy baselines either.

For comparison, training EDM or DDiM for 250 epochs each and sampling with 1-step (which means 2 forward passes for EDM using the Heun scheduler) nets both 0% success rates. I would report the 350 epoch results (to match epochs with the Consistency Policy student + teacher) but those runs haven't gotten there yet.

The toolhang teacher is still training.

Zhendong-Wang commented 1 month ago

Thanks for the update!!! Could you share your config of training pusht, square, tool_hang, and transport? Will you share your pretrained checkpoints? I found that these tasks have big variance in evalution, and changing a seed sometimes make the performance change. I mean like 200 sometimes is not big enough.

BTW, do you know whether Consistency Distillation work in DDPM case?

Aaditya-Prasad commented 1 month ago

I've found this variance as well -- let me think about the best way to do this. Few points:

  1. We did not train or test transport. We trained/tested lift, can, square, tool hang, pusht, kitchen.
  2. I didn't do a good job with keeping old runs/results/configs; this was completely my mistake and part of the reason for this whole issue. This was my first real research project and I'm carrying lessons from this forward.
  3. Policy configs did not vary too much between tasks. I tuned configs for Square and ToolHang since these were the hardest tasks, and I used the same policy values for the other tasks.
  4. I'll upload the teacher, baseline, and student ckpts + cfgs for Square (which I just trained and reported results for). I'll do the same for toolhang as it gets done. I'll also try to evaluate toolhang on more rounds (maybe 500 instead of 200); this will likely take a while though since this is the longest task. If needed, I can retrain policies for the other tasks, but again I think these are the most indicative.

DDPM adds stochasticity to the backwards diffusion process by integrating little bits of noise into each denoising step. This breaks Consistency Distillation, since it requires not only the same marginal distribution at t=0 (which DDPM satisfies) but also deterministic trajectories between t=T and t=0 (which DDPM breaks).

Aaditya-Prasad commented 1 month ago

For now, this is a link to a zip file containing teacher, baseline, and student ckpts and cfgs for square: https://drive.google.com/file/d/1jG7HjDPu4qXs0tMRESTKJxeWpqgR-mL6/view?usp=drive_link.

Note that I was incorrect above about the student dropout and distillation weight; the config in the zip is correct.

I was hoping to use git lfs for this but the checkpoints are over the storage limit. When I have time, I'll pick a more permanent solution like adding these to the website.

Zhendong-Wang commented 1 month ago

Thanks for providing these!!! I will have a look.

Zhendong-Wang commented 1 month ago

Hey Aaditya,

I am trying to reproduce Consistency Policy from my end, and I first tried to pretrain the teacher model. This is the curve that I obtained. They all kind work much worse than the reported values and the values in Diffusion Policy. Especially for tool_hang, it should reach somewhere about 0.83 success rate but now it is 0.46. Do you have insights on this? I used the config configs/edm_square.yaml and configs/edm_th.yaml for square and tool_hang, and copy the setting of square to other tasks, such as pusht, square_mh, transport_mh and transport_ph. I also carefully check the crop size of images to be consistent with the previous Diffusion Policy configs.

Screenshot 2024-07-26 at 8 42 28 AM
Aaditya-Prasad commented 1 month ago

When you say the values in Diffusion Policy, are you referring to the DDPM solver? Diffusion Policy has no reported results for the EDM solver. Diffusion Policy reports a .5 success rate on tool hang with their testing methodology, while we report a .79 success rate (again, for DDPM). However, we did get EDM to ~.8 success rate (at 50 evaluations, not a rigorous test).

On EDM success rate:

The main thing you should do is train for longer. Here are the success curves from an older edm run, again with test_mean_score given by 50 evaluations. Training took a long time, you can see that we only reached ~.8 success rate after 300-350 epochs. *edited, see note at bottom

image

I'm currently checking distillation with a 100 epoch EDM teacher (which achieved 0.0 success rate with 1 bin, since I wanted to sanity check EDM-only performance). I'll keep this thread updated with what length of training was required for a teacher that led to good distillation performance; I've found that toolhang requires a lot of time to train both the teacher and the student.

Also, just wanted to note that I used a 50 epoch EDM teacher to distill the Square task and that seemed to perform well, though of course teachers trained for longer might help by a few percentage points.

*I originally thought that run used 100 bins instead of 80, this was incorrect, it used 80 bins with a 3rd order solver, i.e. 3 forward passes per bin (compared to 2 forward passes for Heun). This might have contributed to the higher success rates. You can try this if you'd like, though I'd recommend just training for longer + testing distillation.

Zhendong-Wang commented 1 month ago

Thanks for sharing the curve! All I am now doing is just to reproduce the results from the Consistency Policy paper. I used all the configs provided from the Github (the edm_th.yaml is by default trained by 400 epochs), usually that should reproduce the table results. For example, for tool_hang, it should be around 0.79 ,and I agree there might be some variance. Screenshot 2024-07-29 at 10 29 58 AM

As for the values of Diffusion Policies, I mean this table, with 0.73 on tool_hang, 0.84 on square_mh and 0.69 on transport_mh. The peak performance value might not useful due to they even use less number seeds like 28. If EDM works worse than the original DDPM, then there is no reason we change to EDM as the teacher models. You could see from my previous figure, the transport_mh and square_mh of EDM are also working bad. Screenshot 2024-07-29 at 10 31 16 AM

Aaditya-Prasad commented 1 month ago

EDM is used as the teacher model because it learns the PFODE, i.e. it doesn't inject noise into the backwards diffusion process. DDPM is not suitable as a teacher model because it does not have deterministic trajectories. We make no claim in the paper to EDM's performance as compared to DDPM; EDM is only introduced so we can use distillation.

We do make the claim that training a teacher EDM and then distilling into a Consistency Policy provides single and 3-step results as in the table. Note that this is seperate from the success rates you will see from EDM: EDM's training is seperate from its sampler. To make teacher training faster, I have been training "1=step EDM's", which will reach 0.0 success rate even after 200-300 epochs. You can increase the number of bins, the order of the sampler, or other noise scheduler hyperparameters to get higher success rates out of EDM, which can be helpful when you're comparing different checkpoints to distill from.