GFNOrg / torchgfn

GFlowNet library
https://torchgfn.readthedocs.io/en/latest/
Other
209 stars 26 forks source link

Simplification of `SubTBGFlowNet.loss` and `SubTBGFlowNet.get_scores` #132

Closed marpaia closed 1 year ago

marpaia commented 1 year ago

This PR attempts to tackle #117: Simplify SubTBGFlowNet.get_scores and SubTBGFlowNet.loss

Most importantly, in this PR, SubTBGFlowNet.loss is reduced to:

    def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]:
        # Get all scores and masks from the trajectories.
        scores, flattening_masks = self.get_scores(env, trajectories)
        flattening_mask = torch.cat(flattening_masks)
        all_scores = torch.cat(scores, 0)

        if self.weighting == "DB":
            # Longer trajectories contribute more to the loss
            return scores[0][~flattening_masks[0]].pow(2).mean()

        elif self.weighting == "geometric":
            # The position i of the following 1D tensor represents the number of sub-
            # trajectories of length i in the batch.
            # n_sub_trajectories = torch.maximum(
            #     trajectories.when_is_done - torch.arange(3).unsqueeze(-1),
            #     torch.tensor(0),
            # ).sum(1)

            # The following tensor's k-th entry represents the mean of all losses of
            # sub-trajectories of length k.
            per_length_losses = torch.stack(
                [
                    scores[~flattening_mask].pow(2).mean()
                    for scores, flattening_mask in zip(scores, flattening_masks)
                ]
            )
            max_len = trajectories.max_length
            L = self.lamda
            ratio = (1 - L) / (1 - L**max_len)
            weights = ratio * (
                L ** torch.arange(max_len, device=per_length_losses.device)
            )
            assert (weights.sum() - 1.0).abs() < 1e-5, f"{weights.sum()}"
            return (per_length_losses * weights).sum()

        elif self.weighting == "equal_within":
            contributions = self.get_equal_within_contributions(trajectories)

        elif self.weighting == "equal":
            contributions = self.get_equal_contributions(trajectories)

        elif self.weighting == "TB":
            contributions = self.get_tb_contributions(trajectories, all_scores)

        elif self.weighting == "ModifiedDB":
            contributions = self.get_modified_db_contributions(trajectories)

        elif self.weighting == "geometric_within":
            contributions = self.get_geometric_within_contributions(trajectories)
        else:
            raise ValueError(f"Unknown weighting method {self.weighting}")

        flat_contributions = contributions[~flattening_mask]
        assert (
            flat_contributions.sum() - 1.0
        ).abs() < 1e-5, f"{flat_contributions.sum()}"
        losses = flat_contributions * all_scores[~flattening_mask].pow(2)
        return losses.sum()

The only additional change I debated here was whether or not to encapsulate that final loss calculation into an anonymous function like this:

def loss_from_contributions(contributions):
  flat_contributions = contributions[~flattening_mask]
        assert (
            flat_contributions.sum() - 1.0
        ).abs() < 1e-5, f"{flat_contributions.sum()}"
        losses = flat_contributions * all_scores[~flattening_mask].pow(2)
        return losses.sum()

And then in the last few branches, you'd have:

        elif self.weighting == "equal_within":
            contributions = self.get_equal_within_contributions(trajectories)
            return loss_from_contributions(contributions)

        elif self.weighting == "equal":
            contributions = self.get_equal_contributions(trajectories)
            return loss_from_contributions(contributions)

        elif self.weighting == "TB":
            contributions = self.get_tb_contributions(trajectories, all_scores)
            return loss_from_contributions(contributions)

        elif self.weighting == "ModifiedDB":
            contributions = self.get_modified_db_contributions(trajectories)
            return loss_from_contributions(contributions)

        elif self.weighting == "geometric_within":
            contributions = self.get_geometric_within_contributions(trajectories)
            return loss_from_contributions(contributions)

        else:
            raise ValueError(f"Unknown weighting method {self.weighting}")

This is nice because then every branch returns out of the function as opposed to having the last several fall through to the higher scope. It incurs a little bit of copy-pasta though and I generally try to avoid that kind of function nesting if I can help it. So given that, I decided to leave the higher level control flow as is but I'd be happy to make this change if y'all would prefer.

There are some additional simplifications to get_scores as well but they're less controversial.

NB, there is a lot of moving around of code here so I think the inline diff viewer isn't really optimal here. The side-by-side diff viewer is better but still not great. I'd suggest reading both the old version of each function alongside the new version of each function for optimal review readability.

saleml commented 1 year ago

Thanks a lot for this PR! I like how it's currently coded, with some branches not returning anything. It is clear that it's the higher scope's purpose.

I left two comments about two docstrings. I'll merge as soon as it's resolved.

marpaia commented 1 year ago

Hey @saleml, great catch on those comments! Those args docs were left over from an earlier version of the refactor where I was passing those variables around instead of calculating them off of trajectories. That is all fixed now 😄

saleml commented 1 year ago

Thanks. Also, for consistency, do you mind making explicit the shape of the tensors with TensorType ?

1 - For get_..._contributions, the output type is TT['max_len * (1 + max_len) / 2', 'n_trajectories'].

2 - Fix the output type of cumulative_logprobs to TT["max_length + 1", "n_trajectories"]

3 -

    def calculate_preds(
        self,
        log_pf_trajectories_cum: TT["max_length + 1", "n_trajectories"],
        log_state_flows: TT["max_length", "n_trajectories"],
        i: int,
    ) -> TT["max_length + 1 - i", "n_trajectories"]:

4-

    def calculate_targets(
        self,
        trajectories: Trajectories,
        preds: TT["max_length + 1 - i", "n_trajectories"],
        log_pb_trajectories_cum: TT["max_length + 1", "n_trajectories"],
        log_state_flows: TT["max_length", "n_trajectories"],
        is_terminal_mask: TT["max_length", "n_trajectories"],
        sink_states_mask: TT["max_length", "n_trajectories"],
        full_mask: TT["max_length", "n_trajectories"],
        i: int,
    ) -> TT["max_length + 1 - i", "n_trajectories"]:

5-

    def calculate_log_state_flows(
        self,
        env: Env,
        trajectories: Trajectories,
        log_pf_trajectories: TT["max_length", "n_trajectories"],
    ) -> TT["max_length", "n_trajectories"]:

6-

    def calculate_masks(
        self, log_state_flows: TT["max_length", "n_trajectories"], trajectories: Trajectories
    ) -> Tuple[TT["max_length", "n_trajectories"], TT["max_length", "n_trajectories"], TT["max_length", "n_trajectories"]]:
marpaia commented 1 year ago

Awesome. That is a great call. I added a ContributionsTensor type alias so that that could be more easily re-used:

ContributionsTensor = TT["max_len * (1 + max_len) / 2", "n_trajectories"]

But I'm happy to just proliferate direct use of TT if you think that adds more indirection than value.

marpaia commented 1 year ago

Reading through the code for this file, it seems like a lot of the explicit typing of parameters on any given function is directly related to the return type of another function. For these examples throughout the class, I updated these to use a FooBarTensor type defined at the top of the file as a constant and shared throughout. Hopefully this makes the TT types easier to audit/fix in the future. Although it seems like, unless we use typeguard, these rules aren't explicitly enforced. Perhaps this is something we should do / see if we can do in CI at least?

josephdviviano commented 11 months ago

I think this PR broke typing:

https://github.com/saleml/torchgfn/issues/137

tests are also failing:

============================================== ERRORS ===============================================
_____________________________ ERROR collecting testing/test_gflownet.py _____________________________
../../.local/lib/python3.10/site-packages/_pytest/runner.py:341: in from_call
    result: Optional[TResult] = func()
../../.local/lib/python3.10/site-packages/_pytest/runner.py:372: in <lambda>
    call = CallInfo.from_call(lambda: list(collector.collect()), "collect")
../../.local/lib/python3.10/site-packages/_pytest/python.py:531: in collect
    self._inject_setup_module_fixture()
../../.local/lib/python3.10/site-packages/_pytest/python.py:545: in _inject_setup_module_fixture
    self.obj, ("setUpModule", "setup_module")
../../.local/lib/python3.10/site-packages/_pytest/python.py:310: in obj
    self._obj = obj = self._getobj()
../../.local/lib/python3.10/site-packages/_pytest/python.py:528: in _getobj
    return self._importtestmodule()
../../.local/lib/python3.10/site-packages/_pytest/python.py:617: in _importtestmodule
    mod = import_path(self.path, mode=importmode, root=self.config.rootpath)
../../.local/lib/python3.10/site-packages/_pytest/pathlib.py:565: in import_path
    importlib.import_module(module_name)
../../miniconda3/envs/torchgfn/lib/python3.10/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
<frozen importlib._bootstrap>:1050: in _gcd_import
    ???
<frozen importlib._bootstrap>:1027: in _find_and_load
    ???
<frozen importlib._bootstrap>:1006: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:688: in _load_unlocked
    ???
../../.local/lib/python3.10/site-packages/_pytest/assertion/rewrite.py:178: in exec_module
    exec(co, module.__dict__)
testing/test_gflownet.py:3: in <module>
    from gfn.gflownet import FMGFlowNet, TBGFlowNet
src/gfn/gflownet/__init__.py:1: in <module>
    from .base import GFlowNet, PFBasedGFlowNet, TrajectoryBasedGFlowNet
src/gfn/gflownet/base.py:78: in <module>
    class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
../../miniconda3/envs/torchgfn/lib/python3.10/typing.py:312: in inner
    return func(*args, **kwds)
../../miniconda3/envs/torchgfn/lib/python3.10/typing.py:1345: in __class_getitem__
    _check_generic(cls, params, len(cls.__parameters__))
../../.local/lib/python3.10/site-packages/typing_extensions.py:152: in _check_generic
    raise TypeError(f"{cls} is not a generic class")
E   TypeError: <class 'gfn.gflownet.base.PFBasedGFlowNet'> is not a generic class
___________________ ERROR collecting testing/test_parametrizations_and_losses.py ____________________
../../.local/lib/python3.10/site-packages/_pytest/runner.py:341: in from_call
    result: Optional[TResult] = func()
../../.local/lib/python3.10/site-packages/_pytest/runner.py:372: in <lambda>
    call = CallInfo.from_call(lambda: list(collector.collect()), "collect")
../../.local/lib/python3.10/site-packages/_pytest/python.py:531: in collect
    self._inject_setup_module_fixture()
../../.local/lib/python3.10/site-packages/_pytest/python.py:545: in _inject_setup_module_fixture
    self.obj, ("setUpModule", "setup_module")
../../.local/lib/python3.10/site-packages/_pytest/python.py:310: in obj
    self._obj = obj = self._getobj()
../../.local/lib/python3.10/site-packages/_pytest/python.py:528: in _getobj
    return self._importtestmodule()
../../.local/lib/python3.10/site-packages/_pytest/python.py:617: in _importtestmodule
    mod = import_path(self.path, mode=importmode, root=self.config.rootpath)
../../.local/lib/python3.10/site-packages/_pytest/pathlib.py:565: in import_path
    importlib.import_module(module_name)
../../miniconda3/envs/torchgfn/lib/python3.10/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
<frozen importlib._bootstrap>:1050: in _gcd_import
    ???
<frozen importlib._bootstrap>:1027: in _find_and_load
    ???
<frozen importlib._bootstrap>:1006: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:688: in _load_unlocked
    ???
../../.local/lib/python3.10/site-packages/_pytest/assertion/rewrite.py:178: in exec_module
    exec(co, module.__dict__)
testing/test_parametrizations_and_losses.py:5: in <module>
    from gfn.gflownet import (
src/gfn/gflownet/__init__.py:1: in <module>
    from .base import GFlowNet, PFBasedGFlowNet, TrajectoryBasedGFlowNet
src/gfn/gflownet/base.py:78: in <module>
    class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
../../miniconda3/envs/torchgfn/lib/python3.10/typing.py:312: in inner
    return func(*args, **kwds)
../../miniconda3/envs/torchgfn/lib/python3.10/typing.py:1345: in __class_getitem__
    _check_generic(cls, params, len(cls.__parameters__))
../../.local/lib/python3.10/site-packages/typing_extensions.py:152: in _check_generic
    raise TypeError(f"{cls} is not a generic class")
E   TypeError: <class 'gfn.gflownet.base.PFBasedGFlowNet'> is not a generic class
====================================== short test summary info ======================================
ERROR testing/test_gflownet.py - TypeError: <class 'gfn.gflownet.base.PFBasedGFlowNet'> is not a generic class
ERROR testing/test_parametrizations_and_losses.py - TypeError: <class 'gfn.gflownet.base.PFBasedGFlowNet'> is not a generic class
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 2 errors during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
========================================= 2 errors in 0.77s =========================================
(torchgfn) jdv@delarge:~/code/torchgfn
josephdviviano commented 11 months ago

I think this is the fix https://github.com/saleml/torchgfn/pull/139