pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

Lower Multinomial without Replacement #4865

Open cowanmeg opened 1 year ago

cowanmeg commented 1 year ago

🚀 Feature

Lower multinomial without replacement.

Motivation

Currently, we lower multinomial with replacement (https://github.com/pytorch/xla/pull/4848), but we want to accelerate both paths.

Some relevant discussion on tensorflow that might be helpful: https://github.com/tensorflow/tensorflow/issues/9260

JackCaoG commented 1 year ago

Can this be closed?

cowanmeg commented 1 year ago

No, the current implementation falls back to the cpu when replacement is False.

wonjoolee95 commented 1 year ago

@cowanmeg, I'll assign this to you as part of cleaning up our unassigned issues. No immediate need to work on this as far as I can tell, so feel free to put it in your backlog and prioritize as necessary.