borisdayma / dalle-mini

DALL·E Mini - Generate images from a text prompt
https://www.craiyon.com
Apache License 2.0
14.75k stars 1.21k forks source link

`mega-1-fp16` OOM with 12GB VRAM on RTX 2080 Ti #213

Closed danielchalef closed 2 years ago

danielchalef commented 2 years ago

Release: 0.1.0 / https://github.com/borisdayma/dalle-mini/commit/00d389bfa5586fde0a51e250f7ec3757bb7e704c

According to this, I should be able to load mega-1-fp16 on a RTX 2080 Ti with 12GB VRAM. There are no other processes running on the GPU and I have only one device on this workstation.

Running notebook tools/inference/inference_pipeline.ipynb:

# Load dalle-mini
model, params = DalleBart.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, #_do_init=False
)

_do_init commented out intentionally. See https://github.com/borisdayma/dalle-mini/issues/212

OOMs with errors below.

I have tried the following prior to importing jax without success:

%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
wandb: Downloading large artifact mega-1-fp16:latest, 4938.53MB. 7 files... Done. 0:0:0
2022-06-07 09:45:57.580584: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 384.00MiB (rounded to 402653184)requested by op 
2022-06-07 09:45:57.580789: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] **__************************************************************************************************
2022-06-07 09:45:57.581250: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 402653184 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:     8.4KiB
              constant allocation:        64B
        maybe_live_out allocation:    5.25GiB
     preallocated temp allocation:  288.04MiB
  preallocated temp fragmentation:       620B (0.00%)
                 total allocation:    5.53GiB
              total fragmentation:     5.3KiB (0.00%)
Peak buffers:
    Buffer 1:
        Size: 768.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,4096,2048]
        ==========================

    Buffer 2:
        Size: 768.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,4096]
        ==========================

    Buffer 3:
        Size: 768.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,4096]
        ==========================

    Buffer 4:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 5:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 6:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 7:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 8:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 9:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 10:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 11:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 12:
        Size: 32.00MiB
        XLA Label: fusion
        Shape: f32[2048,4096]
        ==========================

    Buffer 13:
        Size: 32.00MiB
        XLA Label: fusion
        Shape: f32[2048,4096]
        ==========================

    Buffer 14:
        Size: 16.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/jit(remat(core_fn))/threefry2x32" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/scope.py" source_line=746
        XLA Label: custom-call
        Shape: u32[4194304]
        ==========================

    Buffer 15:
        Size: 16.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/jit(remat(core_fn))/threefry2x32" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/scope.py" source_line=746
        XLA Label: custom-call
        Shape: u32[4194304]
        ==========================

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Input In [4], in <cell line: 7>()
      4 from transformers import CLIPProcessor, FlaxCLIPModel
      6 # Load dalle-mini
----> 7 model, params = DalleBart.from_pretrained(
      8     DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, #_do_init=False
      9 )
     11 # Load VQGAN
     12 vqgan, vqgan_params = VQModel.from_pretrained(
     13     VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
     14 )

File ~/anaconda3/envs/dalle/lib/python3.10/site-packages/dalle_mini/model/utils.py:25, in PretrainedFromWandbMixin.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
     22         artifact = wandb.Api().artifact(pretrained_model_name_or_path)
     23     pretrained_model_name_or_path = artifact.download(tmp_dir)
---> 25 return super(PretrainedFromWandbMixin, cls).from_pretrained(
     26     pretrained_model_name_or_path, *model_args, **kwargs
     27 )

File ~/anaconda3/envs/dalle/lib/python3.10/site-packages/transformers/modeling_flax_utils.py:556, in FlaxPreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs)
    553     resolved_archive_file = None
    555 # init random models
--> 556 model = cls(config, *model_args, **model_kwargs)
    558 if from_pt:
    559     state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)

File ~/anaconda3/envs/dalle/lib/python3.10/site-packages/transformers/models/bart/modeling_flax_bart.py:918, in FlaxBartPreTrainedModel.__init__(self, config, input_shape, seed, dtype, **kwargs)
    909 def __init__(
    910     self,
    911     config: BartConfig,
   (...)
    915     **kwargs
    916 ):
    917     module = self.module_class(config=config, dtype=dtype, **kwargs)
--> 918     super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)

File ~/anaconda3/envs/dalle/lib/python3.10/site-packages/transformers/modeling_flax_utils.py:117, in FlaxPreTrainedModel.__init__(self, config, module, input_shape, seed, dtype)
    114 self.dtype = dtype
    116 # randomly initialized parameters
--> 117 random_params = self.init_weights(self.key, input_shape)
    119 # save required_params as set
    120 self._required_params = set(flatten_dict(unfreeze(random_params)).keys())

File ~/anaconda3/envs/dalle/lib/python3.10/site-packages/transformers/models/bart/modeling_flax_bart.py:936, in FlaxBartPreTrainedModel.init_weights(self, rng, input_shape)
    933 params_rng, dropout_rng = jax.random.split(rng)
    934 rngs = {"params": params_rng, "dropout": dropout_rng}
--> 936 return self.module.init(
    937     rngs,
    938     input_ids,
    939     attention_mask,
    940     decoder_input_ids,
    941     decoder_attention_mask,
    942     position_ids,
    943     decoder_position_ids,
    944 )["params"]

    [... skipping hidden 11 frame]

File ~/anaconda3/envs/dalle/lib/python3.10/site-packages/dalle_mini/model/modeling.py:1330, in FlaxBartForConditionalGenerationModule.__call__(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, output_attentions, output_hidden_states, return_dict, deterministic)
   1317 def __call__(
   1318     self,
   1319     input_ids,
   (...)
   1328     deterministic: bool = True,
   1329 ):
-> 1330     outputs = self.model(
   1331         input_ids=input_ids,
   1332         attention_mask=attention_mask,
   1333         decoder_input_ids=decoder_input_ids,
   1334         decoder_attention_mask=decoder_attention_mask,
   1335         position_ids=position_ids,
   1336         decoder_position_ids=decoder_position_ids,
   1337         output_attentions=output_attentions,
   1338         output_hidden_states=output_hidden_states,
   1339         return_dict=return_dict,
   1340         deterministic=deterministic,
   1341     )
   1343     hidden_states = outputs[0]
   1345     if self.config.tie_word_embeddings:

    [... skipping hidden 3 frame]

File ~/anaconda3/envs/dalle/lib/python3.10/site-packages/transformers/models/bart/modeling_flax_bart.py:878, in FlaxBartModule.__call__(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, output_attentions, output_hidden_states, return_dict, deterministic)
    855 def __call__(
    856     self,
    857     input_ids,
   (...)
    866     deterministic: bool = True,
    867 ):
    868     encoder_outputs = self.encoder(
    869         input_ids=input_ids,
    870         attention_mask=attention_mask,
   (...)
    875         deterministic=deterministic,
    876     )
--> 878     decoder_outputs = self.decoder(
    879         input_ids=decoder_input_ids,
    880         attention_mask=decoder_attention_mask,
    881         position_ids=decoder_position_ids,
    882         encoder_hidden_states=encoder_outputs[0],
    883         encoder_attention_mask=attention_mask,
    884         output_attentions=output_attentions,
    885         output_hidden_states=output_hidden_states,
    886         return_dict=return_dict,
    887         deterministic=deterministic,
    888     )
    890     if not return_dict:
    891         return decoder_outputs + encoder_outputs

    [... skipping hidden 3 frame]

File ~/anaconda3/envs/dalle/lib/python3.10/site-packages/dalle_mini/model/modeling.py:1244, in FlaxBartDecoder.__call__(self, input_ids, attention_mask, position_ids, encoder_hidden_states, encoder_attention_mask, init_cache, output_attentions, output_hidden_states, return_dict, deterministic)
   1241 hidden_states = self.layernorm_embedding(hidden_states)
   1242 hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
-> 1244 outputs = self.layers(
   1245     hidden_states,
   1246     attention_mask,
   1247     encoder_hidden_states,
   1248     encoder_attention_mask,
   1249     deterministic=deterministic,
   1250     init_cache=init_cache,
   1251     output_attentions=output_attentions,
   1252     output_hidden_states=output_hidden_states,
   1253     return_dict=return_dict,
   1254 )
   1256 if self.final_ln is None:
   1257     final_output = outputs[0]

    [... skipping hidden 3 frame]

File ~/anaconda3/envs/dalle/lib/python3.10/site-packages/dalle_mini/model/modeling.py:997, in FlaxBartDecoderLayerCollection.__call__(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, deterministic, init_cache, output_attentions, output_hidden_states, return_dict)
    995     hidden_states = (hidden_states,)
    996     # we use a scale on all norms (even last layer) to allow scanning
--> 997     hidden_states, _ = nn.scan(
    998         layer,
    999         variable_axes={"params": 0, "cache": 0},
   1000         split_rngs={"params": True, "dropout": True},
   1001         in_axes=(
   1002             nn.broadcast,
   1003             nn.broadcast,
   1004             nn.broadcast,
   1005             nn.broadcast,
   1006             nn.broadcast,
   1007             nn.broadcast,
   1008         ),
   1009         length=n_layers,
   1010     )(
   1011         self.config,
   1012         dtype=self.dtype,
   1013         add_norm=self.config.ln_positions == "postln",
   1014         name="FlaxBartDecoderLayers",
   1015     )(
   1016         hidden_states,
   1017         attention_mask,
   1018         encoder_hidden_states,
   1019         encoder_attention_mask,
   1020         init_cache,
   1021         output_attentions,
   1022         deterministic,
   1023     )
   1024     hidden_states = hidden_states[0]
   1026 else:

    [... skipping hidden 4 frame]

File ~/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py:145, in scan.<locals>.scan_fn(broadcast_in, init, *args)
    142   out_flat.append(const)
    143 broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
--> 145 c, ys = lax.scan(body_fn, init, xs, length=length, reverse=reverse)
    146 ys = jax.tree_map(transpose_from_front, out_axes, ys)
    147 ys = jax.tree_map(
    148     lambda ax, const, y: (const if ax is broadcast else y), out_axes,
    149     constants_out, ys)

    [... skipping hidden 7 frame]

File ~/anaconda3/envs/dalle/lib/python3.10/site-packages/jax/_src/dispatch.py:615, in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handlers, effects, kept_var_idx, *args)
    613 if effects:
    614   input_bufs_flat, token_handler = _add_tokens(effects, device, input_bufs_flat)
--> 615 out_bufs_flat = compiled.execute(input_bufs_flat)
    616 check_special(name, out_bufs_flat)
    617 if output_buffer_counts is None:

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 402653184 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:     8.4KiB
              constant allocation:        64B
        maybe_live_out allocation:    5.25GiB
     preallocated temp allocation:  288.04MiB
  preallocated temp fragmentation:       620B (0.00%)
                 total allocation:    5.53GiB
              total fragmentation:     5.3KiB (0.00%)
Peak buffers:
    Buffer 1:
        Size: 768.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,4096,2048]
        ==========================

    Buffer 2:
        Size: 768.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,4096]
        ==========================

    Buffer 3:
        Size: 768.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,4096]
        ==========================

    Buffer 4:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 5:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 6:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 7:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 8:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 9:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 10:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 11:
        Size: 384.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/axes_scan.py" source_line=132
        XLA Label: fusion
        Shape: f32[24,2048,2048]
        ==========================

    Buffer 12:
        Size: 32.00MiB
        XLA Label: fusion
        Shape: f32[2048,4096]
        ==========================

    Buffer 13:
        Size: 32.00MiB
        XLA Label: fusion
        Shape: f32[2048,4096]
        ==========================

    Buffer 14:
        Size: 16.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/jit(remat(core_fn))/threefry2x32" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/scope.py" source_line=746
        XLA Label: custom-call
        Shape: u32[4194304]
        ==========================

    Buffer 15:
        Size: 16.00MiB
        Operator: op_name="jit(scan)/jit(main)/while/body/jit(remat(core_fn))/threefry2x32" source_file="/home/daniel/anaconda3/envs/dalle/lib/python3.10/site-packages/flax/core/scope.py" source_line=746
        XLA Label: custom-call
        Shape: u32[4194304]
        ==========================
danielchalef commented 2 years ago

The model loads after modifying how JAX allocates memory. See here: https://github.com/borisdayma/dalle-mini/issues/185#issuecomment-1145961653

drdaxxy commented 2 years ago

The model loader in 🤗 Transformers wastes a lot of VRAM there, _do_init=False is what prevents that. As I said in #212 just now, that option will work in newer versions of Transformers.

(As you're linking my comment -- I feel I should mention I haven't actually tried running this on a 12GB GPU. I have, however, seen VRAM use stay safely within the limits I described (subtracting graphcis system overhead) with _do_init=False, but peak higher during loading when removing that argument)