potamides / DeTikZify

Synthesizing Graphics Programs for Scientific Figures and Sketches with TikZ
Apache License 2.0
453 stars 11 forks source link

Compatibility issues with macOS Metal (MPS) #1

Closed thebluepotato closed 3 months ago

thebluepotato commented 3 months ago

Hi! I just tried to DeTikZify a flowchart with the local webUI and while it was churning away and printing some TikZ code, it crashed due to the following error:

  File "/Users/user/Developer/DeTikZify/detikzify/evaluate/imagesim.py", line 93, in get_similarity
    return F.cosine_similarity(img1_feats.double(), img2_feats.double(), dim=0).item()
                               ^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Here's the full traceback: https://pastebin.com/qPDsnbLi

I also had to manually downgrade transformers to 4.38.1 based on this: https://github.com/huggingface/transformers/issues/29431

Note: this was with the MCTS option enabled. Tried it again with regular sampling and while it completed, the result was not very convincing.

thebluepotato commented 3 months ago

Changing .double() to .float() seems to fix this, unsurprisingly, but I don't know if it really affects quality.

potamides commented 3 months ago

Thanks for the bug report! We can probably just move the tensors to cpu first before calling .double(), but since this seems to have a small performance impact for EMD (not the default) I will implement this only for backends that do not support double.