Closed j-cyoung closed 6 days ago
Hi! Thank you for your question! We chose not to use soft category-aware matching in the BN layer but instead applied it in the Conv layer, as calculating category-aware statistics involves significant additional computational overhead:
category_dd_var = self.category_running_dd_var_list[self.targets.long()].mean(0)
category_dd_mean = self.category_running_dd_mean_list[self.targets.long()].mean(0)
category_patch_var = self.category_running_patch_var_list[self.targets.long()].mean(0)
category_patch_mean = self.category_running_patch_mean_list[self.targets.long()].mean(0)
r_feature += (torch.norm(category_dd_var - (self.dd_var + dd_var - dd_var.detach()), 2) + \
torch.norm(category_dd_mean - (self.dd_mean + dd_mean - dd_mean.detach()), 2) + \
torch.norm(category_patch_mean - (self.patch_mean + patch_mean - patch_mean.detach()), 2) + \
torch.norm(category_patch_var - (self.patch_var + patch_var - patch_var.detach()), 2)) * 0.5
self.r_feature = r_feature
(line 493-503)
For the second question. The rationale behind this design is that such matching is more precise. For instance, attempting to fit a Gaussian mixture with three components using a single Gaussian distribution would lead to a suboptimal solution. In contrast, fitting a Gaussian mixture with three components using another Gaussian mixture with three components results in a significantly more accurate fit.
Each component of the Gaussian mixture distribution we work with represents a class. This approach enables the creation of higher-quality synthetic datasets. The use of global matching is motivated by the observation that images synthesized with global matching contain condensed (crucial) information but lack diversity. When the IPC is small, global matching plays a crucial role in the final performance during post-evaluation. However, as the IPC increases, data diversity becomes more important, making local matching (Form #2) critical.
Thanks for your detailed reply! However I'm still confused about the code realization of Soft Category-Aware Matching. As described in the paper, we matching $\mu$ and $\sigma$ for each category.
Therefore, the code implementation I have in mind is as follows:
# ignoring SDS-like trick for clarity
r_feature += torch.sum([torch.norm(self.category_running_dd_mean_list[i] - dd_mean[self.targets==i]) for i in range(num_classes)])
r_feature += torch.sum([torch.norm(self.category_running_dd_var_list[i] - dd_var[self.targets==i]) for i in range(num_classes)])
However, the code in EDC seems just using the overall mean feature of all classes to match:
mean = input_0.mean([0, 2, 3]) # mean feature of all distilled images
...
category_dd_mean = self.category_running_dd_mean_list[self.targets.long()].mean(0) # mean feature of each class's mean feature
...
Perhaps I misunderstood the meaning in your paper or missed some important information. I would be very grateful if you could provide further clarification.
I have some questions regarding two points from your previous response.
First, both the BNFeatureHook
and ConvFeatureHook
modules seem to use the Soft Category-Aware Matching mechanism, so I’m wondering why you mentioned that this mechanism is only applied in the Conv module.
Second, while GMM is used for modeling, with each category representing a Gaussian component, why the final matching ultimately performed using the overall mean of all classes instead of individual component means of each class?
category_dd_var = self.category_running_dd_var_list[self.targets.long()].mean(0) # mean of all classes
category_dd_mean = self.category_running_dd_mean_list[self.targets.long()].mean(0) # mean of all classes
I would greatly appreciate any clarification you could provide, and thank you very much for taking the time.
The reason for this format:
mean = input_0.mean([0, 2, 3]) # mean feature of all distilled images
category_dd_mean = self.category_running_dd_mean_list[self.targets.long()].mean(0) # mean feature of each class's mean feature
There are some differences between the implementation of this part and the one described in the article. In the article, the loss for each class is computed first, followed by calculating the mean. Here, however, the mean of the target statistic is computed first, and then the loss is calculated. The reason for using the latter approach is that it is significantly faster. Additionally, we ensure that the image synthesis process traverses each batch class by class. For instance, if the batch size is 50, there are 5 classes—0, 1, 2, 3, and 4—with 10 images per class, reaching the theoretical minimum number of classes per batch. In our formal experiments, the batch size per single GPU was set to 20, which ensures that IPC is consistent with the paper as long as it is greater than or equal to 20.
In response to this question:
First, both the BNFeatureHook and ConvFeatureHook modules appear to use the Soft Category-Aware Matching mechanism, which raises the question of why this mechanism was mentioned as being applied only in the Conv module.
We attempted adding Soft Category-Aware Matching to both modules, but the improvement over adding it solely to the ConvFeatureHook was negligible. The slight performance increase of approximately 0.X% did not justify the additional computational cost. Simply adding Soft Category-Aware Matching to the ConvFeatureHook is sufficient.
I really appreciate your response. It has been very helpful! Thank you so much for your reply and for helping to clarify my questions.
As described in the paper, EDC uses soft category-aware matching by comparing features mean of different category. However, the code in
Branch_full_ImageNet_1k/recover/utils.py
seems not compute the mean of distilled by category.I’m having a bit of trouble understanding the rationale behind using
category_dd_mean
together with the overall mean of the distilled datamean
when calculating the loss. Could you perhaps explain the reasoning behind this design choice? Thanks so much!