DeepGraphLearning / KnowledgeGraphEmbedding

MIT License
1.24k stars 264 forks source link

Did you just use the first batch to train the model? Can you help me solve my problem? #29

Closed ToneLi closed 4 years ago

ToneLi commented 4 years ago

I have question in (ROTATE) model.py. ROTATE uses next function to generate the data, shouldn't the next function be inside the loop? If use this function, I found in every step, ROTATE just chooses the first batch to train, because if next function is not in the loop, it will generates the first data in the list/dict.... who can help me answer my question?

class BidirectionalOneShotIterator(object):

def __init__(self, dataloader_head, dataloader_tail):
    self.iterator_head = self.one_shot_iterator(dataloader_head)
    # print("bb",next(self.iterator_head))  #一个batch的
    self.iterator_tail = self.one_shot_iterator(dataloader_tail)
    self.step = 0

def __next__(self):
    self.step += 1

    if self.step % 2 == 0:
        data = next(self.iterator_head)
    else:
        data = next(self.iterator_tail)
    print("self.step", self.step)
    return data

@staticmethod
def one_shot_iterator(dataloader):
    '''
    Transform a PyTorch Dataloader into python iterator
    '''
    while True:
        for data in dataloader:
            yield data

def train_step(model, optimizer, train_iterator, args): ''' A single train step. Apply back-propation and return the loss ''' model.train() optimizer.zero_grad() positive_sample, negative_sample, subsampling_weight, mode = next(train_iterator)

Edward-Sun commented 4 years ago

Hi Tone,

The "next" function will always give the next item in a "python iterator" (please search "python iterator" in Google for python syntax)

list and dict are not "python iterators".

ToneLi commented 4 years ago

Hi Sun, Thanks very much! Chen~

Zhiqing Sun notifications@github.com 于2020年7月1日周三 下午9:43写道:

Hi Tone,

The "next" function will always give the next item in a "python iterator" (please search "python iterator" in Google for python syntax)

list and dict are not "python iterators".

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/issues/29#issuecomment-652426432, or unsubscribe https://github.com/notifications/unsubscribe-auth/AICJYNPMIACNSSXDTJ4GVSTRZM4PHANCNFSM4OL6XHLQ .