NELSONZHAO / zhihu

This repo contains the source code in my personal column (https://zhuanlan.zhihu.com/zhaoyeyu), implemented using Python 3.6. Including Natural Language Processing and Computer Vision projects, such as text generation, machine translation, deep convolution GAN and other actual combat code.
https://zhuanlan.zhihu.com/zhaoyeyu
3.5k stars 2.14k forks source link

training_logits, targets维度不匹配 #34

Open pengwei-iie opened 5 years ago

pengwei-iie commented 5 years ago

cost = tf.contrib.seq2seq.sequence_loss(training_logits, targets, masks)

tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [5632] vs. [6400]

就是我的target是(256,25) 可是输出得到的training_logits却是(256, 22, 358)358:词表数

我改了一下,这样就对了


def pad_batch_sentence(batch, max_length, pad_id):
    # max_length = max([len(sentence) for sentence in batch])
    return [sentence + [pad_id] * (max_length - len(sentence)) for sentence in batch]

def get_batches(sources, targets, batch_size):

    for batch_i in range(0, len(sources) // batch_size):
        start_i = batch_i * batch_size

        # Slice the right amount for the batch
        sources_batch = sources[start_i:start_i + batch_size]
        targets_batch = targets[start_i:start_i + batch_size]

        pad_idx = source_vocab_to_int.get("<PAD>")
        sources_batch_pad = np.array(pad_batch_sentence(sources_batch, max_source_sentence_length, pad_idx))
        targets_batch_pad = np.array(pad_batch_sentence(targets_batch, max_target_sentence_length, pad_idx))
        # Need the lengths for the _lengths parameters
        # 不应该是对pad过的batch做长度的计算,因为都是25
        targets_lengths = []
        for target in targets_batch_pad:
            targets_lengths.append(len(target))

        source_lengths = []
        for source in sources_batch_pad:
            source_lengths.append(len(source))

        yield sources_batch_pad, targets_batch_pad, source_lengths, targets_lengths

可是这样传入的source_lengths都是(20,20,20...)
targets_lengths都是(25, 25, 25...)
sxlprince commented 4 years ago

我也觉得这块有点问题,这样source长度全是padding以后的最大长度。。

sxlprince commented 4 years ago

改了以后会报错。。。