proteneer / timemachine

Differentiate all the things!
Other
138 stars 17 forks source link

Test additional MCS pruning heuristics if `connected_core = True` #1310

Closed proteneer closed 2 months ago

proteneer commented 2 months ago

This PR adds an additional pruning heuristics that checks and see if a given atom-mapping will result in a disconnected core. The algorithm works as follows:

  1. Given a graph G, and a set of nodes N and edges E, partition N into three groups: A is mapped vertices, B is de-mapped vertices, and C is unvisited vertices. A de-mapped vertex is a vertex that has been visited but is explicitly omitted from the atom-mapping.
  2. Let P be the union of A and C. Note that this may form disconnected components.
  3. Enumerate the connected components on the subgraph of G induced by P: cc_1, cc_2, ...
  4. If A is entirely a subset of one of cc_1, cc_2, etc, then the atom-mapping can still be connected down-stream. Otherwise, A has already been fragmented into multiple connected components and the atom-mapping is guaranteed to result in a disconnected core.
maxentile commented 2 months ago

I wasn't able to detect any bugs yet, but I also don't yet understand this optimization in sufficient detail to review it confidently.

Have tagged in @jkausrelay to take a look as well, since this change will bypass remove_disconnected_components

https://github.com/proteneer/timemachine/blob/ea87cb1349a40ae4ad4a89c0b675ca9cc00eb948/timemachine/fe/atom_mapping.py#L491-L492

proteneer commented 2 months ago

Some line_profiler results - about 70% of the time is now spent in checking connectivity

 350       166         35.6      0.2      0.0      if connected_core:
   351                                                   # process g1 using atom_map_1_to_2_information
   352       166         68.7      0.4      0.0          g1_mapped_nodes = set()
   353       166         45.1      0.3      0.0          g1_unmapped_nodes = set()
   354       166         44.0      0.3      0.0          g1_unvisited_nodes = set()
   355      8798       2182.9      0.2      1.0          for a1, a2 in enumerate(atom_map_1_to_2):
   356      8632       1852.9      0.2      0.8              if a1 < layer:
   357                                                           # visited nodes
   358      1871        524.9      0.3      0.2                  if a2 == UNMAPPED:
   359       163         57.8      0.4      0.0                      g1_unmapped_nodes.add(a1)
   360                                                           else:
   361      1708        583.8      0.3      0.3                      g1_mapped_nodes.add(a1)
   362                                                       else:
   363      6761       2141.7      0.3      1.0                  g1_unvisited_nodes.add(a1)
   364                                                           # are there unvisited nodes that can bridge the core?
   365                                           
   366       166      75912.5    457.3     34.7          if _graph_is_disconnected(g1, g1_mapped_nodes, g1_unmapped_nodes, g1_unvisited_nodes):
   367         2          0.4      0.2      0.0              return
   368                                           
   369                                                   # g2 is a little trickier to process, we need to look at the priority idxs as well
   370       164         60.1      0.4      0.0          g2_mapped_nodes = set()
   371     10988       2464.6      0.2      1.1          for a2, a1 in enumerate(atom_map_2_to_1):
   372     10824       2550.2      0.2      1.2              if a1 != UNMAPPED:
   373      1698        536.3      0.3      0.2                  g2_mapped_nodes.add(a2)
   374       164         50.1      0.3      0.0          g2_unvisited_nodes = set()
   375                                                   # look up priority_idxs of remaining atoms
   376      6834       1335.2      0.2      0.6          for a2_list in priority_idxs[layer:]:
   377     24907       5508.1      0.2      2.5              for a2 in a2_list:
   378     18237       4194.8      0.2      1.9                  if a2 not in g2_mapped_nodes:
   379     16325       5437.5      0.3      2.5                      g2_unvisited_nodes.add(a2)
   380       164       1019.8      6.2      0.5          g2_unmapped_nodes = set(range(g2.n_vertices)).difference(g2_mapped_nodes).difference(g2_unvisited_nodes)
   381                                           
   382       164      71091.5    433.5     32.5          if _graph_is_disconnected(g2, g2_mapped_nodes, g2_unmapped_nodes, g2_unvisited_nodes):
   383                                                       return
   384                                           
   385       164        103.4      0.6      0.0      mcs_result.nodes_visited += 1
   386       164         43.6      0.3      0.0      n_a = g1.n_vertices
   387                                           
   388                                               # leaf-node, every atom has been mapped
   389       164         57.0      0.3      0.0      if layer == n_a:
   390         1          0.5      0.5      0.0          if num_edges == threshold:
   391         1          2.8      2.8      0.0              mcs_result.all_maps.append(copy.copy(atom_map_1_to_2))
   392         1          5.4      5.4      0.0              mcs_result.all_marcs.append(copy.copy(marcs))
   393         1          0.3      0.3      0.0              mcs_result.num_edges = num_edges
   394         1          0.2      0.2      0.0          return
   395                                           
   396       843        248.3      0.3      0.1      for jdx in priority_idxs[layer]:
   397       680        270.6      0.4      0.1          if atom_map_2_to_1[jdx] == UNMAPPED:  # optimize later
   398       555        423.1      0.8      0.2              atom_map_add(atom_map_1_to_2, atom_map_2_to_1, layer, jdx)
   399      1110       3573.2      3.2      1.6              if enforce_core_core and not _verify_core_is_connected(
   400       555        141.7      0.3      0.1                  g1, g2, layer, jdx, atom_map_1_to_2, atom_map_2_to_1
   401                                                       ):
   402       293         69.4      0.2      0.0                  pass
   403       262       9882.3     37.7      4.5              elif not filter_fxn(atom_map_1_to_2):
   404         3          0.7      0.2      0.0                  pass
   405                                                       else:
   406       259       5797.5     22.4      2.7                  new_marcs = refine_marcs(g1, g2, layer, jdx, marcs)
   407       518        616.7      1.2      0.3                  recursion(
   408       259         45.5      0.2      0.0                      g1,
   409       259         38.1      0.1      0.0                      g2,
   410       259         38.1      0.1      0.0                      atom_map_1_to_2,
   411       259         38.3      0.1      0.0                      atom_map_2_to_1,
   412       259         69.0      0.3      0.0                      layer + 1,
   413       259         41.8      0.2      0.0                      new_marcs,
   414       259         38.6      0.1      0.0                      mcs_result,
   415       259         36.5      0.1      0.0                      priority_idxs,
   416       259         37.4      0.1      0.0                      max_visits,
   417       259         37.1      0.1      0.0                      max_cores,
   418       259         37.0      0.1      0.0                      threshold,
   419       259         36.0      0.1      0.0                      enforce_core_core,
   420       259         36.4      0.1      0.0                      connected_core,
   421       259         36.6      0.1      0.0                      filter_fxn,
   422                                                           )
   423       555        566.4      1.0      0.3              atom_map_pop(atom_map_1_to_2, atom_map_2_to_1, layer, jdx)

Of which, the vast majority is spent in enumerating connected components:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   295                                           @profile
   296                                           def _graph_is_disconnected(g, mapped_nodes, demapped_nodes, unvisited_nodes):
   297                                           
   298       330        195.7      0.6      0.1      assert type(mapped_nodes) == set
   299       330        115.6      0.4      0.1      assert type(demapped_nodes) == set
   300       330        117.6      0.4      0.1      assert type(unvisited_nodes) == set
   301                                           
   302                                               # check intersections
   303       330       2004.1      6.1      1.3      assert mapped_nodes.union(demapped_nodes).union(unvisited_nodes) == set(range(g.n_vertices))
   304       330        291.6      0.9      0.2      assert len(mapped_nodes.intersection(demapped_nodes)) == 0
   305       330        188.8      0.6      0.1      assert len(mapped_nodes.intersection(unvisited_nodes)) == 0
   306       330        209.2      0.6      0.1      assert len(demapped_nodes.intersection(unvisited_nodes)) == 0
   307                                           
   308       330        375.3      1.1      0.3      all_possible_nodes = mapped_nodes.union(unvisited_nodes)
   309                                           
   310       330        157.9      0.5      0.1      if len(mapped_nodes) > 0:
   311       314      13191.1     42.0      8.8          sg_all_possible = g.nxg.subgraph(all_possible_nodes)
   312       314        901.9      2.9      0.6          sg_ccs = nx.connected_components(sg_all_possible)
   313                                                   # see if all the mapped nodes belong to the same connected component
   314       323     131789.0    408.0     88.0          for cc in sg_ccs:
   315       321        199.8      0.6      0.1              if mapped_nodes.issubset(cc):
   316       312         75.3      0.2      0.1                  return False
   317         2          0.3      0.2      0.0          return True
   318                                               else:
   319        16          3.6      0.2      0.0          return False