ShusenTang / Dive-into-DL-PyTorch

本项目将《动手学深度学习》(Dive into Deep Learning)原书中的MXNet实现改为PyTorch实现。
http://tangshusen.me/Dive-into-DL-PyTorch
Apache License 2.0
18.17k stars 5.38k forks source link

3.2.2data_iter里面为什么用到循环 #114

Closed TronYY closed 4 years ago

TronYY commented 4 years ago

bug描述 在3.2.2存在以下代码,我想知道为什么需要用到循环for i in range(0, num_examples, batch_size) 这里第一次循环结束就会返回结果了

# 本函数已保存在d2lzh包中方便以后使用
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)  # 样本的读取顺序是随机的
    for i in range(0, num_examples, batch_size):
        j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # 最后一次可能不足一个batch
        yield  features.index_select(0, j), labels.index_select(0, j)
ShusenTang commented 4 years ago

建议你去查查yield的用法