pyg-team / pytorch-frame

Tabular Deep Learning Library for PyTorch
https://pytorch-frame.readthedocs.io
MIT License
505 stars 53 forks source link

Remove CUDA synchronizations by slicing input tensor with `int` instead of CUDA tensors in `nn.LinearEmbeddingEncoder` #432

Closed akihironitta closed 1 month ago

akihironitta commented 1 month ago

start_idx and end_idx used at feat.values[:, start_idx:end_idx] are on device, which leads to cuda synchronizations.