YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.06k stars 203 forks source link

CPU memory increase while training #119

Open gudrb opened 5 months ago

gudrb commented 5 months ago

During the training step, the code block:

for blk in self.v.blocks: x = blk(x) causes an increase in CPU memory, eventually leading to full CPU memory. Do you know any solutions to this issue?

YuanGongND commented 5 months ago

hi there,

Are you using our recipe to train the model?

are you training on cpu or gpu? What is you cpu/gpu memory?

x = blk(x) means the input tensor passes a Transformer layer, it should take GPU memory. I don't know how torch manages cpu memory in this case.

-Yuan

gudrb commented 5 months ago

i am using my own training code but following the ast_models.py code to define ASTModel. i can see from the memory_profiler, at line 176 of screenshot, my htop cpu memory linearly increase while trianing step. 캡처

--> This is part of my code.

class ASTModel(nn.Module): """ The AST model. :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35 :param fstride: the stride of patch spliting on the frequency dimension, for 1616 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6 :param tstride: the stride of patch spliting on the time dimension, for 1616 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6 :param input_fdim: the number of frequency bins of the input spectrogram :param input_tdim: the number of time frames of the input spectrogram :param imagenet_pretrain: if use ImageNet pretrained model :param audioset_pretrain: if use full AudioSet and ImageNet pretrained model :param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining. """ def init(self, label_dim=527, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True):

    super(ASTModel, self).__init__()
    assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'

    if verbose == True:
        print('---------------AST Model Summary---------------')
        print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain)))
    # override timm input shape restriction
    # timm.models.vision_transformer.PatchEmbed = PatchEmbed
    # if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
    if audioset_pretrain == False:
        if model_size == 'tiny224':
            self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain)
        elif model_size == 'small224':
            self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain)
        elif model_size == 'base224':
            self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain)
        elif model_size == 'base384':
            self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain)
        else:
            raise Exception('Model size must be one of tiny224, small224, base224, base384.')

def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024):
    test_input = torch.randn(1, 1, input_fdim, input_tdim)
    test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
    test_out = test_proj(test_input)
    f_dim = test_out.shape[2]
    t_dim = test_out.shape[3]
    return f_dim, t_dim

@autocast()
def forward(self, x):
    """
    :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
    :return: prediction
    """
    # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
    # x = x.unsqueeze(1)
    # x = x.transpose(1, 2)
    B,T,F = x.shape
    # x = self.v.patch_embed(x)
    cls_tokens = self.v.cls_token.expand(B, -1, -1)
    dist_token = self.v.dist_token.expand(B, -1, -1)
    x = torch.cat((cls_tokens, dist_token, x), dim=1)
    x = x + self.v.pos_embed
    x = self.v.pos_drop(x)
    for blk in self.v.blocks:
        x = blk(x)
    x = self.v.norm(x)

--> i am only utilizing pretrained raw vit_deit_tiny_distilled_patch16_224 without changing patch embedding or positional encoding.

class LSQSL(nn.Module):

def init(self, in_chans=1, frequency=129, d1_model=32, d2_model=172, n_layers=1, hidden_size=88, bidirectional=True):

def __init__(self, in_chans=1, time=29, frequency=129, n_seq=10, seq_length=20, d1_model=32, d2_model=128, n_layers=1, hidden_size=64, bidirectional=True, head_num=2):
    super(LSQSL, self).__init__()
    self.ast_mdl = ASTModel(label_dim=5, \
     input_fdim=192, input_tdim=196, \
     imagenet_pretrain=True, audioset_pretrain=False, \
     model_size='tiny224')

def forward(self, raw):
    B, L, D = raw.shape
    raw = self.ast_mdl(raw)
    return raw

--> i use GPU for training with (model = LSQSL.to(device)). but the blk(x) increase my htop cpu memory linearly increase. when i track the code until /python3.7/site-packages/timm/models/vision_transformer.py, in the Block class, self.attn and self.mlp operation increase the cpu memory.

def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x

class Mlp(nn.Module): def init(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().init() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x

class Attention(nn.Module): def init(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().init() self.num_heads = num_heads head_dim = dim // num_heads

NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights

    self.scale = qk_scale or head_dim ** -0.5
    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    self.attn_drop = nn.Dropout(attn_drop)
    self.proj = nn.Linear(dim, dim)
    self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

--> However, Mlp and Attention class foward have nothing related about cpu operation or stack of trash variables such as detach() or append(). Do u know the reason of this unexpected CPU memory leack? i am only suspecting the timm library own Issues.

YuanGongND commented 5 months ago

What's your GPU memory and CPU memory?

Have you tried to use our recipe (e.g., ESC-50, which is very fast to run, see our Readme) and see if the issue is still there?

This is the first time I heard this issue.

gudrb commented 5 months ago

I also see this memory leak from your ESC-50 training code also. but, not as much as mine. I think the difference from mine is sequence length. I use original 198 sequence length with much more attention computation.
캡처1 캡처2 캡처3

I am using GPU: NVIDIA A100-PCIE-40GB, and CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 48 On-line CPU(s) list: 0-47 Vendor ID: GenuineIntel Model name: Intel Xeon Processor (Skylake, IBRS) CPU family: 6 Model: 85 Thread(s) per core: 1 Core(s) per socket: 1 Socket(s): 48 Stepping: 4 BogoMIPS: 6185.46 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_k nown_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 arat pku ospke avx512_vnni md_clear Virtualization features: Hypervisor vendor: KVM Virtualization type: full Caches (sum of all):
L1d: 1.5 MiB (48 instances) L1i: 1.5 MiB (48 instances) L2: 192 MiB (48 instances) L3: 768 MiB (48 instances) NUMA:
NUMA node(s): 1 NUMA node0 CPU(s): 0-47 Vulnerabilities:
Gather data sampling: Unknown: Dependent on hypervisor status Itlb multihit: KVM: Mitigation: VMX unsupported L1tf: Mitigation; PTE Inversion Mds: Mitigation; Clear CPU buffers; SMT Host state unknown Meltdown: Mitigation; PTI Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Retbleed: Mitigation; IBRS Spec rstack overflow: Not affected Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Spectre v2: Mitigation; IBRS, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected Srbds: Not affected Tsx async abort: Mitigation; Clear CPU buffers; SMT Host state unknown

gudrb commented 5 months ago

Thank you for answering my questions. It was not any problem of the timm library or the AST codes. While training, there was a tensor operation without using .item() such <running_loss += loss> So i changed it to <running_loss += loss.item()> and RAM memory leak was fixed!

YuanGongND commented 5 months ago

thanks for letting me know. Good luck with your research!