Context for future reference: When using the Metal backend, the built-in dropout helper produces the following error as it scaled the mask prior to a cast.
thread 'multi_head_attn::multihead_block::tests::test_multiheadattnblock_forward' panicked at src/multi_head_attn/multihead_block.rs:275:14:
called `Result::unwrap()` on an `Err` value: Msg("Metal contiguous affine U8 not implemented")
I can confirm, that this works now.
Context for future reference: When using the Metal backend, the built-in dropout helper produces the following error as it scaled the mask prior to a cast.