locuslab / torchdeq

Modern Fixed Point Systems using Pytorch
MIT License
82 stars 10 forks source link

Fixed-point is not returned when indexing is set #7

Open BurgerAndreas opened 5 months ago

BurgerAndreas commented 5 months ago

It seems like the best fixed-point estimate z_star = lowest_xest is only returned when the indexed trajectory is empty. When one specifies indexing, they are not getting the best fixed-point estimate.

Relevant Code

From the Broyden solver

# Store the solution at the specified index
if indexing and (nstep+1) in indexing:
    indexing_list.append(lowest_xest)

# ...

# at least return the lowest value when enabling  ``indexing''
if indexing and not indexing_list:
    indexing_list.append(lowest_xest)

info = solver_stat_from_info(stop_mode, lowest_dict, trace_dict, lowest_step_dict)
return lowest_xest, indexing_list, info

Note that the best fixed-point estimate z_star = lowest_xest is ignored in DEQIndexing

_, trajectory, info = self._solve_fixed_point()

Example

If solver nstep > indexing, lowest_xest is added to trajectory. Only if nothing was added to the trajectory, lowest_xest is added. Which means that the trajectory sometimes contains the best fixed-point estimate lowest_xest and sometimes not?

Scenario 1: indexing=[8], nstep=5 -> trajectory contains fp_5 Scenario 2: indexing=[8], nstep=10 -> trajectory contains fp_8 Shouldn't the trajectory contain [fp_5, fp_8] (assuming fp_8 is the better estimate)?

Note that indexing defaults to indexing=[f_max_iter] if not specified otherwise and the best fixed-point estimate is added to the trajectory. So the problem only arises if one specifies indexing or n_states, e.g. to implement the fixed-point correction loss. It is also not a problem in DEQSliced