Open Nie-Yingying opened 2 months ago
Hi @Nie-Yingying!
I have a suggestion to run XCOMET-XXL in a 40GB but its still not integrated. In the file: comet/encoders/xlmr_xl.py
Replace the model init to load in 16bits:
def __init__(
self, pretrained_model: str, load_pretrained_weights: bool = True
) -> None:
super(Encoder, self).__init__()
self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(pretrained_model)
if load_pretrained_weights:
self.model = XLMRobertaXLModel.from_pretrained(
pretrained_model, add_pooling_layer=False
)
else:
print ("Loading model in f16")
self.model = XLMRobertaXLModel(
XLMRobertaXLConfig.from_pretrained(pretrained_model, torch_dtype=torch.float16, device_map="auto"),
add_pooling_layer=False
)
self.model.encoder.output_hidden_states = True
this will load the model with half its memory and should solve your problem. I'll integrate this soon
@ricardorei I did something very similar for the XL. I actually converted it in fp16 then I just changed one line in the feedforward.py But after I wanted to go even further and use bitsandbytes/HF load_in_8_bit / load_in_4_bit = True but the integration is a mess between lightning and HF. Last, FYI I did this as a WIP: https://huggingface.co/vince62s/wmt23-cometkiwi-da-roberta-xl adapting your code in the existing HF XLM-roberta-XL code. We are trying to implement it in CTranslate2 for much faster inference.
this will load the model with half its memory and should solve your problem. I'll integrate this soon
sorry to tell you and it's still oom
❓ Questions and Help
Before asking:
What is your question?
I can predict scores with only cpu successfully. But when loaded model to gpu, there is oom error.
Code
from comet import download_model, load_from_checkpoint
model_path = download_model("Unbabel/XCOMET-XXL")
model_path = "./XCOMET-XXL/checkpoints/model.ckpt" model = load_from_checkpoint(model_path,reload_hparams=True) data = [ { "src": "Boris Johnson teeters on edge of favour with Tory MPs", "mt": "Boris Johnson ist bei Tory-Abgeordneten völlig in der Gunst", "ref": "Boris Johnsons Beliebtheit bei Tory-MPs steht auf der Kippe" } ] model_output = model.predict(data, batch_size=1, gpus=1)
Segment-level scores
print (model_output.scores)
System-level score
print (model_output.system_score)
Score explanation (error spans)
print (model_output.metadata.error_spans)
hparams.yaml![image](https://github.com/Unbabel/COMET/assets/65881015/593d40a2-4bcb-43b5-89af-967454ae4217)
#### What have you tried? #### What's your environment? # Name Version Build Channel _libgcc_mutex 0.1 main _openmp_mutex 5.1 1_gnu aiohttp 3.9.5 pypi_0 pypi aiosignal 1.3.1 pypi_0 pypi async-timeout 4.0.3 pypi_0 pypi attrs 23.2.0 pypi_0 pypi ca-certificates 2024.3.11 h06a4308_0 certifi 2024.2.2 pypi_0 pypi charset-normalizer 3.3.2 pypi_0 pypi colorama 0.4.6 pypi_0 pypi entmax 1.3 pypi_0 pypi filelock 3.13.4 pypi_0 pypi frozenlist 1.4.1 pypi_0 pypi fsspec 2024.3.1 pypi_0 pypi huggingface-hub 0.22.2 pypi_0 pypi idna 3.7 pypi_0 pypi jinja2 3.1.3 pypi_0 pypi jsonargparse 3.13.1 pypi_0 pypi ld_impl_linux-64 2.38 h1181459_1 libffi 3.4.4 h6a678d5_0 libgcc-ng 11.2.0 h1234567_1 libgomp 11.2.0 h1234567_1 libstdcxx-ng 11.2.0 h1234567_1 lightning-utilities 0.11.2 pypi_0 pypi lxml 5.2.1 pypi_0 pypi markupsafe 2.1.5 pypi_0 pypi mpmath 1.3.0 pypi_0 pypi multidict 6.0.5 pypi_0 pypi ncurses 6.4 h6a678d5_0 networkx 3.1 pypi_0 pypi numpy 1.24.4 pypi_0 pypi nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi nvidia-curand-cu12 10.3.2.106 pypi_0 pypi nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi nvidia-nccl-cu12 2.19.3 pypi_0 pypi nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi nvidia-nvtx-cu12 12.1.105 pypi_0 pypi openssl 3.0.13 h7f8727e_0 packaging 24.0 pypi_0 pypi pandas 2.0.3 pypi_0 pypi pip 24.0 pypi_0 pypi portalocker 2.8.2 pypi_0 pypi protobuf 4.25.3 pypi_0 pypi python 3.8.19 h955ad1f_0 python-dateutil 2.9.0.post0 pypi_0 pypi pytorch-lightning 2.2.2 pypi_0 pypi pytz 2024.1 pypi_0 pypi pyyaml 6.0.1 pypi_0 pypi readline 8.2 h5eee18b_0 regex 2024.4.16 pypi_0 pypi requests 2.31.0 pypi_0 pypi sacrebleu 2.4.2 pypi_0 pypi safetensors 0.4.3 pypi_0 pypi scipy 1.10.1 pypi_0 pypi sentencepiece 0.1.99 pypi_0 pypi setuptools 68.2.2 py38h06a4308_0 six 1.16.0 pypi_0 pypi sqlite 3.41.2 h5eee18b_0 sympy 1.12 pypi_0 pypi tabulate 0.9.0 pypi_0 pypi tk 8.6.12 h1ccaba5_0 tokenizers 0.15.2 pypi_0 pypi torch 2.2.2 pypi_0 pypi torchmetrics 0.10.3 pypi_0 pypi tqdm 4.66.2 pypi_0 pypi transformers 4.39.3 pypi_0 pypi triton 2.2.0 pypi_0 pypi typing-extensions 4.11.0 pypi_0 pypi tzdata 2024.1 pypi_0 pypi unbabel-comet 2.2.2 pypi_0 pypi urllib3 2.2.1 pypi_0 pypi wheel 0.41.2 py38h06a4308_0 xz 5.4.6 h5eee18b_0 yarl 1.9.4 pypi_0 pypi zlib 1.2.13 h5eee18b_0