MachineLearningLifeScience / stochman

Algorithms for computations on random manifolds made easier
Apache License 2.0
85 stars 11 forks source link

[Improvement] Ensure the graph is connected before connecting geodesics #28

Open a-pouplin opened 1 year ago

a-pouplin commented 1 year ago

Is your feature related to a problem? When plotting geodesics, calling connecting_geodesic, the function assumes that the graph is connected. When not connected, networks.shortest_path throws after some time an error that can be seen as cryptic by the users.

Describe the solution you would like: Either when discretising the manifold, or before plotting the geodesics, a warning message can be added to check that the graph is connected: assert nx.is_connected(graph), "Graph not connected". Additional guidance or best practices can be added to ensure the graph is connected.

a-pouplin commented 1 year ago

Another improvement would be to check if the size are as expected. For example, in DiscretizedManifold.fit() function:

with torch.no_grad():
   weight = model.curve_length(line(t))
   assert weight.shape == bs, f"model.curve_length should return a {bs} shape object but found {weight.shape}."
a-pouplin commented 1 year ago

Another question regarding DiscretizedManifold.fit(): a graph is created based on two points obtained from a curve:

t = torch.linspace(0, 1, 2)

(...)

with torch.no_grad():
   weight = model.curve_length(line(t))

and this method mainly relies on giving a metric tensor to compute the graph (curve_length depends on inner_product which depends on metric). Yet, when the metric tensor is not easily accessible, once might want to compute the curve lenght based on the derivatives ($\dot{\gamma}$) of the curve: $L[\gamma] = \int \dot{\gamma_t} \ dt$. Derivatives can be nicely computed only if the curve is discretised enough.

The question is: would it make sense to add an argument (ex: num_curve_points) to discretise the curve and use the derivatives to compute the expected metric (or a Finsler metric for example)?

def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise=0.0, num_curve_points=2):

   (...)

   t = torch.linspace(0, 1, num_curve_points)