danpovey / torch_iterative_sampling

Other
8 stars 0 forks source link

Sampling code #1

Open danpovey opened 3 years ago

danpovey commented 3 years ago

@zhu-han this repo contains the sampling code I mentioned to you. The "iterative" aspect of it is not needed here, we just treat it as a simple way to sample from a distribution.

Below, is some icefall code called unsupervised.py, that I was going to use to sample CTC transcripts for use in unsupervised training. I believe sampling is more correct than taking the top-one, and will avoid it collapsing to blank.

# Some utilities for unsupervised training                                                                                                                                                                            

import random
from typing import Optional, Sequence, Tuple, TypeVar, Union, Dict, List

import math
import torch
from k2 import RaggedTensor, Fsa
from torch import nn
from torch import Tensor
import torch_iterative_sampling

Supervisions = Dict[str, torch.Tensor]

def sample_ctc_transcripts_ragged(
        ctc_output: Tensor,
        paths_per_sequence: int,
        modified_topo: bool) -> RaggedTensor:
    """                                                                                                                                                                                                               
      ctc_output: a Tensor of shape (N, T, C), i.e. (batch, time, num_symbols),                                                                                                                                       
                 containing normalized log-probs.                                                                                                                                                                     
      paths_per_sequence: The number of separately sampled paths that are requested                                                                                                                                   
                 per sequence                                                                                                                                                                                         
      modified_topo:  True if the system is using the modified CTC topology where two                                                                                                                                 
                consecutive instances of a nonzero symbol can mean either one or two                                                                                                                                  
                copies of the original symbol.                                                                                                                                                                        

    Returns a RaggedTensor, on the same device as `ctc_output`, with shape (N *                                                                                                                                       
    paths_per_sequence, None), where 1st index is batch_idx * paths_per_sequence                                                                                                                                      
    + path_idx and 2nd idx is the position in the token sequence.  The returned                                                                                                                                       
    RaggedTensor will have no 0's... those will have been removed, as blanks.                                                                                                                                         
    """
    (N, T, C) = ctc_output.shape

    # The 'seq_len' arg below is something specific to the "iterative" part of                                                                                                                                        
    # torch_iterative_sampling, which has to do with "sampling without replacement";                                                                                                                                  
    # here, we don't really want to do "iterative sampling", we just want to                                                                                                                                          
    # sample from the distribution once.                                                                                                                                                                              

    probs = ctc_output.exp()
    sampled = torch_iterative_sampling.iterative_sample(probs,
                                                        num_seqs=paths_per_sequence,
                                                        seq_len=1).to(dtype=torch.int32)
    # `sampled` now has shape:                                                                                                                                                                                        
    # (N, T, paths_per_sequence, 1)                                                                                                                                                                                   
    sampled = sampled.squeeze(3).transpose(1, 2)
    # `sampled` now has shape (N, paths_per_sequence, T)                                                                                                                                                              

    # identical_mask is of shape (N, paths_per_sequence, T-1), and                                                                                                                                                    
    # contains True at each position if sampled[n,s,t] == sampled[n,s,t+1].                                                                                                                                           
    identical_mask = sampled[:,:,1:] == sampled[:,:,:-1]

    if modified_topo:
        # If we are using the modified/simplified CTC topology, it is possible for                                                                                                                                    
        # two consecutive instances of a nonzero symbol to represent either                                                                                                                                           
        # one symbol or two.  We choose either, with probability 0.5.  I think this                                                                                                                                   
        # is correct, perhaps should check though.                                                                                                                                                                    
        identical_mask = identical_mask and (torch.randn(*identical_mask.shape,
                                                         device=identical_mask.device) > 0.5)
    # The following statement replaces repeats of nonzero symbols with 0, so only the                                                                                                                                 
    # final symbol in a chain of identical, consecutive symbols will retain its                                                                                                                                       
    # nonzero value.                                                                                                                                                                                                  
    sampled[:,:,:-1].masked_fill_(identical_mask, 0)

    sampled = sampled.reshape(N * paths_per_sequence, T)

    # The shape of ragged_sampled would be the same as `sampled`.. it's regular.                                                                                                                                      
    # if you query it, though, it would come up as (N * paths_per_sequence, None).                                                                                                                                    
    ragged_sampled = RaggedTensor(sampled)

    # Remove 0's from the ragged tensor, to keep only "real" (non-blank) symbols.                                                                                                                                     
    ragged_sampled = ragged_sampled.remove_values_leq(0)

    # note: you can create the CTC graphs with k2.ctc_graph(ragged_sampled, modified={True,False})                                                                                                                    
    # You can turn into a List[List[int]] with ragged_sampled.tolist().                                                                                                                                               
    return ragged_sampled

def _test_sample_ctc_transcripts_ragged():
    for device in ['cpu', 'cuda']:
        # simple case.. N = 1, T == 2, C == 3                                                                                                                                                                         
        ctc_output = torch.Tensor( [[[ 0., 1., 0. ], [ 1., 0., 0. ] ],
                                    [[ 1., 0., 0. ], [ 0., 0., 1. ] ]]).to(device=device).log()
        r = sample_ctc_transcripts_ragged(ctc_output, paths_per_sequence=1, modified_topo=False)
        print("r = ", r)
        assert r == RaggedTensor('[[1], [2]]', dtype=torch.int32,
                                 device=device)

    for device in ['cpu', 'cuda']:
        # simple case.. N = 1, T == 3, C == 3, with repeats.                                                                                                                                                          
        # We use modified == False, so the repeats should be removed.                                                                                                                                                 
        ctc_output = torch.Tensor( [[[ 0., 1., 0. ], [0., 1., 0.], [ 1., 0., 0. ] ],
                                    [[ 1., 0., 0. ], [0., 0., 1.], [ 0., 0., 1. ] ]]).to(device=device).log()
        r = sample_ctc_transcripts_ragged(ctc_output, paths_per_sequence=1, modified_topo=False)
        print("r = ", r)
        assert r == RaggedTensor('[[1], [2]]', dtype=torch.int32,
                                 device=device)

if __name__ == "__main__":
    _test_sample_ctc_transcripts_ragged()
zhu-han commented 3 years ago

Cool! Is there any specific reasons to use this repo for the sampling? I find there is similar function in Pytorch: torch.distributions.categorical.Categorical

danpovey commented 3 years ago

Oh, I didn't know about that. Then that should be fine.

zhu-han commented 3 years ago

OK, thanks! I’ll try to use this sampling idea in the unsupervised training.