Open beniz opened 5 months ago
Hi~ Can you provide more descriptions, like
Hi @PeizeSun thanks, and apologies for not having provided more details.
I've pre-computed the codes with (removing the ten_crop flag for debug):
bash scripts/autoregressive/extract_codes_c2i.sh --vq-ckpt /path/to/models/vq_ds16_c2i.pt --data-path /path/to/butterflies/ --code-path /path/to/butterflies/codes_256/ --image-size 256
The toy dataset is single class, available from https://www.joligen.com/datasets/butterflies.tar
The generated codes in the codes_256
dir seems to be OK:
ls -l codes_256/
total 44
drwxrwxr-x 2 b b 20480 Jun 14 17:19 imagenet256_codes
drwxrwxr-x 2 b b 20480 Jun 14 17:19 imagenet256_labels
From printing the shapes, features are (correctly afaik) of shape [1,2,256]
, and labels of shape [1]
.
My diff on the training code is below, I've downsized the training to a single GPU for debug purposes.
diff --git a/autoregressive/train/train_c2i.py b/autoregressive/train/train_c2i.py
index 3b43aa5..031b868 100644
--- a/autoregressive/train/train_c2i.py
+++ b/autoregressive/train/train_c2i.py
@@ -15,6 +15,8 @@ import time
import inspect
import argparse
+import sys
+sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from utils.logger import create_logger
from utils.distributed import init_distributed_mode
from utils.ema import update_ema, requires_grad
diff --git a/dataset/imagenet.py b/dataset/imagenet.py
index c07f6cb..6d0e185 100644
--- a/dataset/imagenet.py
+++ b/dataset/imagenet.py
@@ -23,8 +23,8 @@ class CustomDataset(Dataset):
# self.feature_files = sorted(os.listdir(feature_dir))
# self.label_files = sorted(os.listdir(label_dir))
# TODO: make it configurable
- self.feature_files = [f"{i}.npy" for i in range(1281167)]
- self.label_files = [f"{i}.npy" for i in range(1281167)]
+ self.feature_files = [f"{i}.npy" for i in range(951)]
+ self.label_files = [f"{i}.npy" for i in range(951)]
def __len__(self):
assert len(self.feature_files) == len(self.label_files), \
@@ -58,4 +58,4 @@ def build_imagenet_code(args):
label_dir = f"{args.code_path}/imagenet{args.image_size}_labels"
assert os.path.exists(feature_dir) and os.path.exists(label_dir), \
f"please first run: bash scripts/autoregressive/extract_codes_c2i.sh ..."
- return CustomDataset(feature_dir, label_dir)
\ No newline at end of file
+ return CustomDataset(feature_dir, label_dir)
diff --git a/scripts/autoregressive/train_c2i.sh b/scripts/autoregressive/train_c2i.sh
index ecc6a98..5638ebb 100644
--- a/scripts/autoregressive/train_c2i.sh
+++ b/scripts/autoregressive/train_c2i.sh
@@ -1,6 +1,12 @@
# !/bin/bash
set -x
+nnodes=1
+nproc_per_node=1
+node_rank=0
+master_addr=127.0.0.1
+master_port=29500
+
torchrun \
--nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \
--master_addr=$master_addr --master_port=$master_port \
I run training with
bash scripts/autoregressive/train_c2i.sh --cloud-save-path /path/to/models/gpt_b/ --code-path /data1/path/to/butterflies/codes_256/ --image-size 256 --global-batch-size 256 --gpt-model GPT-B --num-classes 1 --no-compile
The full error is below:
Traceback (most recent call last):
File "/path/to/apps/LlamaGen/autoregressive/train/train_c2i.py", line 296, in <module>
main(args)
File "/path/to/apps/LlamaGen/autoregressive/train/train_c2i.py", line 196, in main
_, loss = model(cond_idx=c_indices, idx=z_indices[:,:-1], targets=z_indices)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/path/to/apps/LlamaGen/autoregressive/train/../../autoregressive/models/gpt.py", line 364, in forward
h = layer(h, freqs_cis, input_pos, mask)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/path/to/apps/LlamaGen/autoregressive/train/../../autoregressive/models/gpt.py", line 255, in forward
h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask))
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/path/to/apps/LlamaGen/autoregressive/train/../../autoregressive/models/gpt.py", line 220, in forward
xq = apply_rotary_emb(xq, freqs_cis)
File "/path/to/apps/LlamaGen/autoregressive/train/../../autoregressive/models/gpt.py", line 424, in apply_rotary_emb
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
RuntimeError: shape '[1, 512, 1, 32, 2]' is invalid for input of size 16448
[2024-06-17 11:04:43,428] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 263727) of binary: /usr/bin/python3
Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 8, in <module>
sys.exit(main())
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
return f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 812, in main
run(args)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 803, in run
elastic_launch(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 135, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
autoregressive/train/train_c2i.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2024-06-17_11:04:43
host : neptune10
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 263727)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
I may have missed something, regarding the rope embedding and replicating across the number of crops.
the same issue
I first run the following command to generate codes on the imagenet dataset
torchrun --nproc_per_node 2 autoregressive/train/extract_codes_c2i.py --vq-model VQ-16 --vq-ckpt ./vq_ds16_c2i.pt --data-path xxx --code-path xxx --image-size 256
and then run the following command to train
torchrun --nproc_per_node 8 autoregressive/train/train_c2i.py --code-path xxx --results-dir xxx --no-compile --image-size 256
it raises an error in the apply_rotary_emb
function in autoregressive/models/gpt.py
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
RuntimeError: shape '[1, 512, 1, 32, 2]' is invalid for input of size 16448
I have printed the size of variables before line xq = apply_rotary_emb(xq, freqs_cis)
and found that the size of xq
is torch.Size([128, 512, 12, 64])
and the size of freqs_cis
is torch.Size([257, 32, 2])
if i comment out xq = apply_rotary_emb(xq, freqs_cis)
and xk = apply_rotary_emb(xk, freqs_cis)
, it can train normally.
@PeizeSun could you help to solve this problem? thanks.
I have met the same issue.
I also met the same issue. Do you solve this problem? @Baijiong-Lin @Menoly-xin @beniz @PeizeSun
Just regenerate the image codes. Ensure you do not change the directory name shown in the instructions. For more details, you can check the ImageNet dataset.
Thanks for your reply. I have resolved the issue. It was due to the path name restrictions in lines 12, 14, and 15 of the CustomDataset class in imagenet.py within the project. If you do not use --ten_crop when extracting data, it means the features you obtain will be in the shape (B, S, 2, D). In this case, you need to add x = x[:, :, torch.rand(1) < 0.5, :] at line 188 in train.py to select one of the two augmented features.
I notice freqs_cis is postion_code in attention. When training, freqs_cis is produced by it
def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
# split the dimension into half, one for x and one for y
half_dim = n_elem // 2
freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
t = torch.arange(grid_size, device=freqs.device)
freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
freqs_grid = torch.concat([
freqs[:, None, :].expand(-1, grid_size, -1),
freqs[None, :, :].expand(grid_size, -1, -1),
], dim=-1) # (grid_size, grid_size, head_dim // 2)
cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
cache = cache_grid.flatten(0, 1)
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
return cond_cache
You can notice his size being depended by grid_size which is the latent space size. The latent space size is 576 But if follow author's code. You will know they has used the Data Augmentation in generating training data. So the token shape is [bs, 5760] I think there is a bug in it is they should not use the ten_crop or change the function "precompute_freqs_cis_2d" to "precompute_freqs_cis"
I have a better idea. I notice X.shape is 10 batch shape. I think the wrong is that. They should not reshape the x. they should do that
z_indices = x.squeeze(0)
And Modify it in gpt.py
def forward(
self,
idx: torch.Tensor,
cond_idx: torch.Tensor, # cond_idx_or_embed
input_pos: Optional[torch.Tensor] = None,
targets: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
valid: Optional[torch.Tensor] = None,
):
if idx is not None and cond_idx is not None: # training or naive inference
cond_embeddings = self.cls_embedding(cond_idx, train=self.training)[:,:self.cls_token_num]
token_embeddings = self.tok_embeddings(idx)
cond_embeddings = cond_embeddings.repeat(10, 1, 1)
token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
I think it is logical way to finish the training or batch size forward
assert z_indices.shape[0] == c_indices.shape[0] should be deleted by the way. It is meaningless.
Hi, thanks for the interesting work.
I'm playing a bit with the code on a simple single-class dataset of 256x256 images, and I've modified basic things (imagenet hardcoded numbers, etc...).
I'm hitting the error above on the rope embedding:
Went chasing the issue, and it seems this is due to a mismatch between the precomputed
freqs_cis
and the reshaping of the attention vectors. This mismatch appears to mostly be due to the number of augmentations (I went from 10 to 2 during debug).If this error rings a bell, I'd appreciate any hint :) I see how to fix it with a hack (reducing aug to none), but I believe something else is wrong, otherwise it wouldn't work at all.
Thanks!