kuprel / min-dalle

min(DALL·E) is a fast, minimal port of DALL·E Mini to PyTorch
MIT License
3.48k stars 255 forks source link

GPU OOM if model ran in Python multithreading #101

Open xcharleslin opened 1 year ago

xcharleslin commented 1 year ago

Minimum repro:

import torch
from min_dalle import MinDalle
from concurrent.futures import ThreadPoolExecutor

USE_GPU = True
def f(text: str, root: str):
    return MinDalle(
        models_root=f'./{root}',
        dtype=torch.float32,
        device="cuda",
        is_mega=False, 
        is_reusable=True,
    ).generate_image(
        text,
        seed=-1,
        grid_size=1,
        is_seamless=False,
        temperature=1,
        top_k=256,
        supercondition_factor=32,
    )

# No threading works
f("hello", "root1")  

# Threading does not work
tpe = ThreadPoolExecutor()
tpe.submit(f, "hello2", "root2").result()  # GPU OOMs here

The last line fails with OutOfMemoryError: CUDA out of memory.

(click for full stack trace) ```python using device cuda downloading tokenizer params intializing TextTokenizer downloading encoder params initializing DalleBartEncoder downloading decoder params initializing DalleBartDecoder downloading detokenizer params initializing VQGanDetokenizer --------------------------------------------------------------------------- OutOfMemoryError Traceback (most recent call last) in 5 fut = tpe.submit(f, "abc", "def") 6 ----> 7 fut.result() 12 frames /usr/lib/python3.8/concurrent/futures/_base.py in result(self, timeout) 442 raise CancelledError() 443 elif self._state == FINISHED: --> 444 return self.__get_result() 445 else: 446 raise TimeoutError() /usr/lib/python3.8/concurrent/futures/_base.py in __get_result(self) 387 if self._exception: 388 try: --> 389 raise self._exception 390 finally: 391 # Break a reference cycle with the exception in self._exception /usr/lib/python3.8/concurrent/futures/thread.py in run(self) 55 56 try: ---> 57 result = self.fn(*self.args, **self.kwargs) 58 except BaseException as exc: 59 self.future.set_exception(exc) in f(text, dir) 5 USE_GPU = True 6 def f(text: str, dir: str) -> PIL.Image.Image: ----> 7 return MinDalle( 8 models_root=f'./{dir}', 9 dtype=torch.float32, /usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image(self, *args, **kwargs) 279 progressive_outputs=False 280 ) --> 281 return next(image_stream) 282 283 /usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image_stream(self, *args, **kwargs) 259 def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]: 260 image_stream = self.generate_raw_image_stream(*args, **kwargs) --> 261 for image in image_stream: 262 image = image.to(torch.uint8).to('cpu').numpy() 263 yield Image.fromarray(image) /usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_raw_image_stream(self, text, seed, grid_size, progressive_outputs, is_seamless, temperature, top_k, supercondition_factor, is_verbose) 238 torch.cuda.empty_cache() 239 with torch.cuda.amp.autocast(dtype=self.dtype): --> 240 image_tokens[:, i + 1], attention_state = self.decoder.sample_tokens( 241 settings=settings, 242 attention_mask=attention_mask, /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in sample_tokens(self, settings, **kwargs) 175 176 def sample_tokens(self, settings, **kwargs) -> Tuple[LongTensor, FloatTensor]: --> 177 logits, attention_state = self.forward(**kwargs) 178 image_count = logits.shape[0] // 2 179 temperature = settings[[0]] /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, attention_mask, encoder_state, attention_state, prev_tokens, token_index) 162 decoder_state = self.layernorm_embedding.forward(decoder_state) 163 for i in range(self.layer_count): --> 164 decoder_state, attention_state[i] = self.layers[i].forward( 165 decoder_state, 166 encoder_state, /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, encoder_state, attention_state, attention_mask, token_index) 88 residual = decoder_state 89 decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) ---> 90 decoder_state, attention_state = self.self_attn.forward( 91 decoder_state=decoder_state, 92 attention_state=attention_state, /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, attention_state, attention_mask, token_index) 43 values = attention_state[batch_count:] 44 ---> 45 decoder_state = super().forward(keys, values, queries, attention_mask) 46 return decoder_state, attention_state 47 /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_encoder.py in forward(self, keys, values, queries, attention_mask) 47 queries /= queries.shape[-1] ** 0.5 48 attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12 ---> 49 attention_weights: FloatTensor = torch.einsum( 50 'bqhc,bkhc->bhqk', 51 queries, /usr/local/lib/python3.8/dist-packages/torch/functional.py in einsum(*args) 376 # the path for contracting 0 or 1 time(s) is already optimized 377 # or the user has disabled using opt_einsum --> 378 return _VF.einsum(equation, operands) # type: ignore[attr-defined] 379 380 path = None OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 14.76 GiB total capacity; 13.67 GiB already allocated; 17.88 MiB free; 13.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF ```