Clay-foundation / model

The Clay Foundation Model (in development)
https://clay-foundation.github.io/model/
Apache License 2.0
236 stars 25 forks source link

Float 16 on GPU? #271

Open cjerzak opened 3 weeks ago

cjerzak commented 3 weeks ago

Hey Clay! Great work on v1.

I've been running some experiments with the model.

Do you know if anyone on your end has been able to get float16 precision working on CPU or GPU?

I tried both on using model.half() and feeding in float16 tensors. But an error indicating that somewhere an incompatible float32 array is being used somewhere in the model (I think it might be traced to the metadata part?). In any case, setting torch.set_default_dtype(torch.float16) didn't seem to do the trick. I will keep trying.

yellowcap commented 3 weeks ago

Hi @cjerzak thanks for the compliment! 😊

I am not sure if it is possible to use float16 with Clay. That is probably a question @srmsoumya can answer best. But could you explain why float32 is not an option for you? Is it to optimize GPU usage?

cjerzak commented 3 weeks ago

Thank you! Mostly we were just curious to see if we could further optimize the model in terms of runtime (by running the model with larger batches of images and so forth).