rxn4chemistry / rxnfp

Reaction fingerprints, atlases and classification. Code complementing our Nature Machine Intelligence publication on "Mapping the space of chemical reactions using attention-based neural networks" (http://rdcu.be/cenmd).
https://rxn4chemistry.github.io/rxnfp/
MIT License
156 stars 40 forks source link

Bugs in batch generation of fingerprints by rxnfp #17

Closed wangxr0526 closed 11 months ago

wangxr0526 commented 2 years ago

This job is really great, but I encountered some small problems when using rxnfp. https://github.com/rxn4chemistry/rxnfp/blob/master/rxnfp/transformer_fingerprints.py#L134 It seems that the last batch of data is discarded in generate_fingerprint when len(rxns)%batch_size!=0

This change may fix the problem:

def generate_fingerprints(rxns: List[str], fingerprint_generator:FingerprintGenerator, batch_size=1) -> np.array:
  fps = []
  n_batches = len(rxns) // batch_size
  if len(rxns) % batch_size != 0:
      n_batches += 1
  emb_iter = iter(rxns)
  for i in tqdm(range(n_batches)):
      batch = list(islice(emb_iter, batch_size))

      fps_batch = fingerprint_generator.convert_batch(batch)

      fps += fps_batch
  return np.array(fps)