segment-any-text / wtpsplit

Toolkit to segment text into sentences or other semantic units in a robust, efficient and adaptable way.
MIT License
753 stars 44 forks source link

Canine model and High VRAM usage #115

Closed Qubitium closed 1 month ago

Qubitium commented 9 months ago

@bminixhofer We are observing very high vram usage with canine model even though the wtp-canine-s-12l-no-adapters fp32 weights are only about 515MB so we naively expected batch=1 in fp16 mode to use 207.5MB of ram for weights plus runtime/inference costs. We didn't expect batch=1 vram to be 1.3GB. Input text is around 230kb text file.

Is this a bug or architecture norm for the canine model? If norm, is there anything that we can do to reduce the memory footprint? Thanks.

wtp = WtP("wtp-canine-s-12l-no-adapters")
wtp.half().to(device="cuda")
batch vram GB
1 1.309
2 1.335
4 1.385
6 1.428
8 1.487
10 1.542
12 1.583
14 1.639
16 1.688
32 2.094
bminixhofer commented 9 months ago

Hi, thanks for these benchmarks! And sorry for being slow to respond.

You could debug this by checking how much memory the vanilla CANINE (https://huggingface.co/google/canine-s) takes for a forward pass vs. a forward pass of the WtP model (see e.g. here: https://github.com/bminixhofer/wtpsplit/?tab=readme-ov-file#advanced-usage).

If there's a discrepancy there I'll investigate it. It's possible that CANINE just needs a lot of memory though, I am not super happy with that architecture and will upgrade the models to a different arch soon(ish).

Qubitium commented 9 months ago

Will do. Btw, if you need gpu compute to train the next model, I can provide you with a A100 80+G. You can ping me up on Twitter at qbitium.

bminixhofer commented 8 months ago

Thanks! And that's very generous, deferring to @markus583 since he is doing the training but we are using TPUs so there is probably no need.

markus583 commented 8 months ago

Very generous indeed! Thanks but the TPUs are very strong. I'd be very curious whether there is a discrepancy too.

markus583 commented 1 month ago

Hi,

Could you check if this is still true for the newer SaT models?

Qubitium commented 1 month ago

@markus583 Confirmed using torch and v2.0.8 with sat-3l-sm model usage on 4090 is around 784MB. This is at least 2x lower vram usage than older non-sat models.

EDIT: not exactly apples to apples since the SaT 3L model is 3 layers vs our previous canine 12L test. I can test out the SaT 12L vs Canine 12L tomorrow.

markus583 commented 1 month ago

Great! Many thanks for checking, this sounds very reasonable to me. If you still encounter memory issues, you could try sat-1-sm.