microsoft / otdd

Optimal Transport Dataset Distance
MIT License
156 stars 48 forks source link

Cov is nan in the flow example in readme #21

Open ChenChengKuan opened 2 years ago

ChenChengKuan commented 2 years ago

Hello When I run the example in the reademe. I got the following error after fixing issues mentioned in #20 .
After tracing the code, I found it is caused by Nan happened in ln:383 compute_label_stats in flow.py.
I tried to use using eigen_correction='constant' to ensure the PSDness of cov matrix but it didn't work. A workaround I found helpful is setting diagonal_cov = True.

I would like to confirm whether setting diagnoal_cov = True is a valid way to deal with the issue. A side question: xonly, xonly-attached, and xyaug are corresponding to fd, jd-fl, and jd-vl in the paper? Thanks for your time in advance.

TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_48265/2326306459.py in <module>
     22                           device='cpu'
     23                           )
---> 24 d,out = flow.flow()

~/Desktop/kuan/otdd/otdd/pytorch/flows.py in flow(self, tol)
    477             pbar.set_description(f'Flow Step {iter}/{len(self.times)}, F_t={obj:8.2f}')
    478             self.callback.on_step_begin(self.otdd, iter)
--> 479             obj = self.step(iter)
    480             logger.info(f't={t:8.2f}, F(a_t)={obj:8.2f}') # Although things have been updated, this is obj of time t still
    481             self.history.append(obj)

~/Desktop/kuan/otdd/otdd/pytorch/flows.py in step(self, iter)
    454         if self.otdd.inner_ot_method != 'exact':
    455             logger.info('Performing stats update...')
--> 456             self.stats_update()
    457 
    458         if self.compute_coupling == 'every_iteration':

~/Desktop/kuan/otdd/otdd/pytorch/flows.py in stats_update(self)
    393                                         )
    394             if torch.isnan(self.otdd.Covs[0]).any():
--> 395                 pdb.set_trace(header='Nans in Cov Matrices')
    396 
    397 

TypeError: set_trace() got an unexpected keyword argument 'header'