Open kpaigwar opened 1 month ago
######################################################################################## # Spec 1 ######################################################################################## fused_query_key_value = {'shape' : [1, 1, 32, 1280], 'shard_shape' : (32, 32)} all_gather_output = {'shape' : [4, 1, 32, 1280], 'shard_shape' : (32*4, 32)} output_mem_config = ttnn.create_sharded_memory_config( shape=(32*4, 32), core_grid=ttnn.CoreGrid(y=5, x=8), strategy=ttnn.ShardStrategy.WIDTH, orientation=ttnn.ShardOrientation.ROW_MAJOR, use_height_and_width_as_shard_shape=True, ) gathered_tensor = ttnn.line_all_gather(fused_query_key_value, dim=0, num_links=2, cluster_axis=1, device_mesh=self.device_mesh, memory_config=output_mem_config) ######################################################################################## # Spec 2 ######################################################################################## attn_output_tensor = {'shape' : [1, 1, 32, 2048], 'shard_shape' : [1, 1, 32, 64]} all_gather_output = {'shape' : [8, 1, 32, 2048], 'shard_shape' : (32*8, 64)} output_mem_config = ttnn.create_sharded_memory_config( shape=(32*8, 64), core_grid=ttnn.CoreGrid(y=4, x=8), strategy=ttnn.ShardStrategy.WIDTH, orientation=ttnn.ShardOrientation.ROW_MAJOR, use_height_and_width_as_shard_shape=True, ) gathered_tensor = ttnn.line_all_gather(attn_output_tensor, dim=0, num_links=2, cluster_axis=0, device_mesh=self.device_mesh, memory_config=output_mem_config)
fyi @SeanNijjar @cglagovich