GFNOrg / torchgfn

GFlowNet library
https://torchgfn.readthedocs.io/en/latest/
Other
209 stars 26 forks source link

Nit: Tensor padding utility function (DRY) #154

Closed josephdviviano closed 6 months ago

josephdviviano commented 6 months ago

In trajectories.py we have the following:

                    # TODO: This should be a single reused function 
                    # The size of self needs to grow to match other along dim=0.
                    if self_shape[0] < other_shape[0]:
                        pad_dim = required_first_dim - self_shape[0]
                        pad_dim_full = (pad_dim,) + tuple(self_shape[1:])
                        output_padding = torch.full(
                            pad_dim_full,
                            fill_value=-float("inf"),
                            dtype=self.estimator_outputs.dtype,  # TODO: This isn't working! Hence the cast below...
                            device=self.estimator_outputs.device,
                        )
                        self.estimator_outputs = torch.cat(
                            (self.estimator_outputs, output_padding),
                            dim=0,
                        )

This logic appears multiple times in the library and could be abstracted into a utility function.

josephdviviano commented 6 months ago

closing as we have a PR.