MIC-DKFZ / batchgenerators

A framework for data augmentation for 2D and 3D image classification and segmentation
Apache License 2.0
1.09k stars 221 forks source link

RuntimeWarning in color_augumentations. #118

Closed xxsxxsxxs666 closed 9 months ago

xxsxxsxxs666 commented 9 months ago

1705807295588

Here is the code.

for c in range(data_sample.shape[0]):
    retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats
    if retain_stats_here:
        mn = data_sample[c].mean()
        sd = data_sample[c].std()
    if np.random.random() < 0.5 and gamma_range[0] < 1:
        gamma = np.random.uniform(gamma_range[0], 1)
    else:
        gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1])
    minm = data_sample[c].min()
    rnge = data_sample[c].max() - minm
    data_sample[c] = np.power(((data_sample[c] - minm) / float(rnge + epsilon)), gamma) * float(rnge + epsilon) + minm
    if retain_stats_here:
        data_sample[c] = data_sample[c] - data_sample[c].mean()
        data_sample[c] = data_sample[c] / (data_sample[c].std() + 1e-8) * sd
        data_sample[c] = data_sample[c] + mn

I think this problem comes when data_sample[c].std() is too small. Maybe data_sample[c] / (data_sample[c].std()*sd + 1e-8) is better ?