Eventual-Inc / Daft

Distributed DataFrame for Python designed for the cloud, powered by Rust
https://getdaft.io
Apache License 2.0
1.79k stars 108 forks source link

Text to image generation notebook failing on GPU oom #609

Closed jaychia closed 1 year ago

jaychia commented 1 year ago

Describe the bug

See notebook: https://colab.research.google.com/github/Eventual-Inc/Daft/blob/main/tutorials/text_to_image/text_to_image_generation.ipynb#scrollTo=b500e7f5

Our multithreaded PyRunner may not be respecting GPU/CPU requests, running multiple tasks in parallel when it should not be able to.

xcharleslin commented 1 year ago

@jaychia In that link I get a GPU OOM, but I can't find any evidence that multiple tasks are being ran in parallel. I see a single model init and a single udf call before the oom.

GenerateImageFromTextGPU.__init__()
using device cuda
downloading tokenizer params
intializing TextTokenizer
downloading encoder params
initializing DalleBartEncoder
downloading decoder params
initializing DalleBartDecoder
downloading detokenizer params
initializing VQGanDetokenizer
GenerateImageFromTextGPU.__call__(['Photo pour Japanese pagoda and old house in Kyoto at twilight - image libre de droit'])
ERROR:daft.udf:Encountered error when running user-defined function GenerateImageFromTextGPU
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
[<ipython-input-9-f70ecca7159b>](https://localhost:8080/#) in <module>
     34 
     35 resource_request = ResourceRequest(num_gpus=1) if USE_GPU else None
---> 36 images_df.with_column(
     37     "generated_image",
     38     GenerateImageFromTextGPU(images_df["TEXT"]),

30 frames
[/usr/local/lib/python3.8/dist-packages/torch/functional.py](https://localhost:8080/#) in einsum(*args)
    376         # the path for contracting 0 or 1 time(s) is already optimized
    377         # or the user has disabled using opt_einsum
--> 378         return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    379 
    380     path = None

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 14.76 GiB total capacity; 13.96 GiB already allocated; 3.88 MiB free; 13.98 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

The dataframe also only has a single partition.

Do you have more info about the resource request violation?

jaychia commented 1 year ago

Ah I could be wrong then - I had thought that was the reason why since nothing else really changed that would potentially cause this issue. Feel free to assign the issue back to me and I can investigate further

xcharleslin commented 1 year ago

@jaychia I'll run with it for a bit longer :)

xcharleslin commented 1 year ago

It does seem to be the multithreading PyRunner PR though. Notebook works on e73ddcd and breaks on the next nightly, 342535e

xcharleslin commented 1 year ago
Click for full stack trace ```python GenerateImageFromTextGPU.__init__() using device cuda downloading tokenizer params intializing TextTokenizer downloading encoder params initializing DalleBartEncoder downloading decoder params initializing DalleBartDecoder downloading detokenizer params initializing VQGanDetokenizer GenerateImageFromTextGPU.__call__(['Photo pour Japanese pagoda and old house in Kyoto at twilight - image libre de droit']) ERROR:daft.udf:Encountered error when running user-defined function GenerateImageFromTextGPU --------------------------------------------------------------------------- OutOfMemoryError Traceback (most recent call last) in 34 35 resource_request = ResourceRequest(num_gpus=1) if USE_GPU else None ---> 36 images_df.with_column( 37 "generated_image", 38 GenerateImageFromTextGPU(images_df["TEXT"]), 30 frames /usr/local/lib/python3.8/dist-packages/daft/api_annotations.py in _wrap(*args, **kwargs) 13 def _wrap(*args, **kwargs): 14 timed_method = time_df_method(func) ---> 15 return timed_method(*args, **kwargs) 16 17 return _wrap /usr/local/lib/python3.8/dist-packages/daft/analytics.py in tracked_method(*args, **kwargs) 170 start = time.time() 171 try: --> 172 result = method(*args, **kwargs) 173 except Exception as e: 174 _ANALYTICS_CLIENT.track_df_method_call( /usr/local/lib/python3.8/dist-packages/daft/dataframe/dataframe.py in show(self, n) 206 df = df.limit(n) 207 --> 208 df.collect(num_preview_rows=n) 209 result = df._result 210 assert result is not None /usr/local/lib/python3.8/dist-packages/daft/api_annotations.py in _wrap(*args, **kwargs) 13 def _wrap(*args, **kwargs): 14 timed_method = time_df_method(func) ---> 15 return timed_method(*args, **kwargs) 16 17 return _wrap /usr/local/lib/python3.8/dist-packages/daft/analytics.py in tracked_method(*args, **kwargs) 170 start = time.time() 171 try: --> 172 result = method(*args, **kwargs) 173 except Exception as e: 174 _ANALYTICS_CLIENT.track_df_method_call( /usr/local/lib/python3.8/dist-packages/daft/dataframe/dataframe.py in collect(self, num_preview_rows) 1183 DataFrame: DataFrame with materialized results. 1184 """ -> 1185 self._materialize_results() 1186 1187 assert self._result is not None /usr/local/lib/python3.8/dist-packages/daft/dataframe/dataframe.py in _materialize_results(self) 1165 context = get_context() 1166 if self._result is None: -> 1167 self._result_cache = context.runner().run(self._plan) 1168 result = self._result 1169 assert result is not None /usr/local/lib/python3.8/dist-packages/daft/runners/pyrunner.py in run(self, logplan) 206 done_task = inflight_tasks.pop(done_id) 207 --> 208 partitions = done.result() 209 210 if isinstance(done_task, MultiOutputExecutionStep): /usr/lib/python3.8/concurrent/futures/_base.py in result(self, timeout) 435 raise CancelledError() 436 elif self._state == FINISHED: --> 437 return self.__get_result() 438 439 self._condition.wait(timeout) /usr/lib/python3.8/concurrent/futures/_base.py in __get_result(self) 387 if self._exception: 388 try: --> 389 raise self._exception 390 finally: 391 # Break a reference cycle with the exception in self._exception /usr/lib/python3.8/concurrent/futures/thread.py in run(self) 55 56 try: ---> 57 result = self.fn(*self.args, **self.kwargs) 58 except BaseException as exc: 59 self.future.set_exception(exc) /usr/local/lib/python3.8/dist-packages/daft/runners/pyrunner.py in build_partitions(instruction_stack, *inputs) 251 partitions = list(inputs) 252 for instruction in instruction_stack: --> 253 partitions = instruction.run(partitions) 254 255 return partitions /usr/local/lib/python3.8/dist-packages/daft/execution/execution_step.py in run(self, inputs) 243 244 def run(self, inputs: list[vPartition]) -> list[vPartition]: --> 245 return self._project(inputs) 246 247 def _project(self, inputs: list[vPartition]) -> list[vPartition]: /usr/local/lib/python3.8/dist-packages/daft/execution/execution_step.py in _project(self, inputs) 247 def _project(self, inputs: list[vPartition]) -> list[vPartition]: 248 [input] = inputs --> 249 return [input.eval_expression_list(self.projection)] 250 251 /usr/local/lib/python3.8/dist-packages/daft/runners/partitioning.py in eval_expression_list(self, exprs) 207 208 def eval_expression_list(self, exprs: ExpressionList) -> vPartition: --> 209 tile_list = [self.eval_expression(e) for e in exprs] 210 new_columns = {t.column_name: t for t in tile_list} 211 return vPartition(columns=new_columns, partition_id=self.partition_id) /usr/local/lib/python3.8/dist-packages/daft/runners/partitioning.py in (.0) 207 208 def eval_expression_list(self, exprs: ExpressionList) -> vPartition: --> 209 tile_list = [self.eval_expression(e) for e in exprs] 210 new_columns = {t.column_name: t for t in tile_list} 211 return vPartition(columns=new_columns, partition_id=self.partition_id) /usr/local/lib/python3.8/dist-packages/daft/runners/partitioning.py in eval_expression(self, expr) 201 required_blocks[name] = block 202 exec = ExpressionExecutor() --> 203 result = exec.eval(expr, required_blocks) 204 expr_name = expr.name() 205 assert expr_name is not None /usr/local/lib/python3.8/dist-packages/daft/expressions.py in eval(self, expr, operands) 97 return DataBlock.make_block(result) 98 elif isinstance(expr, AliasExpression): ---> 99 result = self.eval(expr._expr, operands) 100 return result 101 elif isinstance(expr, CallExpression): /usr/local/lib/python3.8/dist-packages/daft/expressions.py in eval(self, expr, operands) 117 eval_kwargs = {kw: self.eval(a, operands) for kw, a in expr._kwargs.items()} 118 --> 119 results = expr._func(*eval_args, **eval_kwargs) 120 121 if ExpressionType.is_py(expr._func_ret_type): /usr/local/lib/python3.8/dist-packages/daft/udf.py in pre_process_data_block_func(*args, **kwargs) 258 259 try: --> 260 results = initialized_func(*converted_args, **converted_kwargs) 261 except: 262 logger.error(f"Encountered error when running user-defined function {func.__name__}") in __call__(self, text_col) 21 def __call__(self, text_col): 22 print(f"GenerateImageFromTextGPU.__call__({text_col})") ---> 23 return [ 24 self.model.generate_image( 25 t, in (.0) 22 print(f"GenerateImageFromTextGPU.__call__({text_col})") 23 return [ ---> 24 self.model.generate_image( 25 t, 26 seed=-1, /usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image(self, *args, **kwargs) 279 progressive_outputs=False 280 ) --> 281 return next(image_stream) 282 283 /usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image_stream(self, *args, **kwargs) 259 def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]: 260 image_stream = self.generate_raw_image_stream(*args, **kwargs) --> 261 for image in image_stream: 262 image = image.to(torch.uint8).to('cpu').numpy() 263 yield Image.fromarray(image) /usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_raw_image_stream(self, text, seed, grid_size, progressive_outputs, is_seamless, temperature, top_k, supercondition_factor, is_verbose) 238 torch.cuda.empty_cache() 239 with torch.cuda.amp.autocast(dtype=self.dtype): --> 240 image_tokens[:, i + 1], attention_state = self.decoder.sample_tokens( 241 settings=settings, 242 attention_mask=attention_mask, /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in sample_tokens(self, settings, **kwargs) 175 176 def sample_tokens(self, settings, **kwargs) -> Tuple[LongTensor, FloatTensor]: --> 177 logits, attention_state = self.forward(**kwargs) 178 image_count = logits.shape[0] // 2 179 temperature = settings[[0]] /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, attention_mask, encoder_state, attention_state, prev_tokens, token_index) 162 decoder_state = self.layernorm_embedding.forward(decoder_state) 163 for i in range(self.layer_count): --> 164 decoder_state, attention_state[i] = self.layers[i].forward( 165 decoder_state, 166 encoder_state, /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, encoder_state, attention_state, attention_mask, token_index) 88 residual = decoder_state 89 decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) ---> 90 decoder_state, attention_state = self.self_attn.forward( 91 decoder_state=decoder_state, 92 attention_state=attention_state, /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, attention_state, attention_mask, token_index) 43 values = attention_state[batch_count:] 44 ---> 45 decoder_state = super().forward(keys, values, queries, attention_mask) 46 return decoder_state, attention_state 47 /usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_encoder.py in forward(self, keys, values, queries, attention_mask) 54 attention_weights += attention_bias 55 attention_weights = torch.softmax(attention_weights, -1) ---> 56 attention_output: FloatTensor = torch.einsum( 57 "bhqk,bkhc->bqhc", 58 attention_weights, /usr/local/lib/python3.8/dist-packages/torch/functional.py in einsum(*args) 376 # the path for contracting 0 or 1 time(s) is already optimized 377 # or the user has disabled using opt_einsum --> 378 return _VF.einsum(equation, operands) # type: ignore[attr-defined] 379 380 path = None OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 14.76 GiB total capacity; 13.96 GiB already allocated; 3.88 MiB free; 13.98 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF ```
xcharleslin commented 1 year ago
Got a minimum repro (click for code). MinDalle GPU OOMs when ran in a thread. Daft is not involved at all. ```python import torch from min_dalle import MinDalle import PIL.Image from concurrent.futures import ThreadPoolExecutor USE_GPU = True def f(text: str, dir: str) -> PIL.Image.Image: return MinDalle( models_root=f'./{dir}', dtype=torch.float32, # Tell the min-dalle library to load model on GPU or GPU device="cuda" if USE_GPU else "cpu", is_mega=False, is_reusable=True ).generate_image( text, seed=-1, grid_size=1, is_seamless=False, temperature=1, top_k=256, supercondition_factor=32, ) # No threading works f("hello", "tmp1") # Threading does not work tpe = ThreadPoolExecutor() tpe.submit(f, "hello2", "world2").result() # GPU OOMs here ```
xcharleslin commented 1 year ago

Filed here:

For now we'll just modify the notebook to do a weight predownload.

xcharleslin commented 1 year ago

@jaychia I couldn't get the predownloading to work within a single notebook run. Could you try giving it a shot?

The only times I've gotten it working involve refreshing the notebook (:thinking:):

Things I've tried to predownload that didn't work: