gnina / libmolgrid

Comprehensive library for fast, GPU accelerated molecular gridding for deep learning workflows
https://gnina.github.io/libmolgrid/
Apache License 2.0
145 stars 48 forks source link

Question: access or enforce grid center in Gridmaker forward() #47

Closed jpjanet closed 4 years ago

jpjanet commented 4 years ago

Hi,

This seems like a really neat tool. I have been experimenting along the lines of

maker  = molgrid.GridMaker(resolution=res, dimension=dimension)
provider = molgrid.ExampleProvider()
provider .populate('my file')
examples = provider .next_batch(batch_size)
_ = maker.forward(examples , input_tensor)

Might be a silly question but is it possible to obtain the center of the grid created with the forward action each step, and/or to enforce a consistent grid center across batches?

From the documentation I believe I might be able to implement this with Transform object, but I can't quite get the datatypes to match up passing molgrid.Transform() and example batches to maker.forward().

Any advice is greatly appreciated!

Best, JP

dkoes commented 4 years ago

The batched version of forward dynamically computes a distinct center for each example in the batch based on the centroid of the last coordinate set and applies a random transformation (if requested) to each example.

If you want to control the center/transformation, you will have to loop over the batch and call forward on each example individually. For example: gmaker.forward(ex, molgrid.Transform(center), mgridout.cpu())

If you just want to know what center is being used, call the center method on each example's last coordinate set: ex.coord_sets[-1].center()