Closed Qubitium closed 1 month ago
Hi, thanks for these benchmarks! And sorry for being slow to respond.
You could debug this by checking how much memory the vanilla CANINE (https://huggingface.co/google/canine-s) takes for a forward pass vs. a forward pass of the WtP model (see e.g. here: https://github.com/bminixhofer/wtpsplit/?tab=readme-ov-file#advanced-usage).
If there's a discrepancy there I'll investigate it. It's possible that CANINE just needs a lot of memory though, I am not super happy with that architecture and will upgrade the models to a different arch soon(ish).
Will do. Btw, if you need gpu compute to train the next model, I can provide you with a A100 80+G. You can ping me up on Twitter at qbitium.
Thanks! And that's very generous, deferring to @markus583 since he is doing the training but we are using TPUs so there is probably no need.
Very generous indeed! Thanks but the TPUs are very strong. I'd be very curious whether there is a discrepancy too.
Hi,
Could you check if this is still true for the newer SaT models?
@markus583 Confirmed using torch and v2.0.8 with sat-3l-sm
model usage on 4090 is around 784MB
. This is at least 2x lower vram usage than older non-sat models.
EDIT: not exactly apples to apples since the SaT 3L model is 3 layers vs our previous canine 12L test. I can test out the SaT 12L vs Canine 12L tomorrow.
Great! Many thanks for checking, this sounds very reasonable to me. If you still encounter memory issues, you could try sat-1-sm
.
@bminixhofer We are observing very high vram usage with canine model even though the
wtp-canine-s-12l-no-adapters
fp32 weights are only about 515MB so we naively expected batch=1 in fp16 mode to use 207.5MB of ram for weights plus runtime/inference costs. We didn't expect batch=1 vram to be 1.3GB. Input text is around 230kb text file.Is this a bug or architecture norm for the canine model? If norm, is there anything that we can do to reduce the memory footprint? Thanks.