Open ghost233lism opened 1 month ago
感谢您的关注! 是的,在AQA-7数据集上的norm是使用的CoRe工作中的代码, def denormalize(label, class_idx, upper=100.0): label_ranges = { 1: (21.6, 102.6), 2: (12.3, 16.87), 3: (8.0, 50.0), 4: (8.0, 50.0), 5: (46.2, 104.88), 6: (49.8, 99.36)} label_range = label_ranges[class_idx] true_label = (label.float() / float(upper)) * (label_range[1] - label_range[0]) + label_range[0] return true_label
def normalize(label, class_idx, upper=100.0): label_ranges = { 1: (21.6, 102.6), 2: (12.3, 16.87), 3: (8.0, 50.0), 4: (8.0, 50.0), 5: (46.2, 104.88), 6: (49.8, 99.36)} label_range = label_ranges[class_idx] norm_label = ((label - label_range[0]) / (label_range[1] - label_range[0])) * float(upper) return norm_label
当前的版本是基于TSA工作在FineDiving数据集上norm。抱歉一直没有进一步补充。
感谢您的回答,还有一个小问题,您在runner中115行.network_forward_train传入了class1和class2,但是AQA-7数据集是没有class概念的,这导致程序无法运行,请问这一部分也是直接替换为CoRe的代码吗?
是的,在AQA-7数据集上没有使用class分类,直接去掉class1和class2和对应的分类损失。
作者您好,感谢您的工作!我注意到您的代码在读取数据后进行了norm处理,但是在处理AQA-7数据集时,调用norm时应该传入一个数组(代表上下限),但实际上传入的是一个数。这一点CoRe工作的代码应该是正确的。