datamol-io / molfeat

molfeat - the hub for all your molecular featurizers
https://molfeat.datamol.io
Apache License 2.0
169 stars 16 forks source link

Update the collate_fn for PYGGraphTransformer #92

Closed zhu0619 closed 6 months ago

zhu0619 commented 7 months ago

Motivation

For the latest version of torch_geometric 2.4.0, the Collater was refactored and require an argument dataset.

Therefore, the current implementation of PYGGraphTransformer.get_collate_fn is no longer valid.

Pitch

The argument dataset should also be added in the get_collate_fn.

def get_collate_fn(
        self,
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        return_pair: Optional[bool] = True,
        **kwargs,
    ):
        collator = Collater(dataset=dataset, follow_batch=follow_batch, exclude_keys=exclude_keys)
        return partial(self._collate_batch, collator=collator, return_pair=return_pair)

Additional context

No response

maclandrol commented 7 months ago

Thanks @zhu0619, can you make a PR ?

maclandrol commented 6 months ago

@zhu0619, can you document the alternative of using the pyg DataLoader object directly somewhere ?

It allows actually not providing any collate function