uw-ipd / tmol

TMol
Apache License 2.0
30 stars 3 forks source link

Constraints #305

Open jflat06 opened 1 month ago

jflat06 commented 1 month ago

You can create a constraint set and add constraints to it with:

add_constraints(your_constraint_function, atom_indices, parameters)

atom_indices: a tensor that is (constraint_ind x atom_num x (pose_ind, res_ind, atom_ind). You must supply at least 2 and at most 4 atoms in the atom_num dimension.

parameters: a (constraint_ind x 6) float tensor that you can use to store parameters on a per-constraint basis.

your_constraint_function: this must take two args, (atoms, params). atoms will be filled with a (n_constraints x 4 x 3) tensor containing atom coordinates. If you supplied less than 4 atom indices when you created the constraints, you must not use any extra atoms (they will be filled with garbage). parameters will be the parameter tensor you supplied when you added the constraints.

You can call add_constraints multiple times and they will be concatenated.

An example:

    def constfn(atoms, params):
        return params[:, 0]

    cnstr_atoms = torch.full((2, 2, 3), 0, dtype=torch.int32, device=torch_device)
    cnstr_params = torch.full((2, 6), 0, dtype=torch.float32, device=torch_device)

    cnstr_atoms[0,0] = torch.tensor([0,0,0]) # constraint 0, atom 0 - assigned to atom at pose 0, res 0, atom 0
    cnstr_atoms[0,1] = torch.tensor([0,1,1]) # constraint 0, atom 1 - assigned to atom at pose 0, res 1, atom 1
    cnstr_params[0, 0] = 2 # constraint 0, parameter 0 - assigned to 2

    cnstr_atoms[1,0] = torch.tensor([1,0,0])
    cnstr_atoms[1,1] = torch.tensor([1,1,1])
    cnstr_params[1, 0] = 4

    constraints.add_constraints(constfn, cnstr_atoms, cnstr_params)

    cnstr_atoms[0,0] = torch.tensor([0,1,0])
    cnstr_atoms[0,1] = torch.tensor([0,2,1])
    cnstr_params[0, 0] = 20

    cnstr_atoms[1,0] = torch.tensor([1,0,0])
    cnstr_atoms[1,1] = torch.tensor([1,2,1])
    cnstr_params[1, 0] = 40

    constraints.add_constraints(constfn, cnstr_atoms, cnstr_params)

This returns:

[[[[ 0.  1.  0.  0.  0.]
   [ 1.  0. 10.  0.  0.]
   [ 0. 10.  0.  0.  0.]
   [ 0.  0.  0.  0.  0.]
   [ 0.  0.  0.  0.  0.]]

  [[ 0.  2. 20.  0.  0.]
   [ 2.  0.  0.  0.  0.]
   [20.  0.  0.  0.  0.]
   [ 0.  0.  0.  0.  0.]
   [ 0.  0.  0.  0.  0.]]]]

Right now, the score for a constraint is statically split between the residues of the first and second atoms. You could utilize this convention to manipulate where the scores end up by ordering the atoms in your input atom index tensor appropriately. I'm not sure if that's ideal, but may be OK for now.

jflat06 commented 11 hours ago

@fdimaio

This should be ready for review. I've tagged a number of notes/questions in the code for some feedback.

Also, I took a look at the tests failing and it seems like something might be up with the testing server:

⚠️ Warning: Checkout failed! Error running /usr/bin/git remote set-url origin https://github.com/uw-ipd/tmol: exit status 128 (Attempt 3/3 Retrying in 2s) 🚨 Error: Error running /usr/bin/git remote set-url origin https://github.com/uw-ipd/tmol: exit status 128