lxztju / pytorch_classification

利用pytorch实现图像分类的一个完整的代码,训练,预测,TTA,模型融合,模型部署,cnn提取特征,svm或者随机森林等进行分类,模型蒸馏,一个完整的代码
MIT License
1.38k stars 339 forks source link

为什么这几个代码不上传呢? #19

Closed shichengcn closed 3 years ago

shichengcn commented 3 years ago

from data import get_train_transform, get_test_transform from data import get_random_eraser

import numpy as np

def get_random_eraser(p=0.5, s_l=0.02, s_h=0.4, r_1=0.3, r_2=3.3333333333333335, v_l=0, v_h=255, pixel_level=False):

def eraser(input_img):
    img_h, img_w, img_c = input_img.shape
    p_1 = np.random.rand()
    if p_1 > p:
        return input_img
    else:
        while 1:
            s = np.random.uniform(s_l, s_h) * img_h * img_w
            r = np.random.uniform(r_1, r_2)
            w = int(np.sqrt(s / r))
            h = int(np.sqrt(s * r))
            left = np.random.randint(0, img_w)
            top = np.random.randint(0, img_h)
            if left + w <= img_w:
                if top + h <= img_h:
                    break

        if pixel_level:
            c = np.random.uniform(v_l, v_h, (h, w, img_c))
        else:
            c = np.random.uniform(v_l, v_h)
        input_img[top:top + h, left:left + w, :] = c
        return input_img

return eraser