예를 들어, 스플릿 후 695장의 아이템이 들어있는 fold가 있습니다.
그럼 이미지의 인덱스의 범위는 0 ~ 694가 되어야하고, 694번째 이미지의 인덱스는 694 여야 하는데요.
하지만 스플릿해서 가져온 탓인지 id가 694가 아닌 스플릿 전 이미지 아이디를 가지고 있습니다.
image_id 가 업데이트가 되지 않는다는 것을 알 수 있습니다.
이러한 결과로 같은 폴드 내 어노테이션 역시 이전 이미지 아이디를 가지고 있습니다.
import os
import json
import numpy as np
import pandas as pd
import argparse
import random
from tqdm import tqdm
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
path = os.path.dirname(os.path.abspath(__file__))
data_path = os.path.join(path, "..", "data")
annotations_path = os.path.join(data_path, "train_all.json")
def main(args):
random.seed(args.random_seed)
with open(annotations_path, "r") as f:
train_json = json.loads(f.read())
images = train_json["images"]
categories = train_json["categories"]
annotations = train_json["annotations"]
annotations_df = pd.DataFrame.from_dict(annotations)
x = images
y = [[0] * len(categories) for _ in range(len(images))]
for i, anno in enumerate(annotations):
i_id = anno["image_id"]
c_id = anno["category_id"] - 1
y[i_id][c_id] += 1
mskf = MultilabelStratifiedKFold(n_splits=args.n_split, shuffle=True)
path = args.path
if not os.path.exists(path):
os.mkdir(path)
for idx, (train_index, val_index) in tqdm(
enumerate(mskf.split(x, y)), total=args.n_split
):
train_dict = dict()
val_dict = dict()
for i in ["info", "licenses", "categories"]:
train_dict[i] = train_json[i]
val_dict[i] = train_json[i]
train_dict["images"] = np.array(images)[train_index].tolist()
val_dict["images"] = np.array(images)[val_index].tolist()
# 기존 이미지 id 를 Key, 새로 인덱스 순으로 부여받은 이미지 id를 Value
train_id2id = dict()
val_id2id = dict()
# train 과 val 이미지 아이디 업데이트
for i in range(len(train_dict["images"])):
if i < len(val_dict["images"]):
id_idx = train_dict["images"][i]["id"]
train_id2id[id_idx] = i
train_dict["images"][i]["id"] = i
val_id2id[id_idx] = i
val_dict["images"][i]["id"] = i
else:
id_idx = train_dict["images"][i]["id"]
train_id2id[id_idx] = i
train_dict["images"][i]["id"] = i
train_dict["annotations"] = annotations_df[
annotations_df["image_id"].isin(train_index)
].to_dict("records")
val_dict["annotations"] = annotations_df[
annotations_df["image_id"].isin(val_index)
].to_dict("records")
# train 과 val 어노테이션 안 이미지 아이디 업데이트
for i in range(len(train_dict["annotations"])):
train_dict["annotations"][i]["image_id"] = train_id2id[
train_dict["annotations"][i]["image_id"]
]
for i in range(len(val_dict["annotations"])):
val_dict["annotations"][i]["image_id"] = val_id2id[val_dict["annotations"][i]["image_id"]]
train_dir = os.path.join(path, f"train_fold{idx}.json")
val_dir = os.path.join(path, f"val_fold{idx}.json")
with open(train_dir, "w") as train_file:
json.dump(train_dict, train_file)
with open(val_dir, "w") as val_file:
json.dump(val_dict, val_file)
print("Done Make files")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--path",
"-p",
type=str,
default=os.path.join(path, "..", "data", "stratified_kfold"),
)
parser.add_argument("--n_split", "-n", type=int, default=10)
parser.add_argument("--random_seed", type=int, default=42)
args = parser.parse_args()
main(args)
이렇게 해서 실행하면
Traceback (most recent call last):
File "/opt/ml/level2-semantic-segmentation-level2-cv-14/utils/split_stratified_dataset.py", line 111, in <module>
main(args)
File "/opt/ml/level2-semantic-segmentation-level2-cv-14/utils/split_stratified_dataset.py", line 86, in main
val_dict["annotations"][i]["image_id"] = val_id2id[val_dict["annotations"][i]["image_id"]]
KeyError: 4
분명히 잘 들어가야 하는데 원래 이미지 id가 4가 없었던 것처럼 에러가 발생합니다.
버그 재현 방법
pip install iterative-stratification
설치해주시고 python 파일로 저장하신 후 실행하시면 됩니다.
해결 방향
main() 함수에서 해결이 안될것 같아서 일단 json 저장 후 다시 불러와서 고치는 방향으로 작성 중 입니다.
버그 내용
예를 들어, 스플릿 후 695장의 아이템이 들어있는 fold가 있습니다. 그럼 이미지의 인덱스의 범위는 0 ~ 694가 되어야하고, 694번째 이미지의 인덱스는 694 여야 하는데요.
하지만 스플릿해서 가져온 탓인지 id가 694가 아닌 스플릿 전 이미지 아이디를 가지고 있습니다.
image_id 가 업데이트가 되지 않는다는 것을 알 수 있습니다.
이러한 결과로 같은 폴드 내 어노테이션 역시 이전 이미지 아이디를 가지고 있습니다.
이렇게 해서 실행하면
분명히 잘 들어가야 하는데 원래 이미지 id가 4가 없었던 것처럼 에러가 발생합니다.
버그 재현 방법
설치해주시고 python 파일로 저장하신 후 실행하시면 됩니다.
해결 방향
main() 함수에서 해결이 안될것 같아서 일단 json 저장 후 다시 불러와서 고치는 방향으로 작성 중 입니다.
혹시 저 코드에서 이상한 점을 발견해주신다면 감사하겠습니다!