Open eahenle opened 1 year ago
Edit: this particular example actually seems like it's a rough edge between Flux
and Pluto
, and not actually an issue with GeometricFlux
Another example:
model = Chain(
GCNConv(1024=>256, relu),
x -> (node_feature(x), global_feature(x)),
(nf, gf) -> (softmax(nf), identity.(gf))
) # this works
Implicit expectation:
model(some_featuredgraph) # ❌ this should work, but does not
Actual result: UndefVarError: node_feature not defined
. This is very confusing, because node_feature(some_featuredgraph)
works without issue in the same Pluto notebook.
Defining a function get_nf_gf
solves this:
get_nf_gf(fg) = (node_feature(fg), global_feature(fg))
model = Chain(
GCNConv(1024=>256, relu),
get_nf_gf,
(nf, gf) -> (softmax(nf), identity.(gf))
)
model(some_featuredgraph) # ✔️
In the documentation, code is given that does not run. The problem is that the node feature matrix is transposed vs. what it should be.
Example given:
Correct code: