google / aqt

Apache License 2.0
247 stars 25 forks source link

[pallas] Relands PR #22552: Simplify handling of BlockMapping and GridMapping #676

Closed copybara-service[bot] closed 1 month ago

copybara-service[bot] commented 1 month ago

[pallas] Relands PR #22552: Simplify handling of BlockMapping and GridMapping

BlockSpec, GridSpec and PrefetchScalarGridSpec are now simple dataclasses that just store the parameters passed from the API. They are then canonicalized and coverted to BlockMapping and GridMapping, which contains fewer optional metadata. In particular, BlockMapping is never None. This consolidates the code to preprocess the block and grid parameters, and simplifies the code downstream.

grid now defaults to () instead of None.

Added more fields to BlockMapping (block_aval, array_shape_dtype, and source). The source field is used in error messages. The array_shape_dtype makes it unnecessary to process BlockMappings zipped with in_shapes. With these fields, we can now add a check_invariants method that is called during testing or when config.enable_checks is true.

Added more fields and a check_invariants to GridMapping, since it is such an important data structure. The new fields are: index_map_avals, index_map_tree (to encode the calling convention for the index map functions), num_inputs, num_outputs. The latter make it possible to recover the in_shapes and out_shapes from the GridMapping. Previously there was some redundancy of information between in_shapes and out_shapes.

Now we do not need the in_shapes and out_shapes parameters to pallas_call_p, since it already has grid_mapping.

Moved some of the logic for handling scalar prefetch and scratch shapes from PrefetchScalarGridSpec.get_grid_mapping to GridSpec.get_grid_mapping, and thus removed code duplication.

Removed some dead code for implementing the interpret mode.

Previous handling of hoisted consts did not account for them in in_shapes. Now, this is fixed since we do not keep track of in_shapes separately.

Renamed GridMapping.mapped_dims to GridMapping.vmapped_dims to avoid confusion with the use of mapped in block shapes.

Added test for the calling convention, including dynamic grid dimensions.

There is more work to be done: with the new information in GridMapping it should be possible to clean the code throughout that extract various parts of the inputs and outputs. This should be a bunch of local changes, which I will do separately once I merge this large global change.