Open InfluenceFunctional opened 1 year ago
This is still work in progress, but a lot of the recent changes have been motivated by the need for conditional generation. One example is the Batch class, where you can see that some things are handled differently for conditional environments.
@AlexandraVolokhova: does the molecule environment include already the possibility to incorporate conditions?
I'm planning to work on this probably all day today 1) molecular crystal environment 2) graph-conditioned policy 3) conditional data loading & management 4) molecular crystal proxy/oracle.
Still struggling to push to remote branch. I had been assigning conditions to envs in the following way:
envs = [self.env.copy().reset(idx).set_condition(train_loader.dataset[idx]) for idx in range(n_envs)]
I will look at Batch and see if that makes more sense.
I have just sent you an invitation to join the repo as collaborator. I think that will solve the issue with pushing to remote branch.
Got it and it's working. See my in-progress here.
Doing it in the rough for now to help me understand the working pieces.
Unless I'm missing it I don't see anywhere a method for conditional generation. For my purposes it would be 1) load conditions in batches from a dataloader 2) assign each env to a condition 3) encode the condition (via a graph model - the condition in this case is a molecule and the conditions encoding is a vector) 4) concatenate conditions encoding to GFN policy input 5) train as normal
If possible it would also be ideal if the conditioning model could be updated during training along with the policy model, though I could probably find a way to pretrain one which is at least 'ok' if necessary. I haven't read deep enough into the conditional gflownet work to know what's optimal here. In my case, the distribution of high-scoring samples is both very sharp and extremely sensitive to the conditions.
In evaluation mode, for speed, we could call the conditioning model once and use the same encoding at all generation steps. For training, particularly if the conditioning model is being updated, probably fine to call it with the policy at every action.
I have started playing with this locally but don't want to conflict with any planned format.