mapillary / inplace_abn

In-Place Activated BatchNorm for Memory-Optimized Training of DNNs
BSD 3-Clause "New" or "Revised" License
1.32k stars 187 forks source link

CUDA:out of memory #79

Closed loveddy closed 5 years ago

loveddy commented 5 years ago

i run the train.py and the cuda memory just get larger after serveral epochs,and here is the strange thing: if i don't create new batches use the same data, the cuda memory usage stays still,but if i create new batches, the cuda memory get larger.I find out that the pytorch turioal don't run epochs but iterations, so i don't know where is the problem in my code.I need your help.... here is my code

train.py

import torch
import random
from dataloader import loadDataset, getBatches, sentence2enco
from tqdm import tqdm
import math
import torch.optim as optim
import os
from config import config
from model import Encoder, AttnDecoder, MLP, Embedding
import time
import numpy as np

torch.backends.cudnn.enabled = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_forcing_ratio = 0.5
SOS_token = 1

config = config()

def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    return loss, nTotal.item()

def createOptimizer(encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, optimizer_type, learning_rate,
                    weight_decay, iter):
    if optimizer_type == 'Adam':
        encoder_s_opt = optim.Adam(encoder_s.parameters(), lr=learning_rate, weight_decay=weight_decay)
        decoder_s_opt = optim.Adam(decoder_s.parameters(), lr=learning_rate, weight_decay=weight_decay)
        encoder_t_opt = optim.Adam(encoder_t.parameters(), lr=learning_rate, weight_decay=weight_decay)
        decoder_t_opt = optim.Adam(decoder_t.parameters(), lr=learning_rate, weight_decay=weight_decay)
        mlp_opt = optim.Adam(mlp.parameters(), lr=learning_rate, weight_decay=weight_decay)
        emb_opt = optim.Adam(embedding.parameters(), lr=learning_rate, weight_decay=weight_decay)
    else:
        encoder_s_opt = optim.SGD(encoder_s.parameters(), lr=learning_rate, weight_decay=weight_decay)
        decoder_s_opt = optim.SGD(decoder_s.parameters(), lr=learning_rate, weight_decay=weight_decay)
        encoder_t_opt = optim.SGD(encoder_t.parameters(), lr=learning_rate, weight_decay=weight_decay)
        decoder_t_opt = optim.SGD(decoder_t.parameters(), lr=learning_rate, weight_decay=weight_decay)
        mlp_opt = optim.SGD(mlp.parameters(), lr=learning_rate, weight_decay=weight_decay)
        emb_opt = optim.SGD(embedding.parameters(), lr=learning_rate, weight_decay=weight_decay)

    if iter >= config._after - 1:
        for param_group in encoder_s_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
        for param_group in decoder_s_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
        for param_group in encoder_t_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
        for param_group in decoder_t_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
        for param_group in mlp_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
        for param_group in emb_opt.param_groups:
            param_group['lr'] = (param_group['lr'] * 0.9 if param_group[
                                                                'lr'] > config.learning_rate_bottom else config.learning_rate_bottom)
    return encoder_s_opt, decoder_s_opt, encoder_t_opt, decoder_t_opt, mlp_opt, emb_opt

def trainIters(encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, optimizer_type, train_samples, val_samples,
               learning_rate, weight_decay):
    current_step = 0
    best_loss = 100.
    val_batches = getBatches(val_samples, config.batch_size)
    loss_fn = torch.nn.MSELoss(reduction='none')
    train_batches = getBatches(train_samples, config.batch_size)
    for iter in range(config.numEpochs):
        if (iter + 1) % 5 == 0:
            train_batches = getBatches(train_samples, config.batch_size)
        print("----- Epoch {}/{} -----".format(iter + 1, config.numEpochs))

        encoder_s_opt, decoder_s_opt, encoder_t_opt, decoder_t_opt, mlp_opt, emb_opt = createOptimizer(encoder_s,
                                                                                                       decoder_s,
                                                                                                       encoder_t,
                                                                                                       decoder_t, mlp,
                                                                                                       embedding,
                                                                                                       optimizer_type,
                                                                                                       learning_rate,
                                                                                                       weight_decay,
                                                                                                       iter)
        for next_batch in train_batches:
            current_step += 1
            loss_train = train(encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, encoder_s_opt,
                               decoder_s_opt,
                               encoder_t_opt, decoder_t_opt, mlp_opt, emb_opt, next_batch, config.clip_max_norm,
                               loss_fn)
            if current_step % config.steps_per_checkpoint == 0:
                total_loss = 0.
                total_per = 0.
                with torch.no_grad():
                    for nextBatch in val_batches:
                        _loss = val(encoder_s, decoder_t, mlp, embedding, nextBatch)
                        perplexity = math.exp(float(_loss)) if _loss < 50 else float('inf')
                        total_loss += _loss * len(nextBatch.encoder_inputs_length)
                        total_per += perplexity * len(nextBatch.encoder_inputs_length)
                val_loss = total_loss / len(val_samples)
                val_per = total_per / len(val_samples)
                tqdm.write("----- Step %d -- Loss_train %.4f -- Loss_test %.4f -- Time %s" % (
                current_step, loss_train, val_loss, time.strftime('%Y.%m.%d %H:%M:%S', time.localtime(time.time()))))
                if val_loss < best_loss:
                    best_loss = val_loss
                    torch.save(encoder_s, config.model_dir + '/encoder_s_val')
                    torch.save(decoder_s, config.model_dir + '/decoder_s_val')
                    torch.save(encoder_t, config.model_dir + '/encoder_t_val')
                    torch.save(decoder_t, config.model_dir + '/decoder_t_val')
                    torch.save(mlp, config.model_dir + '/mlp_val')
                    torch.save(embedding, config.model_dir + '/embedding_val')
                else:
                    torch.save(encoder_s, config.model_dir + '/encoder_s')
                    torch.save(decoder_s, config.model_dir + '/decoder_s')
                    torch.save(encoder_t, config.model_dir + '/encoder_t')
                    torch.save(decoder_t, config.model_dir + '/decoder_t')
                    torch.save(mlp, config.model_dir + '/mlp')
                    torch.save(embedding, config.model_dir + '/embedding')

def train(encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, encoder_s_opt, decoder_s_opt,
          encoder_t_opt, decoder_t_opt, mlp_opt, emb_opt, batch,
          clip_max_norm, loss_fn):
    encoder_s.train()
    decoder_s.train()
    encoder_t.train()
    decoder_t.train()
    mlp.train()
    embedding.train()
    encoder_s_opt.zero_grad()
    decoder_s_opt.zero_grad()
    encoder_t_opt.zero_grad()
    decoder_t_opt.zero_grad()
    mlp_opt.zero_grad()
    emb_opt.zero_grad()
    encoder_inputs = batch.encoder_inputs.to(device)
    encoder_inputs_length = batch.encoder_inputs_length.to(device)
    mask_s = batch.mask_s.to(device)
    decoder_targets = batch.decoder_targets.to(device)
    decoder_targets_length = batch.decoder_targets_length.to(device)
    mask_t = batch.mask_t.to(device)
    batch_size = len(batch.encoder_inputs)

    # 计算J1(θ)=−logP(x̃ |x;θ)
    loss_1 = 0
    print_losses_1 = 0
    n_totals_1 = 0
    output_length = encoder_inputs.size()[1]
    encoder_output, encoder_hidden = encoder_s(embedding(encoder_inputs), encoder_inputs_length)
    decoder_hidden = encoder_hidden
    decoder_input = torch.LongTensor([SOS_token] * batch_size).unsqueeze(1).to(device)
    use_teacher_forcing = True if random.random() < config.SCHEDULED_SAMPLING_RATIO else False
    if use_teacher_forcing:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_s(embedding(decoder_input), decoder_hidden,
                                                                  encoder_output)
            decoder_input = encoder_inputs[:, i].view(batch_size, 1)
            mask_loss_1, nTotal_1 = maskNLLLoss(logits, encoder_inputs[:, i], mask_s[:, i])
            loss_1 += mask_loss_1
            print_losses_1 += mask_loss_1.item() * nTotal_1
            n_totals_1 += nTotal_1
    else:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_s(embedding(decoder_input), decoder_hidden,
                                                                  encoder_output)
            topv, topi = logits.data.topk(1)
            ni = topi
            decoder_input = torch.LongTensor(ni).view(batch_size, 1).to(device)
            mask_loss_1, nTotal_1 = maskNLLLoss(logits, encoder_inputs[:, i], mask_s[:, i])
            loss_1 += mask_loss_1
            print_losses_1 += mask_loss_1.item() * nTotal_1
            n_totals_1 += nTotal_1

    # 计算J2(φ)=−logP(ỹ|y;φ)
    loss_2 = 0
    print_losses_2 = 0
    n_totals_2 = 0
    output_length = decoder_targets.size()[1]
    encoder_output, encoder_hidden = encoder_t(embedding(decoder_targets), decoder_targets_length, use_pack=False)
    decoder_hidden = encoder_hidden
    decoder_input = torch.LongTensor([SOS_token] * batch_size).unsqueeze(1).to(device)
    use_teacher_forcing = True if random.random() < config.SCHEDULED_SAMPLING_RATIO else False
    if use_teacher_forcing:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_t(embedding(decoder_input), decoder_hidden,
                                                                  encoder_output)
            decoder_input = decoder_targets[:, i].view(batch_size, 1)
            mask_loss_2, nTotal_2 = maskNLLLoss(logits, decoder_targets[:, i], mask_t[:, i])
            loss_2 += mask_loss_2
            print_losses_2 += mask_loss_2.item() * nTotal_2
            n_totals_2 += nTotal_2
    else:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_t(embedding(decoder_input), decoder_hidden,
                                                                  encoder_output)
            topv, topi = logits.data.topk(1)
            ni = topi
            decoder_input = torch.LongTensor(ni).view(batch_size, 1).to(device)
            mask_loss_2, nTotal_2 = maskNLLLoss(logits, decoder_targets[:, i], mask_t[:, i])
            loss_2 += mask_loss_2
            print_losses_2 += mask_loss_2.item() * nTotal_2
            n_totals_2 += nTotal_2

    # 计算J3(γ)= 1∥t−s∥2
    _, s = encoder_s(embedding(encoder_inputs), encoder_inputs_length)
    _, t = encoder_t(embedding(decoder_targets), decoder_targets_length, use_pack=False)

    loss_3 = torch.sum(loss_fn(mlp(s), t))

    # 计算J4(θ,φ,γ) = −logP(y|x;θ,φ,γ)
    loss_4 = 0
    print_losses_4 = 0
    n_totals_4 = 0
    output_length = decoder_targets.size()[1]
    encoder_output, encoder_hidden = encoder_s(embedding(encoder_inputs), encoder_inputs_length)
    decoder_hidden = mlp(encoder_hidden)
    decoder_input = torch.LongTensor([SOS_token] * batch_size).unsqueeze(1).to(device)
    use_teacher_forcing = True if random.random() < config.SCHEDULED_SAMPLING_RATIO else False
    if use_teacher_forcing:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_t(embedding(decoder_input), decoder_hidden,
                                                                  encoder_output)
            decoder_input = decoder_targets[:, i].view(batch_size, 1)
            mask_loss_4, nTotal_4 = maskNLLLoss(logits, decoder_targets[:, i], mask_t[:, i])
            loss_4 += mask_loss_4
            print_losses_4 += mask_loss_4.item() * nTotal_4
            n_totals_4 += nTotal_4
    else:
        for i in range(0, output_length):
            decoder_output, logits, _, decoder_hidden = decoder_t(embedding(decoder_input), decoder_hidden,
                                                                  encoder_output)
            topv, topi = logits.data.topk(1)
            ni = topi
            decoder_input = torch.LongTensor(ni).view(batch_size, 1).to(device)
            mask_loss_4, nTotal_4 = maskNLLLoss(logits, decoder_targets[:, i], mask_t[:, i])
            loss_4 += mask_loss_4
            print_losses_4 += mask_loss_4.item() * nTotal_4
            n_totals_4 += nTotal_4

    loss = loss_1 + loss_2 + 0.01 * loss_3 + loss_4
    loss.backward()
    torch.nn.utils.clip_grad_norm_(encoder_s.parameters(), max_norm=clip_max_norm)
    torch.nn.utils.clip_grad_norm_(decoder_s.parameters(), max_norm=clip_max_norm)
    torch.nn.utils.clip_grad_norm_(encoder_t.parameters(), max_norm=clip_max_norm)
    torch.nn.utils.clip_grad_norm_(decoder_t.parameters(), max_norm=clip_max_norm)
    torch.nn.utils.clip_grad_norm_(mlp.parameters(), max_norm=clip_max_norm)
    torch.nn.utils.clip_grad_norm_(embedding.parameters(), max_norm=clip_max_norm)
    decoder_s_opt.step()
    encoder_s_opt.step()
    decoder_t_opt.step()
    encoder_t_opt.step()
    mlp_opt.step()
    emb_opt.step()

    return print_losses_4 / n_totals_4

def val(encoder, decoder, mlp, embedding, batch):
    encoder.eval()
    decoder.eval()
    mlp.eval()
    embedding.eval()
    encoder_inputs = batch.encoder_inputs.to(device)
    encoder_inputs_length = batch.encoder_inputs_length.to(device)
    decoder_targets = batch.decoder_targets.to(device)
    mask_t = batch.mask_t.to(device)
    batch_size = len(batch.encoder_inputs)
    output_length = decoder_targets.size()[1]
    encoder_output, encoder_hidden = encoder(embedding(encoder_inputs), encoder_inputs_length)
    decoder_hidden = mlp(encoder_hidden)
    decoder_input = torch.LongTensor([SOS_token] * batch_size).unsqueeze(1).to(device)

    loss = 0
    print_losses = 0
    n_totals = 0

    for i in range(0, output_length):
        decoder_output, logits, _, decoder_hidden = decoder(embedding(decoder_input), decoder_hidden, encoder_output)
        decoder_input = decoder_targets[:, i].view(batch_size, 1)
        mask_loss, nTotal = maskNLLLoss(logits, decoder_targets[:, i], mask_t[:, i])
        loss += mask_loss
        print_losses += mask_loss.item() * nTotal
        n_totals += nTotal
    return print_losses / n_totals

def build_model():
    data_path = config.data_path
    word2id, id2word, pretrain_embedding, train_samples, val_samples, test_samples = loadDataset(
        data_path)
    if os.path.exists(config.model_dir):
        print('Reloading model from ' + config.model_dir)
        encoder_s = torch.load(config.model_dir + '/encoder_s')
        decoder_s = torch.load(config.model_dir + '/decoder_s')
        encoder_t = torch.load(config.model_dir + '/encoder_t')
        decoder_t = torch.load(config.model_dir + '/decoder_t')
        mlp = torch.load(config.model_dir + '/mlp')
        embedding = torch.load(config.model_dir + '/embedding')
    else:
        print('Building model to ' + config.model_dir)
        os.mkdir(config.model_dir)
        encoder_s = Encoder(input_size=config.embedding_size,
                            hidden_size=config.cell_size,
                            drop_prob=config.keep_prob,
                            cell_type=config.cell_name,
                            nonlinearity=config.nonlinearity,
                            num_layers=config.num_layers,
                            bidirectional=config.bidirectional
                            )

        decoder_s = AttnDecoder(input_size=config.embedding_size,
                                output_size=len(word2id),
                                hidden_size=config.cell_size,
                                drop_prob=config.keep_prob,
                                cell_type=config.cell_name,
                                nonlinearity=config.nonlinearity,
                                num_layers=config.num_layers,
                                bidirectional=config.bidirectional,
                                attn=config.attn
                                )
        encoder_t = Encoder(input_size=config.embedding_size,
                            hidden_size=config.cell_size,
                            drop_prob=config.keep_prob,
                            cell_type=config.cell_name,
                            nonlinearity=config.nonlinearity,
                            num_layers=config.num_layers,
                            bidirectional=config.bidirectional
                            )

        decoder_t = AttnDecoder(input_size=config.embedding_size,
                                output_size=len(word2id),
                                hidden_size=config.cell_size,
                                drop_prob=config.keep_prob,
                                cell_type=config.cell_name,
                                nonlinearity=config.nonlinearity,
                                num_layers=config.num_layers,
                                bidirectional=config.bidirectional,
                                attn=config.attn
                                )
        mlp = MLP(config.cell_size, config.cell_size)
        embedding = Embedding(pretrain_embedding)

        torch.save(encoder_s.to(device), config.model_dir + '/encoder_s')
        torch.save(decoder_s.to(device), config.model_dir + '/decoder_s')
        torch.save(encoder_t.to(device), config.model_dir + '/encoder_t')
        torch.save(decoder_t.to(device), config.model_dir + '/decoder_t')
        torch.save(mlp.to(device), config.model_dir + '/mlp')
        torch.save(embedding.to(device), config.model_dir + '/embedding')
    return encoder_s.to(device), decoder_s.to(device), encoder_t.to(device), decoder_t.to(device), mlp.to(
        device), embedding.to(device), train_samples, val_samples, test_samples

encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, train_samples, val_samples, test_samples = build_model()
trainIters(encoder_s, decoder_s, encoder_t, decoder_t, mlp, embedding, config.optimizer_type, train_samples,
           val_samples,
           config.learning_rate, config.weight_decay)

model.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SOS_token = 1
EOS_token = 2

class Embedding(nn.Module):
    def __init__(self, embedding):
        super(Embedding, self).__init__()
        self.embedding = torch.nn.Embedding.from_pretrained(torch.FloatTensor(np.asarray(embedding, dtype=float)),
                                                            freeze=False)

    def forward(self, input):
        return self.embedding(input)

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, drop_prob, cell_type='GRU', nonlinearity='tanh',
                 num_layers=1, bidirectional=False):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.cell_type = cell_type
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.drop_out = nn.Dropout(drop_prob)
        if cell_type == 'RNN':
            self.cell = nn.RNN(input_size, hidden_size, num_layers, dropout=(0 if num_layers == 1 else drop_prob),
                               batch_first=True, bidirectional=bidirectional)
        elif cell_type == 'LSTM':
            self.cell = nn.LSTM(input_size, hidden_size, num_layers, dropout=(0 if num_layers == 1 else drop_prob),
                                batch_first=True,
                                bidirectional=bidirectional)
        else:
            self.cell = nn.GRU(input_size, hidden_size, num_layers, dropout=(0 if num_layers == 1 else drop_prob),
                               batch_first=True, bidirectional=bidirectional)

    def forward(self, input, lengths, hidden=None, use_pack=True):
        if use_pack:
            packed = torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=True)
            outputs, hidden = self.cell(packed, hidden)
            outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        else:
            outputs, hidden = self.cell(input, hidden)
        return outputs, hidden

    def init_state(self, batch_size):
        if self.cell_type == 'LSTM':
            state = (
                torch.zeros(self.num_layers * 2 if self.bidirectional else self.num_layers, batch_size,
                            self.hidden_size,
                            device=device),
                torch.zeros(self.num_layers * 2 if self.bidirectional else self.num_layers, batch_size,
                            self.hidden_size,
                            device=device))
        else:
            state = torch.zeros(self.num_layers * 2 if self.bidirectional else self.num_layers, batch_size,
                                self.hidden_size, device=device)
        return state

class BasicDecoder(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, drop_prob, cell_type='GRU', nonlinearity='tanh',
                 num_layers=1, bidirectional=False):
        super(BasicDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.cell_type = cell_type
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        if drop_prob < 1.0:
            self.drop_out = nn.Dropout(drop_prob)
        if cell_type == 'RNN':
            self.cell = nn.RNN(input_size, hidden_size, num_layers, batch_first=True,
                               dropout=(0 if num_layers == 1 else drop_prob),
                               bidirectional=bidirectional)
        elif cell_type == 'LSTM':
            self.cell = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True,
                                dropout=(0 if num_layers == 1 else drop_prob),
                                bidirectional=bidirectional)
        else:
            self.cell = nn.GRU(input_size, hidden_size, num_layers, batch_first=True,
                               dropout=(0 if num_layers == 1 else drop_prob),
                               bidirectional=bidirectional)
        self.out = nn.Linear(self.hidden_size * 2 if bidirectional else self.hidden_size, output_size)

    def forward(self, input, hidden=None):
        cell_output, hidden = self.cell(input, hidden)
        proj_output = F.log_softmax(self.out(cell_output), dim=2)
        return cell_output, proj_output, hidden

    def init_state(self, batch_size):
        if self.cell_type == 'LSTM':
            state = (
                torch.zeros(self.num_layers * 2 if self.bidirectional else self.num_layers, batch_size,
                            self.hidden_size,
                            device=device),
                torch.zeros(self.num_layers * 2 if self.bidirectional else self.num_layers, batch_size,
                            self.hidden_size,
                            device=device))
        else:
            state = torch.zeros(self.num_layers * 2 if self.bidirectional else self.num_layers, batch_size,
                                self.hidden_size, device=device)
        return state

class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(torch.ones(hidden_size)))
        self.out = nn.Linear(self.hidden_size * 2, self.hidden_size)

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(-1, encoder_output.size(1), -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = F.softmax(attn_energies, dim=1).unsqueeze(1)
        context = attn_energies.bmm(encoder_outputs)
        attn_res = torch.tanh(self.out(torch.cat((context, hidden), dim=2)))
        # Return the softmax normalized probability scores (with added dimension)
        return attn_res

class AttnDecoder(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, drop_prob, attn='general', cell_type='GRU',
                 nonlinearity='tanh', num_layers=1, bidirectional=False):
        super(AttnDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.cell_type = cell_type
        if drop_prob < 1.0:
            self.drop_out = nn.Dropout(drop_prob)

        self.attn = Attn(attn, hidden_size * 2 if bidirectional else hidden_size)
        if cell_type == 'RNN':
            self.cell = nn.RNN(input_size * 2 if bidirectional else input_size, hidden_size, num_layers,
                               dropout=(0 if num_layers == 1 else drop_prob), batch_first=True,
                               bidirectional=bidirectional)
        elif cell_type == 'LSTM':
            self.cell = nn.LSTM(input_size * 2 if bidirectional else input_size, hidden_size, num_layers,
                                dropout=(0 if num_layers == 1 else drop_prob), batch_first=True,
                                bidirectional=bidirectional)
        else:
            self.cell = nn.GRU(input_size * 2 if bidirectional else input_size, hidden_size, num_layers,
                               dropout=(0 if num_layers == 1 else drop_prob), batch_first=True,
                               bidirectional=bidirectional)
        self.out = nn.Linear(self.hidden_size * 2 if bidirectional else self.hidden_size, output_size)

    def forward(self, input, hidden, encoder_outputs):
        cell_output, hidden = self.cell(input, hidden)
        attn_output = self.attn(cell_output, encoder_outputs)
        proj_output = F.softmax(self.out(attn_output), dim=2).squeeze(1)
        proj_output_log = F.log_softmax(self.out(attn_output), dim=2).squeeze(1)
        return attn_output, proj_output, proj_output_log, hidden

    def init_state(self, batch_size):
        if self.cell_type == 'LSTM':
            state = (
                torch.zeros(self.num_layers * 2 if self.bidirectional else self.num_layers, batch_size,
                            self.hidden_size,
                            device=device),
                torch.zeros(self.num_layers * 2 if self.bidirectional else self.num_layers, batch_size,
                            self.hidden_size,
                            device=device))
        else:
            state = torch.zeros(self.num_layers * 2 if self.bidirectional else self.num_layers, batch_size,
                                self.hidden_size, device=device)
        return state

class MLP(nn.Module):
    def __init__(self, input_size, output_size):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.out = nn.Sequential(
            nn.Linear(input_size, input_size * 2),
            nn.Linear(input_size * 2, input_size * 2),
            nn.Linear(input_size * 2, output_size)
        )

    def forward(self, input):
        return self.out(input)

class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder, mlp, embedding):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.mlp = mlp
        self.embedding = embedding
        self.encoder.eval()
        self.decoder.eval()
        self.mlp.eval()
        self.embedding.eval()

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(self.embedding(input_seq), input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = self.mlp(encoder_hidden)
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long)
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            cell_output, proj_output, _, hidden = self.decoder(self.embedding(decoder_input), decoder_hidden,
                                                               encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(proj_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

class BeamSearchDecoder(nn.Module):
    def __init__(self, encoder, decoder, mlp, embedding, beam_size, max_step, alpha):
        super(BeamSearchDecoder, self).__init__()
        self.k = beam_size
        self.max_step = max_step
        self.encoder = encoder
        self.decoder = decoder
        self.alpha = alpha
        self.mlp = mlp
        self.embedding = embedding
        self.encoder.eval()
        self.decoder.eval()
        self.mlp.eval()
        self.embedding.eval()

    def forward(self, input_seq, input_length):
        encoder_hidden = self.encoder.init_state(input_seq.size()[0])
        encoder_outputs, encoder_hidden = self.encoder(self.embedding(input_seq), input_length)
        decoder_hidden = self.mlp(encoder_hidden)
        decoder_input = torch.LongTensor([[SOS_token]]).to(device)
        decoder_output, _, logit, decoder_hidden = self.decoder(self.embedding(decoder_input), decoder_hidden,
                                                                encoder_outputs)
        topv, topk = logit.data.topk(self.k)
        samples = [[] for i in range(self.k)]
        dead_k = 0
        final_samples = []
        for index in range(self.k):
            topk_prob = topv[0][index]
            topk_index = int(topk[0][index])
            samples[index] = [[topk_index], topk_prob, 0, 0, decoder_hidden, encoder_outputs]

        for _ in range(self.max_step):
            tmp = []
            for index in range(len(samples)):
                tmp.extend(self.beamSearchInfer(samples[index], index))

            # 筛选出topk
            samples_temp = []
            for sample in tmp:
                sample[3] = sample[2] / len(sample[0])
                samples_temp.append(sample)

            samples = sorted(samples_temp, key=self._score, reverse=True)
            samples = samples[:(self.k - dead_k)]
            s_tt = []
            for sample in samples:
                if sample[0][-1] == EOS_token:
                    final_samples.append(sample)
                    dead_k += 1
                else:
                    s_tt.append(sample)
            samples = s_tt
            if len(samples) == 0:
                break

        if len(final_samples) < self.k:
            final_samples.extend(samples[:(self.k - dead_k)])
        return final_samples

    def _score(self, s):
        return s[3]

    def beamSearchInfer(self, sample, k):
        samples = []
        decoder_input = torch.LongTensor([[sample[0][-1]]]).to(device)
        sequence, pre_scores, fin_scores, ave_scores, decoder_hidden, encoder_outputs = sample
        decoder_output, _, logit, decoder_hidden = self.decoder(self.embedding(decoder_input), decoder_hidden,
                                                                encoder_outputs)

        # choose topk
        topv, topk = logit.data.topk(self.k)
        for k in range(self.k):
            topk_prob = topv[0][k]
            topk_index = int(topk.data[0][k])
            pre_scores += topk_prob
            fin_scores = pre_scores - (k - 1) * self.alpha
            samples.append(
                [sequence + [topk_index], pre_scores, fin_scores, ave_scores, decoder_hidden, encoder_outputs])
        return samples

dataloader.py

import os
import jieba
import pickle
import random
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
padToken, goToken, eosToken, unknownToken = 0, 1, 2, 3

class Batch:
    # batch类,里面包含了encoder输入,decoder输入以及他们的长度
    def __init__(self):
        self.encoder_inputs = None
        self.encoder_inputs_length = None
        self.decoder_targets = None
        self.decoder_targets_length = None
        self.mask_t = None
        self.mask_s = None

def loadDataset(filename):
    """
    :param filename: 数据的路径,数据是一个json结构,包含三部分,分别是word2id,即word到id的转换,
    id2word,即id到word的转换 ,以及训练数据trainingSamples,是一个二维数组,形状为N*2,每一行包含问题和回答
    :return: 通过pickle解析我们的数据,返回上述的三部分内容。
    """
    dataset_path = os.path.join(filename)
    print('Loading dataset from {}'.format(dataset_path))
    with open(dataset_path, 'rb') as handle:
        data = pickle.load(handle)
        word2id = data['word2id']
        id2word = data['id2word']
        train_samples = data['train_samples']
        val_samples = data['val_samples']
        test_samples = data['test_samples']
        pretrain_embedding = data['pretrain_embedding']
    return word2id, id2word, pretrain_embedding, train_samples, val_samples, test_samples

def by_score(t):
    return len(t[0])

def createBatch(samples):
    '''
    根据给出的samples(就是一个batch的数据),进行padding并构造成placeholder所需要的数据形式
    :param samples: 一个batch的样本数据,列表,每个元素都是[question, answer]的形式,id
    :return: 处理完之后可以直接传入feed_dict的数据格式
    '''
    batch = Batch()
    _samples = sorted(samples, key=by_score, reverse=True)
    encoder_inputs_length = [len(sample[0]) for sample in _samples]
    decoder_targets_length = [len(sample[1]) + 1 for sample in _samples]
    batch.encoder_inputs_length = torch.LongTensor(encoder_inputs_length)
    # 模型训练时要加eos,所以这里targrt——length要加1
    batch.decoder_targets_length = torch.LongTensor(decoder_targets_length)

    max_source_length = max(encoder_inputs_length)
    max_target_length = max(decoder_targets_length)
    encoder_inputs = []
    decoder_targets = []
    mask_t = []
    mask_s = []
    for index in range(len(_samples)):
        # 将source进行反序并PAD值本batch的最大长度
        # source = list(reversed(sample[0]))
        pad = [padToken] * (max_source_length - len(_samples[index][0]))
        encoder_inputs.append(_samples[index][0] + pad)
        mask_s.append([1] * encoder_inputs_length[index] + [0] * (max_source_length - encoder_inputs_length[index]))

        # 将target进行PAD,并添加END符号
        target = _samples[index][1]
        target.append(2)
        pad = [padToken] * (max_target_length - decoder_targets_length[index])
        decoder_targets.append(target + pad)
        mask_t.append([1] * decoder_targets_length[index] + [0] * (max_target_length - decoder_targets_length[index]))
        # batch.target_inputs.append([goToken] + target + pad[:-1])
    batch.encoder_inputs = torch.LongTensor(encoder_inputs)
    batch.decoder_targets = torch.LongTensor(decoder_targets)
    batch.mask_t = torch.ByteTensor(mask_t)
    batch.mask_s = torch.ByteTensor(mask_s)

    return batch

def getBatches(data, batch_size):
    '''
    根据读取出来的所有数据和batch_size将原始数据分成不同的小batch。对每个batch索引的样本调用createBatch函数进行处理
    :param data: loadDataset函数读取之后的trainingSamples,就是QA对的列表
    :param batch_size: batch大小
    :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度
    :return: 列表,每个元素都是一个batch的样本数据,可直接传入feed_dict进行训练
    '''
    # 每个epoch之前都要进行样本的shuffle
    random.shuffle(data)
    batches = []
    data_len = len(data)

    def genNextSamples():
        for i in range(0, data_len, batch_size):
            yield data[i:min(i + batch_size, data_len)]

    for samples in genNextSamples():
        batch = createBatch(samples)
        batches.append(batch)
    return batches

def sentence2enco(sentence, word2id):
    '''
    测试的时候将用户输入的句子转化为可以直接feed进模型的数据,现将句子转化成id,然后调用createBatch处理
    :param sentence: 用户输入的句子
    :param word2id: 单词与id之间的对应关系字典
    :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度
    :return: 处理之后的数据,可直接feed进模型进行预测
    '''
    if sentence == '':
        return None
    # 分词
    tokens = [word for word in jieba.cut(sentence)]
    # tokens = sentence
    if len(tokens) > 20:
        return None
    # 将每个单词转化为id
    wordIds = []
    for token in tokens:
        wordIds.append(word2id.get(token, unknownToken))
    # 调用createBatch构造batch
    batch = createBatch([[wordIds, []]])
    return batch
rotabulo commented 5 years ago

@loveddy please do not open issues that are not strictly ascribable to our repo.