open-mmlab / StyleShot

StyleShot: A SnapShot on Any Style. 一款可以迁移任意风格到任意内容的模型,无需针对图片微调,即能生成高质量的个性风格化图片!
https://styleshot.github.io/
MIT License
267 stars 16 forks source link

About training problem #18

Closed nldhuyen0047 closed 3 months ago

nldhuyen0047 commented 3 months ago

Hi, I would like to train the StyleShot model on my dataset, but I have a question regarding the GPU specifications required for training. I am using a GPU with 24 GB of VRAM; is this suitable for training the model? Thank you so much.

Jeoyal commented 3 months ago

Hi @nldhuyen0047, thank you for your interest in our work. We trained StyleShot on a single machine with 8 A100 GPUs (80GB of VRAM) for 300k steps with a batch size of 16 per GPU. If you're training on a GPU with 24 GB of VRAM, you may need a smaller batch size and to use gradient accumulation, which leads to longer training time. Training time cost also depends on the scale of your dataset.

nldhuyen0047 commented 3 months ago

Thank you so much.

nldhuyen0047 commented 3 months ago

Hi, I would like to ask about preparing data for my dataset.

I have reviewed the Style Gallery dataset and noticed that the JSON files across different datasets are inconsistent. What are the minimum content requirements for these datasets so that I can prepare my dataset? Just including image_file and content_prompt?

Could you please explain to me what image_encoder_path is in the training stage?

Thank you so much.

Jeoyal commented 3 months ago

Yes, including just image_file and content_prompt is sufficient in dataset preparing for training. The image_encoder_path in training stage refers to the pre-trained weight of transformer blocks. In StyleShot, we use image_encoder_path = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K'. You can download it from here. Thank you for your interest again!!!

nldhuyen0047 commented 3 months ago

I wonder if I would like to train with a real image dataset, should I find more datasets similar to my dataset but with different styles like how you used your StyleGallery dataset?

Thank you so much for your guidance.

Jeoyal commented 3 months ago

You can start by training with your dataset, and later, depending on the visual results, add more datasets.

nldhuyen0047 commented 3 months ago

Thank you so much.

Jeoyal commented 3 months ago

You're welcome :). Please feel free to contact us if you have any further questions.

nldhuyen0047 commented 3 months ago

Hi, I have some problems with training process T_T in stage 1.

This is my command line to train the model: accelerate launch --num_processes 2 --multi_gpu --gpu_ids "all" --mixed_precision "bf16" \ tutorial_train_styleshot_stage_1.py \ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \ --image_encoder_path="laion/CLIP-ViT-H-14-laion2B-s32B-b79K" \ --image_json_file="data/Real/data.json" \ --image_root_path="data/Real/images" \ --mixed_precision="bf16" \ --resolution=512 \ --train_batch_size=4 \ --dataloader_num_workers=2 \ --learning_rate=1e-04 \ --weight_decay=0.01 \ --output_dir="output" \ --save_steps=10000

I got the error: Traceback (most recent call last): File "tutorial_train_styleshot_stage_1.py", line 514, in Traceback (most recent call last): File "tutorial_train_styleshot_stage_1.py", line 514, in main() File "tutorial_train_styleshot_stage_1.py", line 455, in main main() File "tutorial_train_styleshot_stage_1.py", line 455, in main for step, batch in enumerate(train_dataloader):
File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/accelerate/data_loader.py", line 454, in iter for step, batch in enumerate(train_dataloader): File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/accelerate/data_loader.py", line 454, in iter current_batch = next(dataloader_iter) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in next current_batch = next(dataloader_iter) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in next data = self._next_data() File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data data = self._next_data() File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data return self._process_data(data) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data return self._process_data(data) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data data.reraise() File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/_utils.py", line 643, in reraise data.reraise() File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/_utils.py", line 643, in reraise raise RuntimeError(msg) from None RuntimeError: Caught JSONDecodeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop data = fetcher.fetch(index) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in data = [self.dataset[idx] for idx in possibly_batched_index] File "tutorial_train_styleshot_stage_1.py", line 81, in getitem item = json.loads(self.data[idx]) File "/home/miniconda3/envs/styleshot/lib/python3.8/json/init.py", line 357, in loads return _default_decoder.decode(s) File "/home/miniconda3/envs/styleshot/lib/python3.8/json/decoder.py", line 337, in decode obj, end = self.raw_decode(s, idx=_w(s, 0).end()) File "/home/miniconda3/envs/styleshot/lib/python3.8/json/decoder.py", line 353, in raw_decode obj, end = self.scan_once(s, idx) json.decoder.JSONDecodeError: Expecting property name enclosed in double quotes: line 2 column 1 (char 6) raise RuntimeError(msg) from None

RuntimeError: Caught JSONDecodeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop data = fetcher.fetch(index) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in data = [self.dataset[idx] for idx in possibly_batched_index] File "tutorial_train_styleshot_stage_1.py", line 81, in getitem item = json.loads(self.data[idx]) File "/home/miniconda3/envs/styleshot/lib/python3.8/json/init.py", line 357, in loads return _default_decoder.decode(s) File "/home/miniconda3/envs/styleshot/lib/python3.8/json/decoder.py", line 337, in decode obj, end = self.raw_decode(s, idx=_w(s, 0).end()) File "/home/miniconda3/envs/styleshot/lib/python3.8/json/decoder.py", line 355, in raw_decode raise JSONDecodeError("Expecting value", s, err.value) from None json.decoder.JSONDecodeError: Expecting value: line 1 column 5 (char 4)

WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 682820 closing signal SIGTERM ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 682819) of binary: /home/miniconda3/envs/styleshot/bin/python Traceback (most recent call last): File "/home/miniconda3/envs/styleshot/bin/accelerate", line 8, in sys.exit(main()) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main args.func(args) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/accelerate/commands/launch.py", line 1097, in launch_command multi_gpu_launcher(args) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/accelerate/commands/launch.py", line 734, in multi_gpu_launcher distrib_run.run(args) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run elastic_launch( File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in call return launch_agent(self._config, self._entrypoint, list(args)) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

Have you ever encountered a similar situation? If possible could you please explain to me?

Thank you so much.

Jeoyal commented 3 months ago

I haven't encountered this issue before, but it seems that the problem occurs at item = json.loads(self.data[idx]). This indicates that there might be an error in the construction of your jsonfile. Could you show me two lines from data.json?

nldhuyen0047 commented 3 months ago

This is my two lines from data.json: { "image_file": "data/Real/images/000031.jpg", "content_prompt": "Many people are crossing the street while many cars are parked on the side of the road, in the background there are several tall brick buildings with many signs" }, { "image_file": "data/Real/images/000034.jpg", "content_prompt": "A building with windows and paved with cobblestones below has a few shops and in front there are some food carts covered with tarps, the street is paved with cobblestones" },

Jeoyal commented 3 months ago

All information must be on a single line (i.e., it needs to be saved to the json file using json.dumps()). Every line in your jsonfile must be like this: {"image_file": "data/Real/images/000031.jpg", "content_prompt": "Many people are crossing the street while many cars are parked on the side of the road, in the background there are several tall brick buildings with many signs"}

nldhuyen0047 commented 3 months ago

Oh thank you.

I fixed the code, but it has another error like that:

... [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000075.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000085.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000005.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000060.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000074.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000059.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000051.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000011.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000040.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000043.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000077.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000044.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000050.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000041.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000080.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000007.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000049.jpg' [Errno 2] No such file or directory: '/home/StyleShot/data/Real/images/data/Real/images/000046.jpg' Traceback (most recent call last): File "tutorial_train_styleshot_stage_1.py", line 514, in main() File "tutorial_train_styleshot_stage_1.py", line 455, in main for step, batch in enumerate(train_dataloader): File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/accelerate/data_loader.py", line 454, in iter current_batch = next(dataloader_iter) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in next data = self._next_data() File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data return self._process_data(data) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data data.reraise() File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/_utils.py", line 644, in reraise raise exception TypeError: Caught TypeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop data = fetcher.fetch(index) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in data = [self.dataset[idx] for idx in possibly_batched_index] File "tutorial_train_styleshot_stage_1.py", line 97, in getitem raw_image = self.crop(raw_image) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 95, in call img = t(img) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 361, in forward return F.resize(img, self.size, self.interpolation, self.maxsize, self.antialias) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torchvision/transforms/functional.py", line 476, in resize , image_height, image_width = get_dimensions(img) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torchvision/transforms/functional.py", line 78, in get_dimensions return F_pil.get_dimensions(img) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torchvision/transforms/_functional_pil.py", line 31, in get_dimensions raise TypeError(f"Unexpected type {type(img)}") TypeError: Unexpected type <class 'NoneType'>

WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 684271 closing signal SIGTERM ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 1 (pid: 684272) of binary: /home/miniconda3/envs/styleshot/bin/python Traceback (most recent call last): File "/home/miniconda3/envs/styleshot/bin/accelerate", line 8, in sys.exit(main()) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main args.func(args) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/accelerate/commands/launch.py", line 1097, in launch_command multi_gpu_launcher(args) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/accelerate/commands/launch.py", line 734, in multi_gpu_launcher distrib_run.run(args) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run elastic_launch( File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in call return launch_agent(self._config, self._entrypoint, list(args)) File "/home/miniconda3/envs/styleshot/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

tutorial_train_styleshot_stage_1.py FAILED

Jeoyal commented 3 months ago

Obviously, your image_path is incorrect, it should be /home/StyleShot/data/Real/images/000075.jpg rather than /home/StyleShot/data/Real/images/data/Real/images/000075.jpg

nldhuyen0047 commented 3 months ago

Yes, the directory in the json file is like "data/Real/images/000075.jpg", but I don't know why the address is so wrong, I'm checking the code again.

Jeoyal commented 3 months ago

It depends on your --image_root_path="data/Real/images", which leads to the wrong path of image.

nldhuyen0047 commented 3 months ago

So could you please explain to me what address should I put in image_root_path?

Jeoyal commented 3 months ago

/home/StyleShot/ should be fine.

nldhuyen0047 commented 3 months ago

The code is running, I use --image_root_path="".

Thank you so much ^^.