Closed sanmulab closed 2 years ago
step1: pre-compute the argsort for LVIS dataset
def get_argsort():
instance_count = np.array(get_instance_count())
argsort = np.argsort(instance_count)[::-1]
return argsort
step 2: save the collected gradients (in EQLv2 loss's attributes) when training is done step3: load the data then plot it
import matplotlib as plt
plt.plot(np.arange(argsort.shape[0]), pos_neg_grad[argsort])
Thank the author! I probably know the process of gradient ratio visualization. One question is, how should I save pos neg grad information, print it into the log? It would be better if you could provide a complete code file.
You could write a hook like this
@HOOKS.register_module()
class GradientCollectHook(Hook):
def __init__(self, output_dir='grads'):
super().__init__()
self.output_dir = output_dir
self.rank, self.world_size = get_dist_info()
os.makedirs(output_dir, exist_ok=True)
logger = get_root_logger()
logger.info(f"set up {self.__class__.__name__}")
def before_run(self, runner):
self.loss_module = runner.model.module.roi_head.bbox_head.loss_cls
def after_run(self, runner):
logger = get_root_logger()
logger.info("after train: dump gradient statistics")
save_dict = {
'pos_grad': self.loss_module.pos_grad.cpu().numpy(),
'neg_grad': self.loss_module.neg_grad.cpu().numpy()
}
with open(f'{self.output_dir}/grad_rank{self.rank}.pkl', 'wb') as f:
pickle.dump(save_dict, f)
dist.barrier()
if self.rank == 0:
all_data = {
'pos_grad': [],
'neg_grad': []
}
for i in range(self.world_size):
with open(f'{self.output_dir}/grad_rank{self.rank}.pkl', 'rb') as f:
data = pickle.load(f)
all_data['pos_grad'].append(data['pos_grad'])
all_data['neg_grad'].append(data['neg_grad'])
all_data['pos_grad'] = np.sum(np.stack(all_data['pos_grad']), axis=0)
all_data['neg_grad'] = np.sum(np.stack(all_data['neg_grad']), axis=0)
with open(f'{self.output_dir}/grad_all.pkl', 'wb') as f:
pickle.dump(all_data, f)
dist.barrier()
logger.info('dump grad data finished')
Hope this would answer your question.
OK, thanks.
Hello, author! I want to know how to draw a gradient ratio line graph for each category. Can you easily provide relevant visual code? I also want to prove the effectiveness of the method through gradient ratio, but I don't know how to visualize it. Thank you for telling me!