MolecularAI / REINVENT4

AI molecular design tool for de novo design, scaffold hopping, R-group replacement, linker design and molecule optimization.
Apache License 2.0
364 stars 89 forks source link

Looking for advice on adjusting emphasis on filters vs. scores in RL #128

Closed lee-jin-gyu96 closed 2 months ago

lee-jin-gyu96 commented 2 months ago

Hello, I've been trying to run staged learning with the Libinvent model and an extensive suite of filters and scores.

The main issue is that learning seems to go very heavily in the direction of generating molecules that pass all the filters, and neglects improving the score of the molecules that do pass the filters.

This tendency persists despite tweaking the learning rate / sigma / batch size / step count, and despite my attempt to exclude molecules that were filtered from the loss calculation (>'s added for emphasis):

...
            # Compute the filter mask to prevent computation of scores in the
            # second loop over the non-filter components below
            for component in self.components.filters:
                transform_result = compute_transform(
                    component.component_type,
                    component.params,
                    smilies,
                    component.cache,
                    invalid_mask,
                    valid_mask,
                )

                for scores in transform_result.transformed_scores:
                    valid_mask = np.logical_and(scores, valid_mask)
                # NOTE: filters are NOT also used as components as in REINVENT3

                filters_to_report.append(transform_result)

>          # After running all filters, check which molecules were filtered out despite having valid smiles
>          filtered_mask = np.logical_and(init_valid_mask, np.logical_not(valid_mask))
...
>      # Exclude molecules that were filtered out
>      if self.exclude_zeroes:
>          final_scores[filtered_mask] = np.nan

        return ScoreResults(smilies, final_scores, completed_components)
...

So basically, I'm trying to keep track of the indices valid molecules that were filtered out so that I can overwrite their scores with np.nan. This, if I've understood the code correctly, will result in those molecules' scores to be not included in the loss calculation in RLReward:

...
        nan_idx = torch.isnan(scores)
        scores_nonnan = scores[~nan_idx]
        agent_lls = -agent_nlls[~nan_idx]  # negated because we need to minimize
        prior_lls = -prior_nlls[~nan_idx]

        loss, augmented_lls = self._strategy(
            agent_lls,
            scores_nonnan,
            prior_lls,
            self._sigma,
        )
...

I've even tried removing the most complex score components (e.g., docking scores), but I haven't noticed much improvement.

The filters I'm using are quite standard, as far as I'm aware: Lilly, PAINS, and a series of property filters (e.g., molecular weight, number of HBA/HBDs, max ring size, etc.).

I will also mention the average score of molecules that pass the filters do start off relatively high, at around 0.7 - 0.8. But I was hoping I'd be able to increase that average to the 0.9 - 1.0 range.

tl;dr

I fully understand if this is beyond the scope of help I should expect via a github issue, (i.e., feel free to close this!) but

Thank you.

halx commented 2 months ago

Hi,

many thanks for your interest in REINVENT and welcome to the community!

First of all I would have to ask if you have actually implemented those scoring components as filters yourself because currently there is only one filter if I recall correctly. Using too many filters may not be a good idea. Some of the components are more for demonstration purposes: Lilly MedChem rules are probably fine but the Lilly PAINS implementation is rather hackish and would not correspond to the approach they are actually using.

The total score depends on the transforms and aggregation functions you are using. You do not have a guarantee that the total scores converges close to zero therefore.

Cheers, Hannes.

lee-jin-gyu96 commented 2 months ago

Hello, thank you for the response.

Yes, I did implement the filters from scratch, and upon further investigation. Upon investigation, it does seem that having too many filters had an impact on performance, as certain subsets of filters seem to work.

Thank you for the help :)