huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.49k stars 5.28k forks source link

Enable Flax Stable Diffusion 2.1 model in the example #5461

Closed RissyRan closed 11 months ago

RissyRan commented 11 months ago

Describe the bug

In the current stable diffusion code example, it supports model CompVis/stable-diffusion-v1-4-flax.

Request to add a from_pt bool arg and pass in CLIPTokenizer.from_pretrained, FlaxCLIPTextModel.from_pretrained, FlaxAutoencoderKL.from_pretrained, and FlaxUNet2DConditionModel.from_pretrained to enable conversion from PyTorch model. So we could use stabilityai/stable-diffusion-2-1 in flax.

Reproduction

Successful run example on TPU in my forked repo (https://github.com/RissyRan/diffusers.git):

git clone -b sd-2.1 https://github.com/RissyRan/diffusers.git
cd diffusers && pip install .
pip install -U -r examples/text_to_image/requirements_flax.txt
pip install clu tensorflow
export PATH=$PATH:$HOME/.local/bin

cd ~/diffusers/examples/text_to_image

python3 train_text_to_image_flax.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-1 --dataset_name=lambdalabs/pokemon-blip-captions --resolution=256 --center_crop --random_flip --train_batch_size=8 --mixed_precision=bf16 --max_train_steps=3000 --learning_rate=1e-05 --max_grad_norm=1 --output_dir=sd-bf16-model --image_column=image --caption_column=text

Logs

loading file vocab.json from cache at /home/ranran/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6/tokenizer/vocab.json
loading file merges.txt from cache at /home/ranran/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6/tokenizer/merges.txt
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at /home/ranran/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6/tokenizer/special_tokens_map.json
loading file tokenizer_config.json from cache at /home/ranran/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6/tokenizer/tokenizer_config.json
loading file tokenizer.json from cache at None
Adding <|startoftext|> to the vocabulary
Adding <|endoftext|> to the vocabulary
Adding ! to the vocabulary
loading configuration file config.json from cache at /home/ranran/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6/text_encoder/config.json
Model config CLIPTextConfig {
  "_name_or_path": "hf-models/stable-diffusion-v2-768x768/text_encoder",
  "architectures": [
    "CLIPTextModel"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "dropout": 0.0,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_size": 1024,
  "initializer_factor": 1.0,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 77,
  "model_type": "clip_text_model",
  "num_attention_heads": 16,
  "num_hidden_layers": 23,
  "pad_token_id": 1,
  "projection_dim": 512,
  "torch_dtype": "float32",
  "transformers_version": "4.34.0",
  "vocab_size": 49408
}

Downloading pytorch_model.bin: 100%|████████████████████████████████████████████████████████████████████████████████| 1.36G/1.36G [00:05<00:00, 263MB/s]
loading weights file pytorch_model.bin from cache at /home/ranran/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6/text_encoder/pytorch_model.bin
Loading PyTorch weights from /home/ranran/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6/text_encoder/pytorch_model.bin
PyTorch checkpoint contains 340,387,917 parameters.
Some weights of the model checkpoint at stabilityai/stable-diffusion-2-1 were not used when initializing FlaxCLIPTextModel: {('text_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of FlaxCLIPTextModel were initialized from the model checkpoint at stabilityai/stable-diffusion-2-1.
If your task is similar to the task the model of the checkpoint was trained on, you can already use FlaxCLIPTextModel for predictions without further training.
Downloading (…)main/vae/config.json: 100%|█████████████████████████████████████████████████████████████████████████████| 611/611 [00:00<00:00, 5.96MB/s]
Downloading (…)on_pytorch_model.bin: 100%|████████████████████████████████████████████████████████████████████████████| 335M/335M [00:01<00:00, 220MB/s]
Downloading (…)ain/unet/config.json: 100%|█████████████████████████████████████████████████████████████████████████████| 939/939 [00:00<00:00, 9.72MB/s]
Downloading (…)on_pytorch_model.bin: 100%|██████████████████████████████████████████████████████████████████████████| 3.46G/3.46G [00:13<00:00, 261MB/s]
10/07/2023 00:20:22 - INFO - __main__ - ***** Running training *****
10/07/2023 00:20:22 - INFO - __main__ -   Num examples = 833
10/07/2023 00:20:22 - INFO - __main__ -   Num Epochs = 116
10/07/2023 00:20:22 - INFO - __main__ -   Instantaneous batch size per device = 8
10/07/2023 00:20:22 - INFO - __main__ -   Total train batch size (w. parallel & distributed) = 32
10/07/2023 00:20:22 - INFO - __main__ -   Total optimization steps = 3000
Epoch... (1/116 | Loss: 0.7180335521697998)                                                                                                             
Epoch... (2/116 | Loss: 0.43771886825561523)                                                                                                            Epoch... (3/116 | Loss: 0.26194703578948975)                                                                                                            
Epoch... (4/116 | Loss: 0.13119475543498993)                                                                                                            
Epoch... (5/116 | Loss: 0.1085905134677887)                                                                                                             
Epoch... (6/116 | Loss: 0.09535876661539078)                                                                                                            
Epoch... (7/116 | Loss: 0.1026906818151474)                                                                                                             
Epoch... (8/116 | Loss: 0.12104104459285736)                                                                                                            
Epoch... (9/116 | Loss: 0.10453010350465775)                                                                                                            
Epoch... (10/116 | Loss: 0.17062756419181824)                                                                                                           
Epoch... (11/116 | Loss: 0.14385651051998138)                                                                                                           
Epoch... (12/116 | Loss: 0.2031339406967163)                                                                                                            
Epoch... (13/116 | Loss: 0.16443254053592682)                                                                                                           
Epoch... (14/116 | Loss: 0.20865577459335327)                                                                                                           
Epoch... (15/116 | Loss: 0.18580672144889832)                                                                                                           
Epoch... (16/116 | Loss: 0.15176695585250854)                                                                                                           
Epoch... (17/116 | Loss: 0.19028088450431824)                                                                                                           
Epoch... (18/116 | Loss: 0.19435754418373108)                                                                                                           
Epoch... (19/116 | Loss: 0.18245089054107666)                                                                                                           
Epoch... (20/116 | Loss: 0.19878770411014557)                                                                                                           
Epoch... (21/116 | Loss: 0.1917506754398346)                                                                                                            
Epoch... (22/116 | Loss: 0.17860080301761627)                                                                                                           
Epoch... (23/116 | Loss: 0.22167636454105377)                                                                                                           
Epoch... (24/116 | Loss: 0.15921765565872192)                                                                                                           
Epoch... (25/116 | Loss: 0.15307846665382385)                                                                                                           
Epoch... (26/116 | Loss: 0.1590590476989746)                                                                                                            
Epoch... (27/116 | Loss: 0.18025700747966766)                                                                                                           
Epoch... (28/116 | Loss: 0.1344955563545227)                                                                                                            
Epoch... (29/116 | Loss: 0.15993043780326843)                                                                                                           
Epoch... (30/116 | Loss: 0.1948901116847992)                                                                                                            
Epoch... (31/116 | Loss: 0.18099309504032135)                                                                                                           
Epoch... (32/116 | Loss: 0.1292792558670044)                                                                                                            
Epoch... (33/116 | Loss: 0.17469605803489685)                                                                                                           
Epoch... (34/116 | Loss: 0.1763019859790802)                                                                                                            
Epoch... (35/116 | Loss: 0.2052479088306427)                                                                                                            
Epoch... (36/116 | Loss: 0.15418601036071777)                                                                                                           
Epoch... (37/116 | Loss: 0.19385230541229248)                                                                                                           
Epoch... (38/116 | Loss: 0.23844414949417114)                                                                                                           
Epoch... (39/116 | Loss: 0.18779703974723816)                                                                                                           
Epoch... (40/116 | Loss: 0.1603996753692627)                                                                                                            
Epoch... (41/116 | Loss: 0.14351682364940643)                                                                                                           
Epoch... (42/116 | Loss: 0.13258004188537598)                                                                                                           
Epoch... (43/116 | Loss: 0.1390378475189209)                                                                                                            
Epoch... (44/116 | Loss: 0.17080891132354736)                                                                                                           
Epoch... (45/116 | Loss: 0.21226170659065247)                                                                                                           
Epoch... (46/116 | Loss: 0.1779744029045105)                                                                                                            
Epoch... (47/116 | Loss: 0.17431044578552246)                                                                                                           
Epoch... (48/116 | Loss: 0.21351423859596252)                                                                                                           
Epoch... (49/116 | Loss: 0.15806937217712402)                                                                                                           
Epoch... (50/116 | Loss: 0.19611066579818726)                                                                                                           
Epoch... (51/116 | Loss: 0.19705402851104736)                                                                                                           
Epoch... (52/116 | Loss: 0.16212433576583862)                                                                                                           
Epoch... (53/116 | Loss: 0.2077130675315857)                                                                                                            
Epoch... (54/116 | Loss: 0.16614729166030884)                                                                                                           
Epoch... (55/116 | Loss: 0.21163761615753174)                                                                                                           
Epoch... (56/116 | Loss: 0.16779109835624695)                                                                                                           
Epoch... (57/116 | Loss: 0.14466387033462524)                                                                                                           
Epoch... (58/116 | Loss: 0.14965006709098816)                                                                                                           
Epoch... (59/116 | Loss: 0.15549156069755554)                                                                                                           
Epoch... (60/116 | Loss: 0.22302648425102234)                                                                                                           
Epoch... (61/116 | Loss: 0.20787739753723145)                                                                                                           
Epoch... (62/116 | Loss: 0.22567233443260193)                                                                                                           
Epoch... (63/116 | Loss: 0.2293650060892105)                                                                                                            
Epoch... (64/116 | Loss: 0.16813695430755615)                                                                                                           
Epoch... (65/116 | Loss: 0.15347810089588165)                                                                                                           
Epoch... (66/116 | Loss: 0.21587051451206207)                                                                                                           
Epoch... (67/116 | Loss: 0.1463858187198639)                                                                                                            
Epoch... (68/116 | Loss: 0.16381247341632843)                                                                                                           
Epoch... (69/116 | Loss: 0.20754028856754303)                                                                                                           
Epoch... (70/116 | Loss: 0.207645446062088)                                                                                                             
Epoch... (71/116 | Loss: 0.21864044666290283)                                                                                                           
Epoch... (72/116 | Loss: 0.1616446077823639)                                                                                                            
Epoch... (73/116 | Loss: 0.23569294810295105)                                                                                                           
Epoch... (74/116 | Loss: 0.1805119663476944)                                                                                                            
Epoch... (75/116 | Loss: 0.14790382981300354)                                                                                                           
Epoch... (76/116 | Loss: 0.16655220091342926)                                                                                                           
Epoch... (77/116 | Loss: 0.21114017069339752)                                                                                                           
Epoch... (78/116 | Loss: 0.14628037810325623)                                                                                                           
Epoch... (79/116 | Loss: 0.18384355306625366)                                                                                                           
Epoch... (80/116 | Loss: 0.1145143210887909)                                                                                                            
Epoch... (81/116 | Loss: 0.1327746957540512)                                                                                                            
Epoch... (82/116 | Loss: 0.12168194353580475)                                                                                                           
Epoch... (83/116 | Loss: 0.14710308611392975)                                                                                                           
Epoch... (84/116 | Loss: 0.11781108379364014)                                                                                                           
Epoch... (85/116 | Loss: 0.10309168696403503)                                                                                                           
Epoch... (86/116 | Loss: 0.13551287353038788)                                                                                                           
Epoch... (87/116 | Loss: 0.13492515683174133)                                                                                                           
Epoch... (88/116 | Loss: 0.15749606490135193)                                                                                                           
Epoch... (89/116 | Loss: 0.15951259434223175)                                                                                                           
Epoch... (90/116 | Loss: 0.19044746458530426)                                                                                                           
Epoch... (91/116 | Loss: 0.10226136445999146)                                                                                                           
Epoch... (92/116 | Loss: 0.16856661438941956)                                                                                                           
Epoch... (93/116 | Loss: 0.14837288856506348)                                                                                                           
Epoch... (94/116 | Loss: 0.12632006406784058)                                                                                                           
Epoch... (95/116 | Loss: 0.1429070234298706)                                                                                                            
Epoch... (96/116 | Loss: 0.136387437582016)                                                                                                             
Epoch... (97/116 | Loss: 0.15173980593681335)                                                                                                           
Epoch... (98/116 | Loss: 0.1286359429359436)                                                                                                            
Epoch... (99/116 | Loss: 0.16469153761863708)                                                                                                           
Epoch... (100/116 | Loss: 0.12364685535430908)                                                                                                          
Epoch... (101/116 | Loss: 0.13709455728530884)                                                                                                          
Epoch... (102/116 | Loss: 0.10448510944843292)                                                                                                          
Epoch... (103/116 | Loss: 0.1209930032491684)                                                                                                           
Epoch... (104/116 | Loss: 0.0959380716085434)                                                                                                           
Epoch... (105/116 | Loss: 0.10715331137180328)                                                                                                          
Epoch... (106/116 | Loss: 0.13205498456954956)                                                                                                          
Epoch... (107/116 | Loss: 0.12471425533294678)                                                                                                          
Epoch... (108/116 | Loss: 0.12001261860132217)                                                                                                          
Epoch... (109/116 | Loss: 0.10570372641086578)                                                                                                          
Epoch... (110/116 | Loss: 0.10314325243234634)                                                                                                          
Epoch... (111/116 | Loss: 0.11420712620019913)                                                                                                          
Epoch... (112/116 | Loss: 0.104721799492836)                                                                                                            
Epoch... (113/116 | Loss: 0.1211758404970169)                                                                                                           
Epoch... (114/116 | Loss: 0.12255255877971649)                                                                                                          
Epoch... (115/116 | Loss: 0.1030556932091713)                                                                                                           
Epoch... (116/116 | Loss: 0.12407614290714264)                                                                                                          
Epoch ... : 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [27:00<00:00, 13.97s/it]
Downloading (…)lve/main/config.json: 100%|█████████████████████████████████████████████████████████████████████████| 4.55k/4.55k [00:00<00:00, 31.7MB/s]
loading configuration file config.json from cache at /home/ranran/.cache/huggingface/hub/models--CompVis--stable-diffusion-safety-checker/snapshots/cb41f3a270d63d454d385fc2e4f571c487c253c5/config.json
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.
`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.
Model config CLIPConfig {
  "_name_or_path": "clip-vit-large-patch14/",
  "architectures": [
    "SafetyChecker"
  ],
  "initializer_factor": 1.0,
  "logit_scale_init_value": 2.6592,
  "model_type": "clip",
  "projection_dim": 768,
  "text_config": {
    "dropout": 0.0,
    "hidden_size": 768,
    "intermediate_size": 3072,
    "model_type": "clip_text_model",
    "num_attention_heads": 12
  },
  "torch_dtype": "float32",
  "transformers_version": "4.34.0",
  "vision_config": {
    "dropout": 0.0,
    "hidden_size": 1024,
    "intermediate_size": 4096,
    "model_type": "clip_vision_model",
    "num_attention_heads": 16,
    "num_hidden_layers": 24,
    "patch_size": 14
  }
}

Downloading pytorch_model.bin: 100%|████████████████████████████████████████████████████████████████████████████████| 1.22G/1.22G [00:04<00:00, 246MB/s]
loading weights file pytorch_model.bin from cache at /home/ranran/.cache/huggingface/hub/models--CompVis--stable-diffusion-safety-checker/snapshots/cb41f3a270d63d454d385fc2e4f571c487c253c5/pytorch_model.bin
Loading PyTorch weights from /home/ranran/.cache/huggingface/hub/models--CompVis--stable-diffusion-safety-checker/snapshots/cb41f3a270d63d454d385fc2e4f571c487c253c5/pytorch_model.bin
PyTorch checkpoint contains 303,981,845 parameters.
Some weights of the model checkpoint at CompVis/stable-diffusion-safety-checker were not used when initializing FlaxStableDiffusionSafetyChecker: {('vision_model', 'vision_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of FlaxStableDiffusionSafetyChecker were initialized from the model checkpoint at CompVis/stable-diffusion-safety-checker.
If your task is similar to the task the model of the checkpoint was trained on, you can already use FlaxStableDiffusionSafetyChecker for predictions without further training.
Downloading (…)rocessor_config.json: 100%|█████████████████████████████████████████████████████████████████████████████| 316/316 [00:00<00:00, 3.03MB/s]
loading configuration file preprocessor_config.json from cache at /home/ranran/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json
size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'shortest_edge': 224}.
crop_size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'height': 224, 'width': 224}.
Image processor CLIPImageProcessor {
  "crop_size": {
    "height": 224,
    "width": 224
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "feature_extractor_type": "CLIPFeatureExtractor",
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "CLIPImageProcessor",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 224
  }
}

Configuration saved in /home/ranran/diffusers/examples/text_to_image/sd-bf16-model/text_encoder/config.json
Model weights saved in /home/ranran/diffusers/examples/text_to_image/sd-bf16-model/text_encoder/flax_model.msgpack
tokenizer config file saved in sd-bf16-model/tokenizer/tokenizer_config.json
Special tokens file saved in sd-bf16-model/tokenizer/special_tokens_map.json
added tokens file saved in sd-bf16-model/tokenizer/added_tokens.json
`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.
`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.
Configuration saved in /home/ranran/diffusers/examples/text_to_image/sd-bf16-model/safety_checker/config.json
Model weights saved in /home/ranran/diffusers/examples/text_to_image/sd-bf16-model/safety_checker/flax_model.msgpack
Image processor saved in sd-bf16-model/feature_extractor/preprocessor_config.json

System Info

Google Cloud TPU

Who can help?

@yiyixuxu @patrickvonplaten

patrickvonplaten commented 11 months ago

Hey @RissyRan,

Would you maybe like to open a PR? cc @pcuenca here

RissyRan commented 11 months ago

Yes sure, @patrickvonplaten! Let me send out a PR.

RissyRan commented 11 months ago

Created the PR: https://github.com/huggingface/diffusers/pull/5501