teamtomo / membrain-seg

membrane segmentation in 3D for cryo-ET
Other
47 stars 12 forks source link

Accelerating inference performance with torch.compile #67

Open saugatkandel opened 1 month ago

saugatkandel commented 1 month ago

I noticed that membrain-seg does not currently use jit compilation. Using torch.compile (https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) can be a simple way to accelerate the model performance for training and inference. Just wrapping

model_new = torch.compile(model)

is usually sufficient to increase the performance by 2x. Are there any barriers to incorporating such a change within membrain-seg?

I have been using this in my own membrain-seg wrappers, and it has worked very well. However, making jit work with on-the-fly Fourier cropping and rescaling could require more work than just this simple one-line change.