MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
169 stars 22 forks source link

How to implement get_output() function for segmentation task? #47

Closed yangchaohua closed 11 months ago

yangchaohua commented 11 months ago

Hi! Thank you for sharing this work. I want to use TRAK for my segmentation task (each pixel of the model is represented as a binary classification of whether it is a foreground or not), I don't know which method is more suitable for calculating f(z; θ), sum of the output or average of the output. Is there any other better method?

kristian-georgiev commented 11 months ago

Hi @yangchaohua. We haven't ran experiments with segmentation models. Based on our preliminary work with diffusion models (where the output is also high-dimensional), my guess is that the sum of per-pixel outputs should be a good start.