yl4579 / StyleTTS2

StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models
MIT License
4.78k stars 391 forks source link

Low GPU Utilization during training #217

Open ayushtues opened 6 months ago

ayushtues commented 6 months ago

Hi, I have been trying to train a StyleTTS2 model from scratch on the LibriTTS 460 dataset, currently going through the first stage via train_first.py

The GPU utilisation of the training is very low ~30%. I am using a single H100 with batch_size = 8 and max_len = 300 to fit it on a single GPU.

Such low util means that the script is not using the GPU effeciently and there are potential bottlenecks to be addressed which can make the training faster.

Has anyone observed similar issues while training the model from scratch or has any ideas for improving the GPU util.

cc @yl4579

lucasgris commented 6 months ago

Yes, the same here, it seems there is a bottleneck, but using accelerate seems to help a little. Are you using accelerate? Try to set the num_processes.

image

ayushtues commented 6 months ago

Yes @lucasgris I am using accelerate and have played around with num_workers. Even in the graph you shared, the util hits very low points (<25% GPU util) consistently, any luck with improving that?

lucasgris commented 6 months ago

Not yet, but I think it is worth trying to identify where the code is slow, if I have any updates I will share here.

Selectorrr commented 6 months ago

Confirming the problem of low GPU utilization:

Снимок экрана 2024-03-27 в 17 30 06

It seems that some sort of computing on a single CPU core is a bottle neck:

Снимок экрана 2024-03-27 в 17 30 17
borrero-c commented 6 months ago

Also having this problem with train_finetune_accelerate.py. I haven't dug too deep but the accelerator.backward() calls seemed to be taking a very long time, specifically this code block https://github.com/yl4579/StyleTTS2/blob/5cedc71c333f8d8b8551ca59378bdcc7af4c9529/train_finetune_accelerate.py#L449-L464

Selectorrr commented 6 months ago

I tried the following options one by one: 1) Without accelerator and with accelerator 2) Increase the number of num_processes from 1 to 2 3) Decrease max_len from 600 to 290 4) Switch decoder from hifigan to istftnet Unsuccessfully.

borrero-c commented 6 months ago

Also showing low GPU utilization and high single core CPU utilization

Screenshot from 2024-03-27 16-27-23

It also seems like the issue goes away after the first epoch is finished, my GPU will start being utilized and the CPU load becomes more distributed

ayushtues commented 6 months ago

@borrero-c thanks for looking into this, I didn't seem to observe anything changing after 1 epoch, it stays low for me. Also accelerate.backward() call might be taking time since its doing the backward pass, that might be expected

Selectorrr commented 6 months ago

I did a little research and launched the profiler. Pay attention to the % of time

MAIN LOOP ```text Line # Hits Time Per Hit % Time Line Contents ============================================================== 162 2 8.8 4.4 0.0 for epoch in range(start_epoch, 5): 163 1 0.3 0.3 0.0 running_loss = 0 164 1 3.8 3.8 0.0 start_time = time.time() 165 166 1 7624.1 7624.1 0.0 _ = [model[key].train() for key in model] 167 168 2 2430.4 1215.2 0.0 pgbar = tqdm(desc=f"Epoch {epoch + 1}/{epochs}", unit='Step', total=len(train_list) // batch_size, smoothing=0, 169 1 0.1 0.1 0.0 initial=1) 170 102 525418.3 5151.2 0.3 for i, batch in enumerate(train_dataloader): 171 102 73.2 0.7 0.0 if i > 100: 172 1 265667.4 265667.4 0.1 break 173 101 36354.2 359.9 0.0 pgbar.update(1) 174 101 917.4 9.1 0.0 waves = batch[0] 175 101 5605.0 55.5 0.0 batch = [b.to(device) for b in batch[1:]] 176 101 2789.0 27.6 0.0 texts, input_lengths, _, _, mels, mel_input_length, _ = batch 177 178 202 1903.5 9.4 0.0 with torch.no_grad(): 179 101 77350.0 765.8 0.0 mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device) 180 101 12146.6 120.3 0.0 text_mask = length_to_mask(input_lengths).to(texts.device) 181 182 101 20453215.4 202507.1 10.2 ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts) 183 184 101 911.2 9.0 0.0 s2s_attn = s2s_attn.transpose(-1, -2) 185 101 1402.6 13.9 0.0 s2s_attn = s2s_attn[..., 1:] 186 101 334.4 3.3 0.0 s2s_attn = s2s_attn.transpose(-1, -2) 187 188 202 2396.6 11.9 0.0 with torch.no_grad(): 189 101 33850.8 335.2 0.0 attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2) 190 101 16570.0 164.1 0.0 attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float() 191 101 5279.7 52.3 0.0 attn_mask = (attn_mask < 1) 192 193 101 3047.2 30.2 0.0 s2s_attn.masked_fill_(attn_mask, 0.0) 194 195 202 1703.6 8.4 0.0 with torch.no_grad(): 196 101 48330.6 478.5 0.0 mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down)) 197 101 416141.6 4120.2 0.2 s2s_attn_mono = maximum_path(s2s_attn, mask_ST) 198 199 # encode 200 101 1624539.8 16084.6 0.8 t_en = model.text_encoder(texts, input_lengths, text_mask) 201 202 # 50% of chance of using monotonic version 203 101 416.9 4.1 0.0 if bool(random.getrandbits(1)): 204 43 4864.5 113.1 0.0 asr = (t_en @ s2s_attn) 205 else: 206 58 12170.7 209.8 0.0 asr = (t_en @ s2s_attn_mono) 207 208 # get clips 209 101 5637.2 55.8 0.0 mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load 210 101 14759.9 146.1 0.0 mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2]) 211 101 2895.5 28.7 0.0 mel_len_st = int(mel_input_length.min().item() / 2 - 1) 212 213 101 499.9 4.9 0.0 en = [] 214 101 521.1 5.2 0.0 gt = [] 215 101 432.4 4.3 0.0 wav = [] 216 101 427.5 4.2 0.0 st = [] 217 218 909 1352.1 1.5 0.0 for bib in range(len(mel_input_length)): 219 808 17282.6 21.4 0.0 mel_length = int(mel_input_length[bib].item() / 2) 220 221 808 6093.9 7.5 0.0 random_start = np.random.randint(0, mel_length - mel_len) 222 808 12116.3 15.0 0.0 en.append(asr[bib, :, random_start:random_start+mel_len]) 223 808 6309.5 7.8 0.0 gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)]) 224 225 808 1675.8 2.1 0.0 y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300] 226 808 69490.9 86.0 0.0 wav.append(torch.from_numpy(y).to(device)) 227 228 # style reference (better to be different from the GT) 229 808 5077.0 6.3 0.0 random_start = np.random.randint(0, mel_length - mel_len_st) 230 808 8738.8 10.8 0.0 st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)]) 231 232 101 5238.2 51.9 0.0 en = torch.stack(en) 233 101 2928.9 29.0 0.0 gt = torch.stack(gt).detach() 234 101 2246.8 22.2 0.0 st = torch.stack(st).detach() 235 236 101 7146.6 70.8 0.0 wav = torch.stack(wav).float().detach() 237 238 # clip too short to be used by the style encoder 239 101 202.3 2.0 0.0 if gt.shape[-1] < 80: 240 continue 241 242 202 2124.9 10.5 0.0 with torch.no_grad(): 243 101 44210.4 437.7 0.0 real_norm = log_norm(gt.unsqueeze(1)).squeeze(1).detach() 244 101 2671261.3 26448.1 1.3 F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1)) 245 246 101 2978410.2 29489.2 1.5 s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1)) 247 248 101 17613113.7 174387.3 8.8 y_rec = model.decoder(en, F0_real, real_norm, s) 249 250 # discriminator loss 251 252 101 70.8 0.7 0.0 if epoch >= TMA_epoch: 253 101 565364.9 5597.7 0.3 optimizer.zero_grad() 254 101 11707820.2 115919.0 5.8 d_loss = dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean() 255 101 18847437.9 186608.3 9.4 accelerator.backward(d_loss) 256 101 313779.6 3106.7 0.2 optimizer.step('msd') 257 101 294492.0 2915.8 0.1 optimizer.step('mpd') 258 else: 259 d_loss = 0 260 261 # generator loss 262 101 237334.6 2349.8 0.1 optimizer.zero_grad() 263 101 282369.5 2795.7 0.1 loss_mel = stft_loss(y_rec.squeeze(), wav.detach()) 264 265 101 51.6 0.5 0.0 if epoch >= TMA_epoch: # start TMA training 266 101 419.4 4.2 0.0 loss_s2s = 0 267 909 10903.2 12.0 0.0 for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths): 268 808 89627.5 110.9 0.0 loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length]) 269 101 2396.6 23.7 0.0 loss_s2s /= texts.size(0) 270 271 101 11985.0 118.7 0.0 loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10 272 273 101 6983523.9 69143.8 3.5 loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean() 274 101 4046595.1 40065.3 2.0 loss_slm = wl(wav.detach(), y_rec).mean() 275 276 505 812033.9 1608.0 0.4 g_loss = loss_params.lambda_mel * loss_mel + \ 277 101 1428.5 14.1 0.0 loss_params.lambda_mono * loss_mono + \ 278 101 1285.5 12.7 0.0 loss_params.lambda_s2s * loss_s2s + \ 279 101 1268.2 12.6 0.0 loss_params.lambda_gen * loss_gen_all + \ 280 101 1230.7 12.2 0.0 loss_params.lambda_slm * loss_slm 281 282 else: 283 loss_s2s = 0 284 loss_mono = 0 285 loss_gen_all = 0 286 loss_slm = 0 287 g_loss = loss_mel 288 289 101 14339.2 142.0 0.0 running_loss += accelerator.gather(loss_mel).mean().item() 290 291 101 99737870.0 987503.7 49.6 accelerator.backward(g_loss) 292 293 101 199636.4 1976.6 0.1 optimizer.step('text_encoder') 294 101 290944.4 2880.6 0.1 optimizer.step('style_encoder') 295 101 2382230.7 23586.4 1.2 optimizer.step('decoder') 296 297 101 72.7 0.7 0.0 if epoch >= TMA_epoch: 298 101 430973.2 4267.1 0.2 optimizer.step('text_aligner') 299 # optimizer.step('pitch_extractor') 300 301 101 82.0 0.8 0.0 iters = iters + 1 302 303 101 386.2 3.8 0.0 if (i+1)%log_interval == 0 and accelerator.is_main_process: 304 20 1296.7 64.8 0.0 status = 'Epoch [%d/%d], Step [%d/%d], Mel Loss: %.5f, Gen Loss: %.5f, Disc Loss: %.5f, Mono Loss: %.5f, S2S Loss: %.5f, SLM Loss: %.5f' % ( 305 10 17.5 1.7 0.0 epoch + 1, epochs, i + 1, len(train_list) // batch_size, running_loss / log_interval, loss_gen_all, 306 10 2.6 0.3 0.0 d_loss, loss_mono, loss_s2s, loss_slm) 307 # log_print (status, logger) 308 10 2629.4 262.9 0.0 pgbar.set_postfix_str(status) 309 10 1553.5 155.4 0.0 writer.add_scalar('train/mel_loss', running_loss / log_interval, iters) 310 10 2915.1 291.5 0.0 writer.add_scalar('train/gen_loss', loss_gen_all, iters) 311 10 1903.1 190.3 0.0 writer.add_scalar('train/d_loss', d_loss, iters) 312 10 2026.3 202.6 0.0 writer.add_scalar('train/mono_loss', loss_mono, iters) 313 10 1550.2 155.0 0.0 writer.add_scalar('train/s2s_loss', loss_s2s, iters) 314 10 1451.7 145.2 0.0 writer.add_scalar('train/slm_loss', loss_slm, iters) 315 316 10 11.9 1.2 0.0 running_loss = 0 317 318 # print('Time elasped:', time.time()-start_time) 319 320 1 0.4 0.4 0.0 loss_test = 0 321 322 1 6907.6 6907.6 0.0 _ = [model[key].eval() for key in model] ```

If we exclude it as expected: accelerator.backward() This increases GPU utilization by about 20% but utilization remains uneven.

Снимок экрана 2024-03-29 в 10 20 12

If I additionally exclude line 182: ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts) This makes GPU utilization more uniform

Снимок экрана 2024-03-29 в 10 16 51
Additional performance details about text_aligner ```text ASRCNN File: /app/Utils/ASR/models.py Function: forward at line 37 Line # Hits Time Per Hit % Time Line Contents ============================================================== 37 @profile 38 def forward(self, x, src_key_padding_mask=None, text_input=None): 39 101 116166.6 1150.2 0.5 x = self.to_mfcc(x) 40 101 742848.0 7354.9 3.2 x = self.init_cnn(x) 41 101 4113638.1 40729.1 17.6 x = self.cnns(x) 42 101 576767.4 5710.6 2.5 x = self.projection(x) 43 101 1441.8 14.3 0.0 x = x.transpose(1, 2) 44 101 102451.0 1014.4 0.4 ctc_logit = self.ctc_linear(x) 45 101 51.5 0.5 0.0 if text_input is not None: 46 101 17664905.7 174900.1 75.8 _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input) 47 101 66.9 0.7 0.0 return ctc_logit, s2s_logit, s2s_attn 48 else: 49 return ctc_logit ``` ```text ASRS2S File: /app/Utils/ASR/models.py Function: forward at line 118 Line # Hits Time Per Hit % Time Line Contents ============================================================== 118 @profile 119 def forward(self, memory, memory_mask, text_input): 120 """ 121 moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim) 122 moemory_mask.shape = (B, L, ) 123 texts_input.shape = (B, T) 124 """ 125 101 73718.9 729.9 0.5 self.initialize_decoder_states(memory, memory_mask) 126 # text random mask 127 101 8880.2 87.9 0.1 random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device) 128 101 2321.3 23.0 0.0 _text_input = text_input.clone() 129 101 273189.4 2704.8 1.7 _text_input.masked_fill_(random_mask, self.unk_index) 130 101 8951.4 88.6 0.1 decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel] 131 202 3652.6 18.1 0.0 start_embedding = self.embedding( 132 101 4901.0 48.5 0.0 torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device)) 133 101 29957.5 296.6 0.2 decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0) 134 135 101 50.4 0.5 0.0 hidden_outputs, logit_outputs, alignments = [], [], [] 136 12503 27916.0 2.2 0.2 while len(hidden_outputs) < decoder_inputs.size(0): 137 138 12402 124859.6 10.1 0.8 decoder_input = decoder_inputs[len(hidden_outputs)] 139 12402 15663074.1 1262.9 95.9 hidden, logit, attention_weights = self.decode(decoder_input) 140 12402 12834.5 1.0 0.1 hidden_outputs += [hidden] 141 12402 4052.0 0.3 0.0 logit_outputs += [logit] 142 12402 4427.6 0.4 0.0 alignments += [attention_weights] 143 144 101 57422.0 568.5 0.4 hidden_outputs, logit_outputs, alignments = \ 145 202 37085.3 183.6 0.2 self.parse_decoder_outputs( 146 101 17.3 0.2 0.0 hidden_outputs, logit_outputs, alignments) 147 148 101 40.2 0.4 0.0 return hidden_outputs, logit_outputs, alignments ``` ```text File: /app/Utils/ASR/models.py Function: decode at line 149 Line # Hits Time Per Hit % Time Line Contents ============================================================== 149 @profile 150 def decode(self, decoder_input): 151 152 12077 451589.1 37.4 2.9 cell_input = torch.cat((decoder_input, self.attention_context), -1) 153 24154 1601496.2 66.3 10.2 self.decoder_hidden, self.decoder_cell = self.decoder_rnn( 154 12077 1534.5 0.1 0.0 cell_input, 155 12077 4154.4 0.3 0.0 (self.decoder_hidden, self.decoder_cell)) 156 157 24154 395431.2 16.4 2.5 attention_weights_cat = torch.cat( 158 24154 94829.2 3.9 0.6 (self.attention_weights.unsqueeze(1), 159 24154 67172.2 2.8 0.4 self.attention_weights_cum.unsqueeze(1)),dim=1) 160 161 24154 10655347.3 441.1 68.1 self.attention_context, self.attention_weights = self.attention_layer( 162 12077 2301.5 0.2 0.0 self.decoder_hidden, 163 12077 2838.1 0.2 0.0 self.memory, 164 12077 2822.3 0.2 0.0 self.processed_memory, 165 12077 1527.4 0.1 0.0 attention_weights_cat, 166 12077 3143.5 0.3 0.0 self.mask) 167 168 12077 231777.8 19.2 1.5 self.attention_weights_cum += self.attention_weights 169 170 12077 264389.2 21.9 1.7 hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1) 171 12077 1005236.5 83.2 6.4 hidden = self.project_to_hidden(hidden_and_context) 172 173 # dropout to increasing g 174 12077 860518.1 71.3 5.5 logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training)) 175 176 12077 5210.3 0.4 0.0 return hidden, logit, self.attention_weights ```
borrero-c commented 6 months ago

Looked into it some more, my steps are taking 40-20 seconds long and the .backwards() call is taking 20-10 seconds respectively.

When the training starts to pick up after that first epoch (and GPU is being more consistently utilized) the steps are ~4 seconds each and the backwards call takes ~2 seconds.

Also interesting to see that this code block is taking a good amount of time to complete too: https://github.com/yl4579/StyleTTS2/blob/5cedc71c333f8d8b8551ca59378bdcc7af4c9529/train_finetune_accelerate.py#L306-L312

It seems for each step ~25% of time is spent in the loop above and ~50% is spent in the .backwards() call in line 464. Not sure how/if those could be improved, this isnt really my area of expertise