Closed davidweichiang closed 2 months ago
Very easy, I think. You can follow real_einsum_forward
as a guide: https://github.com/bdusell/semiring-einsum/blob/master/torch_semiring_einsum/real_forward.py
def _callback(compute_sum):
return compute_sum(add_in_place, sum_block, multiply_in_place)
You would just need to replace add_in_place
with an implementation of or_in_place
(see https://pytorch.org/docs/stable/generated/torch.logical_or.html), sum_block
with an implementation of or_block
(sum followed by >0?), and multiply_in_place
with an implementation of and_in_place
(see https://pytorch.org/docs/stable/generated/torch.logical_and.html).
Oh, good news -- your einsum already works on Boolean tensors. So nothing needs to be implemented. (But maybe it should be documented?)
Neat! Yeah, that's worth documenting. Maybe there would be a speed advantage to implementing another version that uses logical_or
and logical_and
?
How hard would it be to add the Boolean semiring?