modularml / mojo

The Mojo Programming Language
https://docs.modular.com/mojo/manual/
Other
23.35k stars 2.6k forks source link

[Feature Request] Function multiversioning #3651

Open owenhilyard opened 1 month ago

owenhilyard commented 1 month ago

Review Mojo's priorities

What is your request?

What I would like is an equivalent capability to Clang's "target" function attribute. Namely, the ability to compile a function (and everything called by it) multiple times for different feature sets (ex: base x86_64, sse4, avx, avx2, avx512) and have the runtime pick the best function for the processor it is running on.

As an example, consider the following function:

fn _aes_encrypt_buffer_in_place[buffer_len: Int](data: UnsafePointer[UInt8], key: SIMD[DType.uint64, 2]):
   constrained[width % 16 == 0, "width must be padded to 128 bits for in-place encryption."]()

    @parameter
    if has_avx512vl() and has_vaes() and has_aesni():
      aes_rounds[buffer_len, "llvm.x86.aesni.512", 512](data, key)
      aes_rounds_cleanup[buffer_len, "llvm.x86.aesni.256", 256](data, key)
      aes_rounds_cleanup[buffer_len, "llvm.x86.aesni", 128]()
    elif has_vaes() and has_aesni():
      aes_rounds[buffer_len, "llvm.x86.aesni.256", 256]()
      aes_rounds_cleanup[buffer_len, "llvm.x86.aesni", 128]()
    elif has_aesni():
      aes_rounds[buffer_len, "llvm.x86.aesni", 128]()
    else:
      aes_rounds_software_fallback[buffer_len](data, key) 

specialize_on_target_features["aes_encrypt_buffer_in_place", _aes_encrypt_buffer_in_place, "arch=x86-64", "arch=x86-64,aes", "arch=x86-64,aes,vaes", "arch=x86-64,aes,vaes,avx512vl"]()

aes_rounds does the main loop of encrypting everything and aes_rounds_cleanup acts as a drain loop since on some processors the wider instructions are slightly more expensive (512 bit penalty) so it's worth it to use drain loops instead of extra instructions with padded data. This is a hot loop function for anything using TLS that selects an AES, so ideally it should be fast. At present, if you compile for portability, you get the software fallback all the time, but if this feature was present then compiling for portability would ideally compile down to 4 separate functions representing each of the possible branches, with the has_x functions respecting that the target has specialized. I don't have any strong feelings about the syntax, so whatever works best for the compiler internals.

What is your motivation for this change?

The reason I came to wanting this feature is because the current has_x functions in std are globally scoped. It means that it's not possible to do this kind of runtime feature detection with parts of std unless a parallel set of runtime_has_x functions are created, which while possible would mean many std functions would need to check either a global table or go read MSRs and do some bit operations, which is not ideal for hot loop code. For example, if we have support for TLS in std, it will likely need support for AES, which can make use of hardware instructions if they are present. A portable application will be compiled for base ARMv8 or x86-64-v1, which lack those instructions despite most server CPUs supporting those instructions. This feature enables at least hot loops to take better Absolute maximum performance could be achieved by putting this over main or a similar high-level function and effectively building the whole program for each feature combination you care about, at the cost of substantial compile time and binary size.

Any other details?

It would be nice if there was a way to statically pick the function variant used, for instance using -march=native or -mtune=native or some other "--specialize-at-compile-time" flag, which would statically pick which function variant to use so no performance is lost to indirection when compiling from source and inlining can still occur.

martinvuyk commented 1 month ago

I think that this is part of the future goals of mojo package, to be a target agnostic semi-compiled binary (keeps the IR) for distribution. Mentioned in this community meeting moment