fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.4k stars 244 forks source link

导入惊蛰框架后,原本的训练流程出错 #590

Open Kline-song opened 6 days ago

Kline-song commented 6 days ago

Description

我在尝试使用惊蛰框架,将ann网络结构转换为snn网络结构时,导入惊蛰框架后,原本可以正常进行的ANN训练代码出现警告,且后续的训练流程无法进行了。 警告内容为:

WARNING:tensorboardX.x2num:NaN or Inf found in input tensor.

导入的代码为:

from spikingjelly.activation_based import neuron

而我并没有使用导入的惊蛰模块,仅仅是导入了惊蛰框架。

请问为什么会发生这种情况?

For faster response

@fangwei123456

Issue type

SpikingJelly version

0.0.0.0.14

Met4physics commented 6 days ago

你可以提供完整的代码吗

Kline-song commented 6 days ago

你可以提供完整的代码吗 @Met4physics 可以的,我是想要将wenet开源工具包中transformer的变体conformer用snn予以实现,目前才刚刚导入了spikingjelly框架,但是发现就不能进行训练了,如果注释导入的代码,则可以正常训练。 以下是我的encoder代码。


"""Encoder definition."""
from typing import Tuple, Optional

import copy import torch

from wenet.transformer.attention import MultiHeadedAttention from wenet.transformer.attention import RelPositionMultiHeadedAttention from wenet.transformer.convolution import ConvolutionModule from wenet.transformer.embedding import PositionalEncoding from wenet.transformer.embedding import RelPositionalEncoding from wenet.transformer.embedding import NoPositionalEncoding from wenet.snnconformer.encoder_layer import SNNConformerEncoderLayer from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward from wenet.transformer.subsampling import Conv2dSubsampling4 from wenet.transformer.subsampling import Conv2dSubsampling6 from wenet.transformer.subsampling import Conv2dSubsampling8 from wenet.transformer.subsampling import EmbedinigNoSubsampling from wenet.transformer.subsampling import LinearNoSubsampling from wenet.transformer.subsampling import IdentitySubsampling

from wenet.utils.common import get_activation from wenet.utils.mask import make_pad_mask from wenet.utils.mask import add_optional_chunk_mask

from spikingjelly.activation_based import neuron

class SNNBaseEncoder(torch.nn.Module): def init( self, input_size: int, output_size: int = 256, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: str = "conv2d", pos_enc_layer_type: str = "abs_pos", normalize_before: bool = True, static_chunk_size: int = 0, use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, ): """ Args: input_size (int): input dim output_size (int): dimension of attention attention_heads (int): the number of heads of multi head attention linear_units (int): the hidden units number of position-wise feed forward num_blocks (int): the number of decoder blocks dropout_rate (float): dropout rate attention_dropout_rate (float): dropout rate in attention positional_dropout_rate (float): dropout rate after adding positional encoding input_layer (str): input layer type. optional [linear, conv2d, conv2d6, conv2d8] pos_enc_layer_type (str): Encoder positional encoding layer type. opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] normalize_before (bool): True: use layer_norm before each sub-block of a layer. False: use layer_norm after each sub-block of a layer. static_chunk_size (int): chunk size for static chunk training and decoding use_dynamic_chunk (bool): whether use dynamic chunk size for training or not, You can only use fixed chunk(chunk_size > 0) or dyanmic chunk size(use_dynamic_chunk = True) global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module use_dynamic_left_chunk (bool): whether use dynamic left chunk in dynamic chunk training """ super().init() self._output_size = output_size

    if pos_enc_layer_type == "abs_pos":
        pos_enc_class = PositionalEncoding
    elif pos_enc_layer_type == "rel_pos":
        pos_enc_class = RelPositionalEncoding
    elif pos_enc_layer_type == "no_pos":
        pos_enc_class = NoPositionalEncoding
    else:
        raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)

    if input_layer == 'identity':
        subsampling_class = IdentitySubsampling
    if input_layer == "linear":
        subsampling_class = LinearNoSubsampling
    elif input_layer == "conv2d":
        subsampling_class = Conv2dSubsampling4
    elif input_layer == "conv2d6":
        subsampling_class = Conv2dSubsampling6
    elif input_layer == "conv2d8":
        subsampling_class = Conv2dSubsampling8
    elif input_layer == "embed":
        subsampling_class = EmbedinigNoSubsampling
    else:
        raise ValueError("unknown input_layer: " + input_layer)

    self.global_cmvn = global_cmvn
    self.embed = subsampling_class(
        input_size,
        output_size,
        dropout_rate,
        pos_enc_class(output_size, positional_dropout_rate),
    )

    self.normalize_before = normalize_before
    self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
    self.static_chunk_size = static_chunk_size
    self.use_dynamic_chunk = use_dynamic_chunk
    self.use_dynamic_left_chunk = use_dynamic_left_chunk

def output_size(self) -> int:
    return self._output_size

def forward(
        self,
        xs: torch.Tensor,
        xs_lens: torch.Tensor,
        decoding_chunk_size: int = 0,
        num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Embed positions in tensor.

    Args:
        xs: padded input tensor (B, T, D)
        xs_lens: input length (B)
        decoding_chunk_size: decoding chunk size for dynamic chunk
            0: default for training, use random dynamic chunk.
            <0: for decoding, use full chunk.
            >0: for decoding, use fixed chunk size as set.
        num_decoding_left_chunks: number of left chunks, this is for decoding,
        the chunk size is decoding_chunk_size.
            >=0: use num_decoding_left_chunks
            <0: use all left chunks
    Returns:
        encoder output tensor xs, and subsampled masks
        xs: padded output tensor (B, T' ~= T/subsample_rate, D)
        masks: torch.Tensor batch padding mask after subsample
            (B, 1, T' ~= T/subsample_rate)
    """
    T = xs.size(1)
    masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
    if self.global_cmvn is not None:
        xs = self.global_cmvn(xs)
    xs, pos_emb, masks = self.embed(xs, masks)
    mask_pad = masks  # (B, 1, T/subsample_rate)
    chunk_masks = add_optional_chunk_mask(xs, masks,
                                          self.use_dynamic_chunk,
                                          self.use_dynamic_left_chunk,
                                          decoding_chunk_size,
                                          self.static_chunk_size,
                                          num_decoding_left_chunks)
    for layer in self.encoders:
        xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
    if self.normalize_before:
        xs = self.after_norm(xs)
    # Here we assume the mask is not changed in encoder layers, so just
    # return the masks before encoder layers, and the masks will be used
    # for cross attention with decoder later
    return xs, masks

def forward_chunk(
        self,
        xs: torch.Tensor,
        offset: int,
        required_cache_size: int,
        att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
        cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
        att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """ Forward just one chunk

    Args:
        xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
            where `time == (chunk_size - 1) * subsample_rate + \
                    subsample.right_context + 1`
        offset (int): current offset in encoder output time stamp
        required_cache_size (int): cache size required for next chunk
            compuation
            >=0: actual cache size
            <0: means all history cache is required
        att_cache (torch.Tensor): cache tensor for KEY & VALUE in
            transformer/conformer attention, with shape
            (elayers, head, cache_t1, d_k * 2), where
            `head * d_k == hidden-dim` and
            `cache_t1 == chunk_size * num_decoding_left_chunks`.
        cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
            (elayers, b=1, hidden-dim, cache_t2), where
            `cache_t2 == cnn.lorder - 1`

    Returns:
        torch.Tensor: output of current input xs,
            with shape (b=1, chunk_size, hidden-dim).
        torch.Tensor: new attention cache required for next chunk, with
            dynamic shape (elayers, head, ?, d_k * 2)
            depending on required_cache_size.
        torch.Tensor: new conformer cnn cache required for next chunk, with
            same shape as the original cnn_cache.

    """
    assert xs.size(0) == 1
    # tmp_masks is just for interface compatibility
    tmp_masks = torch.ones(1,
                           xs.size(1),
                           device=xs.device,
                           dtype=torch.bool)
    tmp_masks = tmp_masks.unsqueeze(1)
    if self.global_cmvn is not None:
        xs = self.global_cmvn(xs)
    # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
    xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
    # NOTE(xcsong): After  embed, shape(xs) is (b=1, chunk_size, hidden-dim)
    elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
    chunk_size = xs.size(1)
    attention_key_size = cache_t1 + chunk_size
    pos_emb = self.embed.position_encoding(
        offset=offset - cache_t1, size=attention_key_size)
    if required_cache_size < 0:
        next_cache_start = 0
    elif required_cache_size == 0:
        next_cache_start = attention_key_size
    else:
        next_cache_start = max(attention_key_size - required_cache_size, 0)
    r_att_cache = []
    r_cnn_cache = []
    for i, layer in enumerate(self.encoders):
        # NOTE(xcsong): Before layer.forward
        #   shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
        #   shape(cnn_cache[i])       is (b=1, hidden-dim, cache_t2)
        xs, _, new_att_cache, new_cnn_cache = layer(
            xs, att_mask, pos_emb,
            att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
            cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache
        )
        # NOTE(xcsong): After layer.forward
        #   shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
        #   shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
        r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
        r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
    if self.normalize_before:
        xs = self.after_norm(xs)

    # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
    #   ? may be larger than cache_t1, it depends on required_cache_size
    r_att_cache = torch.cat(r_att_cache, dim=0)
    # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
    r_cnn_cache = torch.cat(r_cnn_cache, dim=0)

    return xs, r_att_cache, r_cnn_cache

def forward_chunk_by_chunk(
        self,
        xs: torch.Tensor,
        decoding_chunk_size: int,
        num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """ Forward input chunk by chunk with chunk_size like a streaming
        fashion

    Here we should pay special attention to computation cache in the
    streaming style forward chunk by chunk. Three things should be taken
    into account for computation in the current network:
        1. transformer/conformer encoder layers output cache
        2. convolution in conformer
        3. convolution in subsampling

    However, we don't implement subsampling cache for:
        1. We can control subsampling module to output the right result by
           overlapping input instead of cache left context, even though it
           wastes some computation, but subsampling only takes a very
           small fraction of computation in the whole model.
        2. Typically, there are several covolution layers with subsampling
           in subsampling module, it is tricky and complicated to do cache
           with different convolution layers with different subsampling
           rate.
        3. Currently, nn.Sequential is used to stack all the convolution
           layers in subsampling, we need to rewrite it to make it work
           with cache, which is not prefered.
    Args:
        xs (torch.Tensor): (1, max_len, dim)
        chunk_size (int): decoding chunk size
    """
    assert decoding_chunk_size > 0
    # The model is trained by static or dynamic chunk
    assert self.static_chunk_size > 0 or self.use_dynamic_chunk
    subsampling = self.embed.subsampling_rate
    context = self.embed.right_context + 1  # Add current frame
    stride = subsampling * decoding_chunk_size
    decoding_window = (decoding_chunk_size - 1) * subsampling + context
    num_frames = xs.size(1)
    att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
    cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
    outputs = []
    offset = 0
    required_cache_size = decoding_chunk_size * num_decoding_left_chunks

    # Feed forward overlap input step by step
    for cur in range(0, num_frames - context + 1, stride):
        end = min(cur + decoding_window, num_frames)
        chunk_xs = xs[:, cur:end, :]
        (y, att_cache, cnn_cache) = self.forward_chunk(
            chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
        outputs.append(y)
        offset += y.size(1)
    ys = torch.cat(outputs, 1)
    masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool)
    return ys, masks

class SNNConformerEncoder(SNNBaseEncoder): """Conformer encoder module."""

def __init__(
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        input_layer: str = "conv2d",
        pos_enc_layer_type: str = "rel_pos",
        normalize_before: bool = True,
        static_chunk_size: int = 0,
        use_dynamic_chunk: bool = False,
        global_cmvn: torch.nn.Module = None,
        use_dynamic_left_chunk: bool = False,
        positionwise_conv_kernel_size: int = 1,
        macaron_style: bool = True,
        selfattention_layer_type: str = "rel_selfattn",
        activation_type: str = "swish",
        use_cnn_module: bool = True,
        cnn_module_kernel: int = 15,
        causal: bool = False,
        cnn_module_norm: str = "batch_norm",
):
    """Construct ConformerEncoder

    Args:
        input_size to use_dynamic_chunk, see in BaseEncoder
        positionwise_conv_kernel_size (int): Kernel size of positionwise
            conv1d layer.
        macaron_style (bool): Whether to use macaron style for
            positionwise layer.
        selfattention_layer_type (str): Encoder attention layer type,
            the parameter has no effect now, it's just for configure
            compatibility.
        activation_type (str): Encoder activation function type.
        use_cnn_module (bool): Whether to use convolution module.
        cnn_module_kernel (int): Kernel size of convolution module.
        causal (bool): whether to use causal convolution or not.
    """
    super().__init__(input_size, output_size, attention_heads,
                     linear_units, num_blocks, dropout_rate,
                     positional_dropout_rate, attention_dropout_rate,
                     input_layer, pos_enc_layer_type, normalize_before,
                     static_chunk_size, use_dynamic_chunk,
                     global_cmvn, use_dynamic_left_chunk)
    activation = get_activation(activation_type)

    # self-attention module definition
    if pos_enc_layer_type != "rel_pos":
        encoder_selfattn_layer = MultiHeadedAttention
    else:
        encoder_selfattn_layer = RelPositionMultiHeadedAttention
    encoder_selfattn_layer_args = (
        attention_heads,
        output_size,
        attention_dropout_rate,
    )
    # feed-forward module definition
    positionwise_layer = PositionwiseFeedForward
    positionwise_layer_args = (
        output_size,
        linear_units,
        dropout_rate,
        activation,
    )
    # convolution module definition
    convolution_layer = ConvolutionModule
    convolution_layer_args = (output_size, cnn_module_kernel, activation,
                              cnn_module_norm, causal)

    self.encoders = torch.nn.ModuleList([
        SNNConformerEncoderLayer(
            output_size,
            encoder_selfattn_layer(*encoder_selfattn_layer_args),
            positionwise_layer(*positionwise_layer_args),
            positionwise_layer(
                *positionwise_layer_args) if macaron_style else None,
            convolution_layer(
                *convolution_layer_args) if use_cnn_module else None,
            dropout_rate,
            normalize_before,
        ) for _ in range(num_blocks)
    ])
Met4physics commented 6 days ago

非常奇怪的bug,你在训练代码中使用tensorboardx了吗

Kline-song commented 6 days ago

@Met4physics 感谢回复!

我在train.py中调用tensorboardx中的SummaryWriter对训练过程进行了可视化。

具体代码如下所示:

from __future__ import print_function

import argparse
import copy
import datetime
import deepspeed
import json
import logging
import os

import torch
import torch.distributed as dist
import torch.optim as optim
import yaml

from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live  # noqa
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live  # noqa
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
# 这里引入了tensorboardX的SummaryWriter
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader

from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import (load_checkpoint, save_checkpoint,
                                    load_trained_modules)
from wenet.utils.executor import Executor
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing
from wenet.utils.config import override_config
from wenet.utils.init_model import init_model
from wenet.utils.common import str2bool

def get_args():
    parser = argparse.ArgumentParser(description='training your network')
    parser.add_argument('--config', required=True, help='config file')
    parser.add_argument('--data_type',
                        default='raw',
                        choices=['raw', 'shard'],
                        help='train and cv data type')
    parser.add_argument('--train_data', required=True, help='train data file')
    parser.add_argument('--cv_data', required=True, help='cv data file')
    parser.add_argument('--gpu',
                        type=int,
                        default=-1,
                        help='gpu id for this local rank, -1 for cpu')
    parser.add_argument('--model_dir', required=True, help='save model dir')
    parser.add_argument('--checkpoint', help='checkpoint model')
    parser.add_argument('--tensorboard_dir',
                        default='tensorboard',
                        help='tensorboard log dir')
    parser.add_argument('--ddp.rank',
                        dest='rank',
                        default=0,
                        type=int,
                        help='global rank for distributed training')
    parser.add_argument('--ddp.world_size',
                        dest='world_size',
                        default=-1,
                        type=int,
                        help='''number of total processes/gpus for
                        distributed training''')
    parser.add_argument('--ddp.dist_backend',
                        dest='dist_backend',
                        default='nccl',
                        choices=['nccl', 'gloo'],
                        help='distributed backend')
    parser.add_argument('--ddp.init_method',
                        dest='init_method',
                        default=None,
                        help='ddp init method')
    parser.add_argument('--num_workers',
                        default=0,
                        type=int,
                        help='num of subprocess workers for reading')
    parser.add_argument('--pin_memory',
                        action='store_true',
                        default=False,
                        help='Use pinned memory buffers used for reading')
    parser.add_argument('--use_amp',
                        action='store_true',
                        default=False,
                        help='Use automatic mixed precision training')
    parser.add_argument('--fp16_grad_sync',
                        action='store_true',
                        default=False,
                        help='Use fp16 gradient sync for ddp')
    parser.add_argument('--cmvn', default=None, help='global cmvn file')
    parser.add_argument('--symbol_table',
                        required=True,
                        help='model unit symbol table for training')
    parser.add_argument("--non_lang_syms",
                        help="non-linguistic symbol file. One symbol per line.")
    parser.add_argument('--prefetch',
                        default=100,
                        type=int,
                        help='prefetch number')
    parser.add_argument('--bpe_model',
                        default=None,
                        type=str,
                        help='bpe model for english part')
    parser.add_argument('--override_config',
                        action='append',
                        default=[],
                        help="override yaml config")
    parser.add_argument("--enc_init",
                        default=None,
                        type=str,
                        help="Pre-trained model to initialize encoder")
    parser.add_argument("--enc_init_mods",
                        default="encoder.",
                        type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
                        help="List of encoder modules \
                        to initialize ,separated by a comma")
    parser.add_argument('--lfmmi_dir',
                        default='',
                        required=False,
                        help='LF-MMI dir')

    # Begin deepspeed related config
    parser.add_argument('--local_rank', type=int, default=-1,
                        help='local rank passed from distributed launcher')
    parser.add_argument('--deepspeed.save_states',
                        dest='save_states',
                        default='model_only',
                        choices=['model_only', 'model+optimizer'],
                        help='save model/optimizer states')
    # End deepspeed related config
    parser.add_argument('--find_unused_parameters',
                        type=str2bool,
                        default=True,
                        help='find unused parameters in ddp')

    # DeepSpeed automaticly add '--deepspeed' and '--deepspeed_config' to parser
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()
    return args

def main():
    args = get_args()
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')
    # NOTE(xcsong): deepspeed set CUDA_VISIBLE_DEVICES internally
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) if not args.deepspeed \
        else os.environ['CUDA_VISIBLE_DEVICES']

    # Set random seed
    torch.manual_seed(777)
    with open(args.config, 'r') as fin:
        configs = yaml.load(fin, Loader=yaml.FullLoader)
    if len(args.override_config) > 0:
        configs = override_config(configs, args.override_config)
    if args.deepspeed:
        with open(args.deepspeed_config, 'r') as fin:
            ds_configs = json.load(fin)
        if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
            configs["ds_dtype"] = "fp16"
        elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
            configs["ds_dtype"] = "bf16"
        else:
            configs["ds_dtype"] = "fp32"

    # deepspeed read world_size from env
    if args.deepspeed:
        assert args.world_size == -1
    # distributed means pytorch native ddp, it parse world_size from args
    distributed = args.world_size > 1
    local_rank = args.rank
    world_size = args.world_size
    if distributed:
        logging.info('training on multiple gpus, this gpu {}'.format(args.gpu))
        dist.init_process_group(args.dist_backend,
                                init_method=args.init_method,
                                world_size=world_size,
                                rank=local_rank)
    elif args.deepspeed:
        # Update local_rank & world_size from enviroment variables
        local_rank = int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        deepspeed.init_distributed(dist_backend=args.dist_backend,
                                   init_method=args.init_method,
                                   rank=local_rank,
                                   world_size=world_size)

    symbol_table = read_symbol_table(args.symbol_table)

    train_conf = configs['dataset_conf']
    cv_conf = copy.deepcopy(train_conf)
    cv_conf['speed_perturb'] = False
    cv_conf['spec_aug'] = False
    cv_conf['spec_sub'] = False
    cv_conf['spec_trim'] = False
    cv_conf['shuffle'] = False
    non_lang_syms = read_non_lang_symbols(args.non_lang_syms)

    if args.deepspeed:
        assert train_conf['batch_conf']['batch_type'] == "static"
        assert ds_configs["train_micro_batch_size_per_gpu"] == 1
        configs['accum_grad'] = ds_configs["gradient_accumulation_steps"]
    train_dataset = Dataset(args.data_type, args.train_data, symbol_table,
                            train_conf, args.bpe_model, non_lang_syms, True)
    cv_dataset = Dataset(args.data_type,
                         args.cv_data,
                         symbol_table,
                         cv_conf,
                         args.bpe_model,
                         non_lang_syms,
                         partition=False)

    train_data_loader = DataLoader(train_dataset,
                                   batch_size=None,
                                   pin_memory=args.pin_memory,
                                   num_workers=args.num_workers,
                                   prefetch_factor=args.prefetch)
    cv_data_loader = DataLoader(cv_dataset,
                                batch_size=None,
                                pin_memory=args.pin_memory,
                                num_workers=args.num_workers,
                                prefetch_factor=args.prefetch)

    if 'fbank_conf' in configs['dataset_conf']:
        input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins']
    else:
        input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins']
    vocab_size = len(symbol_table)

    # Save configs to model_dir/train.yaml for inference and export
    configs['input_dim'] = input_dim
    configs['output_dim'] = vocab_size
    configs['cmvn_file'] = args.cmvn
    configs['is_json_cmvn'] = True
    configs['lfmmi_dir'] = args.lfmmi_dir

    if local_rank == 0:
        saved_config_path = os.path.join(args.model_dir, 'train.yaml')
        with open(saved_config_path, 'w') as fout:
            data = yaml.dump(configs)
            fout.write(data)

    # Init asr model from configs
    model = init_model(configs)
    print(model) if local_rank == 0 else None
    num_params = sum(p.numel() for p in model.parameters())
    print('the number of model params: {:,d}'.format(num_params)) if local_rank == 0 else None  # noqa

    # !!!IMPORTANT!!!
    # Try to export the model by script, if fails, we should refine
    # the code to satisfy the script export requirements
    if local_rank == 0:
        script_model = torch.jit.script(model)
        script_model.save(os.path.join(args.model_dir, 'init.zip'))
    executor = Executor()
    # If specify checkpoint, load some info from checkpoint
    if args.checkpoint is not None:
        infos = load_checkpoint(model, args.checkpoint, local_rank)
    elif args.enc_init is not None:
        logging.info('load pretrained encoders: {}'.format(args.enc_init))
        infos = load_trained_modules(model, args)
    else:
        infos = {}
    start_epoch = infos.get('epoch', -1) + 1
    cv_loss = infos.get('cv_loss', 0.0)
    step = infos.get('step', -1)

    num_epochs = configs.get('max_epoch', 100)
    model_dir = args.model_dir
    writer = None
    if local_rank == 0:
        os.makedirs(model_dir, exist_ok=True)
        exp_id = os.path.basename(model_dir)
        # 这里使用了tensorboardX的SummaryWriter
        writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id))

    if distributed:  # native pytorch ddp
        assert (torch.cuda.is_available())
        torch.cuda.set_device(local_rank)
        # cuda model is required for nn.parallel.DistributedDataParallel
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(
            model, find_unused_parameters=args.find_unused_parameters)
        device = torch.device("cuda")
        if args.fp16_grad_sync:
            from torch.distributed.algorithms.ddp_comm_hooks import (
                default as comm_hooks,
            )
            model.register_comm_hook(
                state=None, hook=comm_hooks.fp16_compress_hook
            )
    elif args.deepspeed:  # deepspeed
        # NOTE(xcsong): look in detail how the memory estimator API works:
        #   https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
        if local_rank == 0:
            logging.info("Estimating model states memory needs (zero2)...")
            estimate_zero2_model_states_mem_needs_all_live(
                model, num_gpus_per_node=world_size, num_nodes=1)
            logging.info("Estimating model states memory needs (zero3)...")
            estimate_zero3_model_states_mem_needs_all_live(
                model, num_gpus_per_node=world_size, num_nodes=1)
        device = None  # Init device later
        pass  # Init DeepSpeed later
    else:
        use_cuda = args.gpu >= 0 and torch.cuda.is_available()
        device = torch.device('cuda' if use_cuda else 'cpu')
        model = model.to(device)

    if configs['optim'] == 'adam':
        optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
    elif configs['optim'] == 'adamw':
        optimizer = optim.AdamW(model.parameters(), **configs['optim_conf'])
    else:
        raise ValueError("unknown optimizer: " + configs['optim'])
    scheduler_type = None
    if configs['scheduler'] == 'warmuplr':
        scheduler_type = WarmupLR
        scheduler = WarmupLR(optimizer, **configs['scheduler_conf'])
    elif configs['scheduler'] == 'NoamHoldAnnealing':
        scheduler_type = NoamHoldAnnealing
        scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf'])
    else:
        raise ValueError("unknown scheduler: " + configs['scheduler'])

    if args.deepspeed:
        if "optimizer" in ds_configs:
            # NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
            # extremely useful when enable cpu_offload, DeepspeedCpuAdam
            # could be 4~5x faster than torch native adam
            optimizer = None
            if "scheduler" in ds_configs:
                scheduler = None
            else:
                def scheduler(opt):
                    return scheduler_type(opt, **configs['scheduler_conf'])
        model, optimizer, _, scheduler = deepspeed.initialize(
            args=args, model=model, optimizer=optimizer,
            lr_scheduler=scheduler, model_parameters=model.parameters())

    final_epoch = None
    configs['rank'] = local_rank
    configs['is_distributed'] = distributed  # pytorch native ddp
    configs['is_deepspeed'] = args.deepspeed  # deepspeed
    configs['use_amp'] = args.use_amp
    if args.deepspeed and start_epoch == 0:
        # NOTE(xcsong): All ranks should call this API, but only rank 0
        #   save the general model params. see:
        #   https://github.com/microsoft/DeepSpeed/issues/2993
        with torch.no_grad():
            model.save_checkpoint(save_dir=model_dir, tag='init')
            if args.save_states == "model_only" and local_rank == 0:
                convert_zero_checkpoint_to_fp32_state_dict(
                    model_dir, "{}/init.pt".format(model_dir), tag='init')
                os.system("rm -rf {}/{}".format(model_dir, "init"))
    elif not args.deepspeed and start_epoch == 0 and local_rank == 0:
        save_model_path = os.path.join(model_dir, 'init.pt')
        save_checkpoint(model, save_model_path)

    # Start training loop
    executor.step = step
    scheduler.set_step(step)
    # used for pytorch amp mixed precision training
    scaler = None
    if args.use_amp:
        scaler = torch.cuda.amp.GradScaler()

    for epoch in range(start_epoch, num_epochs):
        train_dataset.set_epoch(epoch)
        configs['epoch'] = epoch
        lr = optimizer.param_groups[0]['lr']
        logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
        device = model.local_rank if args.deepspeed else device
        executor.train(model, optimizer, scheduler, train_data_loader, device,
                       writer, configs, scaler)
        total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device,
                                                configs)
        cv_loss = total_loss / num_seen_utts

        logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
        infos = {
            'epoch': epoch, 'lr': lr, 'cv_loss': cv_loss, 'step': executor.step,
            'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
        }
        if local_rank == 0:
            writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
            writer.add_scalar('epoch/lr', lr, epoch)
            with open("{}/{}.yaml".format(model_dir, epoch), 'w') as fout:
                data = yaml.dump(infos)
                fout.write(data)
        if args.deepspeed:
            with torch.no_grad():
                model.save_checkpoint(save_dir=model_dir,
                                      tag='{}'.format(epoch),
                                      client_state=infos)
                if args.save_states == "model_only" and local_rank == 0:
                    convert_zero_checkpoint_to_fp32_state_dict(
                        model_dir, "{}/{}.pt".format(model_dir, epoch),
                        tag='{}'.format(epoch))
                    os.system("rm -rf {}/{}".format(model_dir, epoch))
        elif not args.deepspeed and local_rank == 0:
            save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
            save_checkpoint(model, save_model_path, infos)
        final_epoch = epoch

    if final_epoch is not None and local_rank == 0:
        final_model_path = os.path.join(model_dir, 'final.pt')
        os.remove(final_model_path) if os.path.exists(final_model_path) else None
        os.symlink('{}.pt'.format(final_epoch), final_model_path)
        writer.close()

if __name__ == '__main__':
    main()
Met4physics commented 6 days ago

目前问题原因并不明确,neuron源代码中没有使用tensorboardx,暂时的解决方法我觉得你可以试试wandb或者手动打印来取代tensorboardx