Open a-pouplin opened 1 year ago
Another improvement would be to check if the size are as expected. For example, in DiscretizedManifold.fit()
function:
with torch.no_grad():
weight = model.curve_length(line(t))
assert weight.shape == bs, f"model.curve_length should return a {bs} shape object but found {weight.shape}."
Another question regarding DiscretizedManifold.fit()
: a graph is created based on two points obtained from a curve:
t = torch.linspace(0, 1, 2)
(...)
with torch.no_grad():
weight = model.curve_length(line(t))
and this method mainly relies on giving a metric tensor to compute the graph (curve_length
depends on inner_product
which depends on metric
). Yet, when the metric tensor is not easily accessible, once might want to compute the curve lenght based on the derivatives ($\dot{\gamma}$) of the curve: $L[\gamma] = \int \dot{\gamma_t} \ dt$. Derivatives can be nicely computed only if the curve is discretised enough.
The question is: would it make sense to add an argument (ex: num_curve_points) to discretise the curve and use the derivatives to compute the expected metric (or a Finsler metric for example)?
def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise=0.0, num_curve_points=2):
(...)
t = torch.linspace(0, 1, num_curve_points)
Is your feature related to a problem? When plotting geodesics, calling
connecting_geodesic
, the function assumes that the graph is connected. When not connected,networks.shortest_path
throws after some time an error that can be seen as cryptic by the users.Describe the solution you would like: Either when discretising the manifold, or before plotting the geodesics, a warning message can be added to check that the graph is connected:
assert nx.is_connected(graph), "Graph not connected"
. Additional guidance or best practices can be added to ensure the graph is connected.