pytorch / nestedtensor

[Prototype] Tools for the concurrent manipulation of variably sized Tensors.
BSD 3-Clause "New" or "Revised" License
252 stars 28 forks source link

Small optimizations for MHA #465

Closed cpuhrsch closed 2 years ago

cpuhrsch commented 2 years ago

Use TensorImpl dtype and device Remove _height from EfficientSizeNode Collapse dims before adding padding