GilesStrong / tomopt

TomOpt: Differential Muon Tomography Optimisation
GNU Affero General Public License v3.0
4 stars 0 forks source link

Memory consumption of VoxelNet limits the number of muons and voxels that can be used #97

Open GilesStrong opened 2 years ago

GilesStrong commented 2 years ago

Problem

VoxelNet acts on tensors of the size (volumes, voxels, muons, features) and as part of its graph construction expands these into (volumes, voxels, muons, muons, new features) before collapsing back to the original shape. Although the forward method runs a loop over the volumes, (so the actual shape is just (voxels, muons, features)), the memory consumption is still very high.

Potential solutions

Loop over voxels in the first part of the network

The first part of the network computes a muon representation per voxel (voxels, muon representation) and this computation is performed irrespective of the other voxels. Meaning that the muon reps. could be computed serially rather than in parallel. This reduces the memory consumption at the cost of processing time.

Compile parts of the network

PyTorch makes it "easy" to compile parts of the network in c++ and CUDA. According to Jan Kieseler this heavily reduces memory consumption and processing time at the cost of development time and model flexibility. He has sent me some examples, and I have also gone through the official PyTorch tutorial for writing and compiling kernels. The main difficulty is that the backwards pass to compute the gradients must also be written manually, and the optimality of the writing of this can have a heavy impact on performance: in my testing of PyTorch's examples, the backwards pass was actually slower when compiled, but the forwards pass was slightly quicker.

There are several parts of the GNN that care candidates for compilation: