rustwasm / wasm-bindgen

Facilitating high-level interactions between Wasm modules and JavaScript
https://rustwasm.github.io/docs/wasm-bindgen/
Apache License 2.0
7.8k stars 1.08k forks source link

How viable is it to use `repr(C)` structs as WASM primitives? #4232

Open RunDevelopment opened 1 week ago

RunDevelopment commented 1 week ago

Background

Right now, the WasmAbi trait works using WASM primitives. Primitives are Rust types that map directly to WASM ABI types, namely i32, u32, i64, u64, f32, f64, and (). () is interesting here, because its ABI is nothing, as in: it doesn't appear. The WasmAbi trait uses the quirk to allow for a variable number of primitives (up to 4) by filling the rest with ().

This design has the major limitation that ALL Rust types must be representable with at most 4 WASM primitives for them to be WASM compatible. This makes it very difficult to support tuples as value types, since they can easily go past the 4 primitives limit. Raising the limit to some number N also isn't a full solution, since we can always create nested tuples that require more primitives.

Question

My question is whether we can use repr(C) structs as primitives in the WASM ABI? Or to be more precise: is it safe to implement WasmPrimitive for a repr(C) structs where all fields also implement WasmPrimitive?

Right now, compiling the following the Rust code:

#[repr(C)]
pub struct Bar {
    pub a: i32,
    pub b: i64,
}

#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), no_mangle)]
pub unsafe extern "C" fn my_test_fn(a: Bar, b: Bar) {}

yields with WASM ABI:

(module $reference_test.wasm
  (type (;0;) (func (param i32 i64 i32 i64)))
  (func $my_test_fn (;0;) (type 5) (param i32 i64 i32 i64))
  (memory (;0;) 17)
  (export "memory" (memory 0))
  (export "my_test_fn" (func $my_test_fn))
  (@custom "target_features" (after code) "\02+\0fmutable-globals+\08sign-ext")
)

Note that I specifically chose i32 and i64 fields to see how the ABI handles padding. It seems to handle it well, but I don't know whether this is guaranteed.

Motivation

If this was possible, then it would be trivial to support tuples as value types in the ABI with helper types like this:

/// A "tuple" with a stable ABI.
#[repr(C)]
struct WasmTuple2<T1, T2>{
    item1: T1,
    item2: T2,
}
unsafe impl<T1: WasmPrimitive, T2: WasmPrimitive> WasmPrimitive for WasmTuple2<T1, T2> {}

This would also allow us to simplify the WasmAbi trait, since a single primitive would be enough while allowing for arbitrarily complex ABIs. E.g. Option<T> could define its own helper type or reuse WasmTuple2<u32, T> as its ABI. Similar for Result<T, E> and WasmTuple3<u32, T, E>.

Liamolucko commented 1 week ago

The fact that the "C" ABI currently flattens structs out into their primitives is a bug, since it doesn't match the standard WASM C ABI. See #3454, rust-lang/rust#71871. We need to avoid relying on what it does so that rustc can fix its C ABI without breaking wasm-bindgen.

I had an old draft of a way of supporting arbitrarily-many primitives sitting around on my laptop though; I've just finished it off, here it is:

struct FalseTy;
struct TrueTy;
trait BoolTy {
    type PrimSel<A: WasmPrimitive, B: WasmPrimitive>: WasmPrimitive;
    type ListSel<A: PrimitiveList, B: PrimitiveList>: PrimitiveList;

    fn split_tuple<A: PrimitiveList, B: PrimitiveList>(
        tuple: (A, B),
    ) -> (Prim<(A, B)>, Rest<(A, B)>)
    where
        A::Prim: WasmPrimitive<IsUnit = Self>;

    fn join_tuple<A: PrimitiveList, B: PrimitiveList>(
        prim: <(A, B) as PrimitiveList>::Prim,
        rest: <(A, B) as PrimitiveList>::Rest,
    ) -> (A, B)
    where
        A::Prim: WasmPrimitive<IsUnit = Self>;

    fn split<T: PrimitiveList>(x: &T) -> (Prim1<T>, Prim2<T>, Prim3<T>, Prim4<T>)
    where
        Prim<Rest<Rest<Rest<Rest<T>>>>>: WasmPrimitive<IsUnit = Self>;

    fn join<T: PrimitiveList>(a: Prim1<T>, b: Prim2<T>, c: Prim3<T>, d: Prim4<T>) -> T
    where
        Prim<Rest<Rest<Rest<Rest<T>>>>>: WasmPrimitive<IsUnit = Self>;
}

impl BoolTy for FalseTy {
    type PrimSel<A: WasmPrimitive, B: WasmPrimitive> = B;
    type ListSel<A: PrimitiveList, B: PrimitiveList> = B;

    fn split_tuple<A: PrimitiveList, B: PrimitiveList>((a, b): (A, B)) -> (A::Prim, (A::Rest, B))
    where
        A::Prim: WasmPrimitive<IsUnit = FalseTy>,
    {
        let (a_prim, a_rest) = a.split();
        (a_prim, (a_rest, b))
    }

    fn join_tuple<A: PrimitiveList, B: PrimitiveList>(prim: A::Prim, rest: (A::Rest, B)) -> (A, B)
    where
        A::Prim: WasmPrimitive<IsUnit = FalseTy>,
    {
        (A::join(prim, rest.0), rest.1)
    }

    fn split<T: PrimitiveList>(x: &T) -> (&T, (), (), ())
    where
        Prim<Rest<Rest<Rest<Rest<T>>>>>: WasmPrimitive<IsUnit = Self>,
    {
        (&x, (), (), ())
    }

    fn join<T: PrimitiveList>(a: &T, _: (), _: (), _: ()) -> T
    where
        Prim<Rest<Rest<Rest<Rest<T>>>>>: WasmPrimitive<IsUnit = Self>,
    {
        *a
    }
}
impl BoolTy for TrueTy {
    type PrimSel<A: WasmPrimitive, B: WasmPrimitive> = A;
    type ListSel<A: PrimitiveList, B: PrimitiveList> = A;

    fn split_tuple<A: PrimitiveList, B: PrimitiveList>((_, b): (A, B)) -> (B::Prim, B::Rest)
    where
        A::Prim: WasmPrimitive<IsUnit = Self>,
    {
        b.split()
    }

    fn join_tuple<A: PrimitiveList, B: PrimitiveList>(prim: B::Prim, rest: B::Rest) -> (A, B)
    where
        A::Prim: WasmPrimitive<IsUnit = TrueTy>,
    {
        (A::empty(), B::join(prim, rest))
    }

    fn split<T: PrimitiveList>(x: &T) -> (Prim1<T>, Prim2<T>, Prim3<T>, Prim4<T>)
    where
        Prim<Rest<Rest<Rest<Rest<T>>>>>: WasmPrimitive<IsUnit = Self>,
    {
        let (a, rest) = (*x).split();
        let (b, rest) = rest.split();
        let (c, rest) = rest.split();
        let (d, _) = rest.split();
        (a, b, c, d)
    }

    fn join<T: PrimitiveList>(a: Prim1<T>, b: Prim2<T>, c: Prim3<T>, d: Prim4<T>) -> T
    where
        Prim<Rest<Rest<Rest<Rest<T>>>>>: WasmPrimitive<IsUnit = Self>,
    {
        let rest = PrimitiveList::join(d, PrimitiveList::empty());
        let rest = PrimitiveList::join(c, rest);
        let rest = PrimitiveList::join(b, rest);
        PrimitiveList::join(a, rest)
    }
}

type PrimSel<Sel, A, B> = <Sel as BoolTy>::PrimSel<A, B>;
type ListSel<Sel, A, B> = <Sel as BoolTy>::ListSel<A, B>;

trait WasmPrimitive: Copy {
    type IsUnit: BoolTy;
    /// If `Self = ()`, returns `()`, otherwise panics.
    fn unit() -> Self {
        unreachable!()
    }
}
impl WasmPrimitive for () {
    type IsUnit = TrueTy;
    fn unit() {}
}
impl WasmPrimitive for u32 {
    type IsUnit = FalseTy;
}
impl<T> WasmPrimitive for &T {
    type IsUnit = FalseTy;
}
// etc...

trait PrimitiveList: Copy {
    type Prim: WasmPrimitive;
    type Rest: PrimitiveList;
    /// If this `PrimitiveList` is empty, returns the only valid instance of it
    /// (since it should be all `()`s); otherwise panics.
    fn empty() -> Self;
    fn split(self) -> (Self::Prim, Self::Rest);
    fn join(prim: Self::Prim, rest: Self::Rest) -> Self;
}

type IsUnit<T> = <T as WasmPrimitive>::IsUnit;

type Prim<T> = <T as PrimitiveList>::Prim;
type Rest<T> = <T as PrimitiveList>::Rest;

type NaivePrim1<T> = Prim<T>;
type NaivePrim2<T> = Prim<Rest<T>>;
type NaivePrim3<T> = Prim<Rest<Rest<T>>>;
type NaivePrim4<T> = Prim<Rest<Rest<Rest<T>>>>;
type NaivePrim5<T> = Prim<Rest<Rest<Rest<Rest<T>>>>>;

type IsEmpty<T> = IsUnit<NaivePrim1<T>>;

impl<T: WasmPrimitive> PrimitiveList for T {
    type Prim = T;
    type Rest = ();

    fn empty() -> Self {
        T::unit()
    }

    fn split(self) -> (Self, ()) {
        (self, ())
    }

    fn join(prim: Self, _: ()) -> Self {
        prim
    }
}

// This is a general way to concatenate two `PrimitiveList`s, but shouldn't be
// what we actually use for argument types, since we need the `PrimitiveList` to
// be `#[repr(C)]` in the long case so we can have JS fill it in and pass a
// pointer.
impl<T: PrimitiveList, U: PrimitiveList> PrimitiveList for (T, U) {
    type Prim = PrimSel<IsEmpty<T>, Prim<U>, Prim<T>>;
    type Rest = ListSel<IsEmpty<T>, Rest<U>, (Rest<T>, U)>;

    fn empty() -> Self {
        (T::empty(), U::empty())
    }

    fn split(self) -> (Self::Prim, Self::Rest) {
        <IsEmpty<T>>::split_tuple(self)
    }

    fn join(prim: Self::Prim, rest: Self::Rest) -> Self {
        <IsEmpty<T>>::join_tuple(prim, rest)
    }
}

#[derive(Clone, Copy)]
#[repr(C)]
struct WasmTuple2<T, U>(T, U);

impl<T: PrimitiveList, U: PrimitiveList> PrimitiveList for WasmTuple2<T, U> {
    type Prim = Prim<(T, U)>;
    type Rest = Rest<(T, U)>;

    fn empty() -> Self {
        Self(T::empty(), U::empty())
    }

    fn split(self) -> (Self::Prim, Self::Rest) {
        (self.0, self.1).split()
    }

    fn join(prim: Self::Prim, rest: Self::Rest) -> Self {
        let (a, b) = <(T, U)>::join(prim, rest);
        Self(a, b)
    }
}

#[derive(Clone, Copy)]
#[repr(C)]
struct WasmTuple3<A, B, C>(A, B, C);

impl<A: PrimitiveList, B: PrimitiveList, C: PrimitiveList> PrimitiveList for WasmTuple3<A, B, C> {
    type Prim = Prim<(A, (B, C))>;
    type Rest = Rest<(A, (B, C))>;

    fn empty() -> Self {
        Self(A::empty(), B::empty(), C::empty())
    }

    fn split(self) -> (Self::Prim, Self::Rest) {
        (self.0, (self.1, self.2)).split()
    }

    fn join(prim: Self::Prim, rest: Self::Rest) -> Self {
        let (a, (b, c)) = <(A, (B, C))>::join(prim, rest);
        Self(a, b, c)
    }
}

type IsShort<T> = <NaivePrim5<T> as WasmPrimitive>::IsUnit;

type Prim1<'a, T> = PrimSel<IsShort<T>, NaivePrim1<T>, &'a T>;
type Prim2<T> = PrimSel<IsShort<T>, NaivePrim2<T>, ()>;
type Prim3<T> = PrimSel<IsShort<T>, NaivePrim3<T>, ()>;
type Prim4<T> = PrimSel<IsShort<T>, NaivePrim4<T>, ()>;

fn split<T: PrimitiveList>(x: &T) -> (Prim1<T>, Prim2<T>, Prim3<T>, Prim4<T>) {
    <IsShort<T>>::split(x)
}
fn join<T: PrimitiveList>(a: Prim1<T>, b: Prim2<T>, c: Prim3<T>, d: Prim4<T>) -> T {
    <IsShort<T>>::join(a, b, c, d)
}

fn main() {
    let a = 1234;

    let x = ((1, 2), (&a, (4, 5)));
    let y = ((1, 2), (&a, 4));
    let z = 1;

    assert_eq!(split(&x), (&x, (), (), ()));
    assert_eq!(split(&y), (1, 2, &a, 4));
    assert_eq!(split(&z), (1, (), (), ()));

    assert_eq!(join::<((u32, u32), (&i32, (u32, u32)))>(&x, (), (), ()), x);
    assert_eq!(join::<((u32, u32), (&i32, u32))>(1, 2, &a, 4), y);
    assert_eq!(join::<u32>(1, (), (), ()), z);
}

It's inspired by the way the component model works: if there are less than 16 primitives (I've used 4 here to keep things short but 16 probably makes more sense), they're passed as arguments, but if there are more they're all passed by pointer instead.

But the primary reason for passing the whole thing by pointer rather than just the ones past the 16th element is to make it easier to implement: we can just pass a pointer to the original type, rather than having to somehow generate a #[repr(C)] struct with arbitrarily many fields (which I don't think is possible).

The idea is that the macro would create a #[repr(C)] struct containing all the arguments of a function, implement PrimitiveList for it, and then use PrimX<T> as the arguments of the extern "C" function.

This solution requires GATs though, which would require bumping the MSRV to 1.65, which is a bit of a problem with #4038. It's also not exactly a simplification.