wgsl-tooling-wg / wesl-spec

A portable and modular superset of WGSL
BSD 3-Clause "New" or "Revised" License
29 stars 3 forks source link

Module Generics Lowering #40

Open stefnotch opened 1 month ago

stefnotch commented 1 month ago

Generic modules are a reasonably convenient way of introducing powerful generics without requiring full WGSL type checking. For example, instead of repeatedly writing my_generic_func<u32>(3), one could reasonably import a generic module that is specialized for u32s.

Background info on monomorphization

Normal generics, that being generic functions and generic structs, get monomorphized. During monomorphization, they also get a mangled name which encodes (function name, first generic, second generic, ...).

If we have code like the following

fn foo<A, B>() {
  return bar<A>() + bar<B>();
}
fn bar<A>() {
    return ...
}

then foo<f32, u32>() results in

Meanwhile foo<u32, u32>() results in

A good implementation of monomorphization will also replace aliases with the actual type.

alias cat = u32;
let result = bar<cat>() + bar<u32>();
// should only result in one copy of bar

Generic Module Monomorphization

A generic module has a set of global declarations, just like normal WGSL code

I propose first lowering a generic module, and subsequently monomorphizing it with the rules from above. For simplicity, we will only look at structs and functions.

mod Foo<A, B> {
  struct Cat {
    cuteness: A
  }

  fn bar(cat: Cat<A>) {
    return pow(cat.cuteness, 2);
  }

  fn demo<T>(input: T) -> T {
    let a = Cat<B>(1);
    ...
    return input;
  }
}

To lower this, we

  1. Go over every item (functions and structs in this case)
  2. Recursively analyse what module generics they use
  3. Add these generics to their list of generics
    1. struct Cat<A> { ... }
    2. fn bar<A>(...) { ... }
    3. fn demo<T, B>(...) { ... }
  4. Do monomorphization with these items, instead of duplicating the entire module.

This guarantees that only the bare minimum of code gets duplicated.

To deal with pipeline-overridable constants and bindings, we pass them in as generic parameters.

mod Assertions<Counter, IsEnabled> {
  fn assert(value: bool) {
    if(IsEnabled && !value) {
      atomicAdd(&Counter, 1u)
    }
  }
}

Then, if two separate parts of our library tree depend on the same assertions module, we correctly deal with all cases.

image

Case 1: Sameness

e.g. They pass the same counter to the assertions module, and both enable it. Then we end up with a single monomorphized (assert, my_counter, true). Both parts of the module tree use the same counter

Case 2: Differences

e.g. We use two separate counters for bevy's assert dependency and for math_utils's assert dependency. Then, we end up with the expected two separate assertion functions. (assert, bevy_counter, true) ,(assert, math_counter, true).

Case 3: Differences, but same binding

e.g. We want to use the same binding, because our host code demands there to be one binding. However, we want to turn off the math_utils assertions. In that case, our monomorphization also generates two separate functions (assert, my_counter, true) ,(assert, my_counter, false).

Why not have bindings and pipeline-overridable constants in generic modules?

It is very reasonable to instantiate a generic module twice, with either the same bindings, or with different bindings.

mod Foo<A> {
  @group(0) @binding(3) var<uniform> a: u32;
}

If we want a Foo<u32> and a Foo<f32>, we now have a conflicting binding at group 0, binding 3. To resolve this, we would need to introduce an additional mechanism to override a binding of a generic module.

However, once that is done, the minimal monomorphization algorithm becomes significantly more complicated. To give an example,

mod Foo<A> {
  @group(0) @binding(3) var<uniform> count: u32;
  fn bar() -> bool {
    return A && (count > 5);
  }
  fn baz() {
    return count;    
  }
  fn is_true() -> bool {
      return A;
  }
}

If we depend on Foo<true>, we get the following code. Please ignore the specifics of the name mangling scheme.

@group(0) @binding(3) var<uniform> count: u32;
fn bar__true() -> bool { // (bar, true)
  return true && (count > 5);
}
fn baz() { // (baz)
  return count;    
}
fn is_true__true() -> bool { // (is_true, true)
  return true;
}

Then, if we depend on Foo<false>, we get an additional

fn bar__false() -> bool { // (bar, false)
  return false && (count > 5);
}
fn is_true__false() -> bool { // (is_true, false)
  return false;
}
// everything else is shared

And now if we depend on a Foo<false> with an overriden count, then monomorphization gives us the following.

@group(0) @binding(3) var<uniform> other_count: u32;
// Those are different, since they use a different count
fn bar__false_2() -> bool { // (bar, false)
  return false && (other_count > 5);
}
fn baz_2() { // (baz)
  return other_count;    
}
// But the (is_true, false) can be shared

To implement such behaviour, one essentially ends up treating bindings and pipeline-overridable constants as generic parameters. Which certainly raises the question of whether we need that additional bit of complexity. (Figuring out the exact algorithm and coming up with further examples is left as an exercise to the reader.)

Finally, I believe that disallowing them in generic modules mostly removes an easy footgun for our users. Instead, we should push users towards creating modules that do not depend on a very specific binding.

stefnotch commented 1 month ago

Alternative: There is a slight variation of this proposal that could be a worthwhile alternative.

  1. We allow pipeline overridable constants and bindings in all modules. E.g
    mod Foo<A> {
    @group(0) @binding(3) var<uniform> count: u32;
    ...
    }
  2. All modules get rewritten such that their bindings and overridables are now generics. Here we make use of generic parameter defaults.
    mod Foo_Helper {
    @group(0) @binding(3) var<uniform> count: u32;
    }
    mod Foo<A, count=Foo_Helper.count> {
    ...
    }
  3. Now we have the desired semantics and can use an existing mechanism for overriding a specific binding.
mighdoll commented 1 month ago

For example, instead of repeatedly writing my_generic_func(3)

Are module generics necessarily the solution for DRYing out call site references to generic elements? Module generics seem nifty. But AFAIK module generics are not a common feature in other languages.. which makes me wonder - what makes wesl different?

without requiring full WGSL type checking

If wgsl were to solve this problem, would they implement type inference instead? If so, should we explore some kind of cheap type inference?

Or If type inference is too complicated for now but the natural solution is type inference on generic elements, should look at a less clever but less ornate solution in the interim until type inference is available? (e.g. I think alias was mentioned elsewhere).

Perhaps discussing the motivation for module generics should go in a separate github issue? Feel free to split it out if so.

k2d222 commented 1 month ago

Are module generics necessarily the solution for DRYing out call site references to generic elements?

@mighdoll generic modules are not about DRY imo, generic functions are. One reason wgsl is different is that code depends heavily on bindings, which the modules can use but should not declare. That is the responsibility of the caller. So we need a mechanism to inject code into modules. One way is passing pointers to all functions (but that is troublesome), another is virtual or overridable module declarations, and another is generic modules.

We have been exploring these different possibilites. @ncthbrt 's proposal has generic modules and mine has overridable declarations. The two approaches seem to achieve the same effect.

It does raise the questions: Are there other ways? Would it be sufficient to provide a way to override just bindings and pipeline-overridable constants?

We should make another issue.

ncthbrt commented 1 month ago

Regarding passing bindings in as generic parameters, I think that could be simply syntactic sugar for something like the following:

mod A<Binding> {
     fn fun_func() {
          return 2.0 + Binding::value();
     }
}

override my_override: f32;

mod MyOverride {
    @inline
    fn value() {
         return my_override;
    }
}
alias MyA = A<MyOverride>;

It'd look nearly the same from the perspective of the user if we later added that sugar but implementation wise, would reduce the number of constructs generic modules would have to handle. Particularly in the beginning

ncthbrt commented 1 month ago

I originally proposed them because shaders are typically a lot less object oriented than most other languages.

I think of them as zero cost static classes.

ncthbrt commented 1 month ago

@stefnotch I think bindings should be allowed for use cases like bevy's. Perhaps the specialisation algorithm could test if the bindings resolve to the same type? And if so allow them to be shared.

stefnotch commented 1 month ago

@stefnotch I think bindings should be allowed for use cases like bevy's. Perhaps the specialisation algorithm could test if the bindings resolve to the same type? And if so allow them to be shared.

@ncthbrt The variation of this proposal would do exactly that, using the existing algorithm. https://github.com/wgsl-tooling-wg/wesl-spec/issues/40#issuecomment-2364732805

stefnotch commented 1 month ago

@mighdoll Yes, the motivation for module generics in the first place is a separate issue that I've just glossed over. My main focus was "how would we make this work, and what would the desired semantics be".

If we end up settling for anything else, I suspect that we will end up with a very similar algorithm for bindings and overridables. So this proposal really is about the semantics.

stefnotch commented 1 month ago

We have been exploring these different possibilites. @ncthbrt 's proposal has generic modules and mine has overridable declarations. The two approaches seem to achieve the same effect.

@k2d222 I agree that they achieve the same thing. In fact, I'd go so far as to saying the overrides that you have described behave like

@override(texture_bind_group: 3, texture_type: u32)
import util/sample_texture/{sample as sample_u32};

@override(texture_bind_group: 5, texture_type: f32)
import util/sample_texture/{sample as sample_f32};
  1. Go over every item
  2. Recursively analyse what overridables they use
  3. Add these to their list of generics
    1. (sample, (3, _, u32))
    2. (sample, (5, _, u32))
  4. Do monomorphization with these items
ncthbrt commented 1 month ago

Regarding generics vs overrides: As I've started implementing generics, I realised that there is a need for optional, named generic arguments with default values, which is even closer semantically to overrides

stefnotch commented 1 month ago

@ncthbrt We just have to be careful to make the overrides semantics "take a module and a set of overrides, and use that to create a new module".

Otherwise we run into unsolvable conflicts when two libraries try to override the same module in different ways.

ncthbrt commented 1 month ago

Yes. I was thinking of it in those terms @stefnotch. It'd be akin to named optional arguments in languages like JavaScript