phlippe / CategoricalNF

Official repository for "Categorical Normalizing Flows via Continuous Transformations"
https://arxiv.org/abs/2006.09790
MIT License
55 stars 11 forks source link

function __getitem__() in zinc250k.py #1

Closed ling-cai closed 3 years ago

ling-cai commented 4 years ago

I am sure why you use "nodes = nodes + (nodes == -1) # Setting padding to 0". I assume node index == -1 is for virtual nodes. if you use this, then the label for virtual nodes turns to 0, which overlays with the true node type (0) in your data. Then you cannot distinguish this node type from virtual nodes. Did I misunderstand something in your code? Appreciate it if you can address this issue for me.

phlippe commented 4 years ago

Hi, that's correct, after that line all virtual nodes have an index of 0. This is necessary as these indices are later processed by the mixture embedding layer. Still, we can distinguish virtual nodes from "real" nodes with the length that is returned as third element by the get_item function. The first #length nodes are "real", and the latter are virtual.