Closed MichelleYang2017 closed 5 years ago
Hi, Sorry if it is a bit confusing. Target use a trick to be able to create batches.
batch_input : the bacth of features given to the network (student model). ema_batch_input : the batch of features given to the teacher model (ema: exponential moving average) target : target of the batch_input (strong labels) to make it a batch, the trick is to give strong labels to all the data, but put -1 to unlabel data, and convert strong to weak labels when needed for weak data.
In pytorch you define a dataset, here in our case: DataLoadDf (line 25 DataLoad.py) inherits from the torch.utils.data.Dataset class. Here DataLoadDf will give you a tuple (features, target) if no transform is given.
However, in our case, we call add transforms to our data:
In these transformations there is the "augment_type="noise" argument in
transforms = get_transforms(cfg.max_frames, scaler, augment_type="noise")
It means we call the AugmentGaussianNoise call function in DataLoad.py. (see get_transforms in utils.utils.py).
This class transform a tuple:
(features, target) to (features, noised_features, target).
This is done for every single data. Then you have the "torch.utils.data.DataLoader" class which gives you how to go from single data to a batch. (explained in collate_fn, it is just concatenating data, sampled thanks to the sampler you indicate).
So to be clearer, you could try lopping into the different data: If you put that line 253 in main.py
# Basic DataLoadDf, coming from Dataset
for sample in train_weak_data:
print("number of items in tuple: {}".format(len(sample)))
features, target = sample
print("shape of features: {}".format(features.shape))
print("shape of targets: {}".format(target.shape))
break
# Adding a transform to get noisy data
transforms = get_transforms(cfg.max_frames, scaler, augment_type="noise")
train_weak_data.set_transform(transforms)
for sample in train_weak_data:
print("number of items in tuple: {}".format(len(sample)))
features, features_noised, target = sample
print("shape of features: {}".format(features.shape))
print("shape of features_noised: {}".format(features_noised.shape))
print("shape of targets: {}".format(target.shape))
break
# DataLoader creating a batch (here the basic sampler is to get random data)
for sample in DataLoader(train_weak_data, batch_size=10):
print("number of items in tuple: {}".format(len(sample)))
features, features_noised, target = sample
print("shape of features: {}".format(features.shape))
print("shape of features_noised: {}".format(features_noised.shape))
print("shape of targets: {}".format(target.shape))
break
I hope it helps.
If you have other questions, do not hesitate.
Thanks a lot. I got it.
I'm sorry,I am not familiar with Pytorch, so I have a small issue I'd like to ask for your help. In the train.py, there is a function"train",there has one code:for i, (batch_input, ema_batch_input, target) in enumerate(train_loader):I don't know how to understand the (batch_input, ema_batch_input, target),what kind of data do they stand for?