NVIDIA / VariantWorks

Deep Learning based variant calling toolkit - https://clara-parabricks.github.io/VariantWorks/
Apache License 2.0
44 stars 11 forks source link

VariantWorks simple_consensus _caller does not go to zero training loss / perfect training accuracy when expected #166

Closed michaelbrownid closed 3 years ago

michaelbrownid commented 3 years ago

Using variantworks simple_consensus_caller, train large model (RNN 512 units by 4 layers) on a single ZMW. I expect training loss to go to zero and accuracy to go to 100% as the number of model parameters is vastly larger than the training data size.

Run experiment with VariantWorks and then the exact same model in simplest pytorch implementation (included below).

mkdir bug
cd bug

module add minimap2/2.17
module add samtools

python /home/UNIXHOME/mbrown/mbrown/workspace2020Q3/VariantWorks/VariantWorks/samples/simple_consensus_caller/pileup_hdf5_generator.py \
-r /home/UNIXHOME/mbrown/mbrown/workspace2020Q3/VariantWorks/VariantWorks/samples/simple_consensus_caller/data/samples/1 -o train.hdf -t 4

python3 /home/UNIXHOME/mbrown/mbrown/workspace2020Q3/VariantWorks/VariantWorks/samples/simple_consensus_caller/consensus_trainer.py \
--train-hdf train.hdf \
--epochs 4096 \
--gru_size 512 \
--gru_layers 4 \
--lr 1.0E-03 \
--model-dir /home/UNIXHOME/mbrown/mbrown/workspace2021Q2/train-zero/bug \
> bug.out 2> bug.err

### now run pytorch bare-bones ripout
python3 /home/UNIXHOME/mbrown/mbrown/workspace2021Q2/train-zero/ripBUG.py > ripbug.out 2> ripbug.err

================================
Summarize:

dat = read.table("bug/bug.out.dat",head=T,sep="\t")
datr = read.table("bug/ripbug.out",head=F,sep="\t")

png("bugLoss.png")
plot(datr$V2,datr$V4,type="b",col="red")
points(dat$epoch,dat$loss,type="b",col="black")
title("train loss. bl=variantworks, red=pytorch. VW does not go to zero")
dev.off()

dat$qv = -10*log10(1.0-dat$acc)
datr$qv = -10*log10(1.0-datr$V6+1.0E-5)

png("bugAcc.png")
plot(datr$V2,datr$qv,type="b",col="red")
points(dat$epoch,dat$qv,type="b",col="black")
title("train acc. bl=variantworks, red=pytorch.\nVW does not go to accuarcy=1.0 def-> QV50")
dev.off()

image

image

ripBug.py

import torch
from torch import nn

import h5py
import sys
import numpy as np

################################
# Data

ff= h5py.File("train.hdf","r")

datTrainFeatures = torch.Tensor(ff["features"])
datTrainLabels = torch.from_numpy(np.array(ff["labels"]).astype("int64")) # pytorch cross entropy wants long

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

class NeuralNetwork(nn.Module):
    def __init__(self, input_feature_size, num_output_logits,
                 gru_size=128, gru_layers=2, apply_softmax=False):
        """Construct an Consensus RNN NeMo instance.

        Args:
            input_feature_size : Length of input feature set.
            num_output_logits : Number of output classes of classifier.
            gru_size : Number of units in RNN
            gru_layers : Number of layers in RNN
            apply_softmax : Apply softmax to the output of the classifier.

        Returns:
            Instance of class.
        """
        super().__init__()
        self.num_output_logits = num_output_logits
        self.apply_softmax = apply_softmax
        self.gru_size = gru_size
        self.gru_layers = gru_layers

        self.gru = nn.GRU(input_feature_size, gru_size, gru_layers, batch_first=True, bidirectional=True)
        self.classifier = nn.Linear(2 * gru_size, self.num_output_logits)  # 2* for bidirectional

        # self._device = torch.device(
        #     "cuda" if self.placement == DeviceType.GPU else "cpu")
        # self.to(self._device)

    def forward(self, encoding):
        """Abstract function to run the network.

        Args:
            encoding : Input sequence to run network on.

        Returns:
            Output of forward pass.
        """
        encoding, h_n = self.gru(encoding)
        encoding = self.classifier(encoding)
        if self.apply_softmax:
            encoding = F.softmax(encoding, dim=2)
        return encoding

model = NeuralNetwork(10, 5, 256, 4).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train(dataFeatures, dataLabels, model, loss_fn, optimizer):
    size = dataFeatures.shape[0]*dataFeatures.shape[1] # 32 windows of 1024 bases each
    model.train()
    losssum, correctsum = 0, 0
    batch = 0
    X = dataFeatures
    y = dataLabels
    if True: # one batch
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)

        newpred= pred.permute(0,2,1)

        loss = loss_fn(newpred, y)
        losssum += loss.item()
        correctsum += (newpred.argmax(1) == y).type(torch.float).sum().item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
          # do sum not avg for loss to follow variantworks
          myloss = losssum # / size
          mycorrect = correctsum / size
          print(f"avgTrainLoss:\t{myloss}\tTrainCorrect:\t{mycorrect}\t",end="")
          print()

epochs = 4096
for t in range(epochs):
    print(f"Epoch\t{t+1}\t",end="")
    train(datTrainFeatures, datTrainLabels, model, loss_fn, optimizer)
michaelbrownid commented 3 years ago

I believe I have found the issue. torch.nn.CrossEntropyLoss "input is expected to contain raw, unnormalized scores for each class". Or in other words the "logit" scores before softmax rather than the probabilities after the softmax (somewhat poor choice of words on the pytorch docs page). Consequently the last apply_softmax in the model should be False in training; and True when you want to test and get final probabilities. "consensus_trainer.py" incorrectly has True during training

tijyojwad commented 3 years ago

Hi @michaelbrownid ! Thank you for finding this issue and the solution :). We'll get that fix merged ASAP!

ohadmo commented 3 years ago

Hey @michaelbrownid! this issue is resolved in #167