sooftware / kospeech

Open-Source Toolkit for End-to-End Korean Automatic Speech Recognition leveraging PyTorch and Hydra.
https://sooftware.github.io/kospeech/
Apache License 2.0
603 stars 191 forks source link

train 시 validate index error #46

Closed ghost closed 4 years ago

ghost commented 4 years ago

안녕하세요. 매번 업데이트 해주시는 코드 열심히 따라가면서 공부하고 있습니다! 감사합니다! 제가 Aihub 데이터셋이 아닌 다른 데이터셋을 사용해서 train, validate를 시도해보고 있는데요. 아래와 같은 에러가 발생하였습니다.

Traceback (most recent call last):
  File "./main.py", line 111, in <module>
    main()
  File "./main.py", line 107, in main
    train(opt)
  File "./main.py", line 86, in train
    num_epochs=opt.num_epochs, teacher_forcing_ratio=opt.teacher_forcing_ratio, resume=opt.resume)
  File "../kospeech/trainer/supervised_trainer.py", line 161, in train
    valid_loss, valid_cer = self.validate(model, valid_queue)
  File "../kospeech/trainer/supervised_trainer.py", line 327, in validate
    loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), targets.contiguous().view(-1))
  File "/home/stt_py/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __c
all__
    result = self.forward(*input, **kwargs)
  File "../kospeech/optim/loss.py", line 59, in forward
    label_smoothed.scatter_(1, target.data.unsqueeze(1), self.confidence)
RuntimeError: invalid argument 4: Index tensor must have same size as output tensor apart from the sp
ecified dimension at /pytorch/aten/src/THC/generic/THCTensorScatterGather.cu:328

확인해봤을때 supevised_trainer.py의 validate 함수에서 324번째줄의 logit과 target의 shape이 맞지 않더라구요.

그래서 코드를 보고 있는데 310번쨰 줄에 targets이 들어가있지 않던데 targets을 추가하는게 맞지 않은가해서요. 크게 변경없이 제 나름대로 코드를 다시 짜봤는데 문제가 없는지 확인 부탁드려도 될까요?

    def validate(self, model: nn.Module, queue: queue.Queue) -> float:
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            queue (queue.Queue): validation queue, containing input, targets, input_lengths, target_lengths

        Returns: loss, cer
            - **loss** (float): loss of validation
            - **cer** (float): character error rate of validation
        """
        cer = 1.0
        total_loss = 0.
        total_num = 0.

        model.eval()
        logger.info('validate() start')

        with torch.no_grad():
            while True:
                inputs, targets, input_lengths, target_lengths = queue.get()

                if inputs.shape[0] == 0:
                    break
                inputs = inputs.to(self.device)
                #targets = targets[:, 1:].to(self.device)
                targets = targets.to(self.device) # train과 같은 형태로 변경
                model.cuda()

                if self.architecture == 'seq2seq':
                    model.module.flatten_parameters()
                    #output = model(inputs=inputs, input_lengths=input_lengths,
                    #               teacher_forcing_ratio=0.0, return_decode_dict=False)
                    output = model(inputs=inputs, input_lengths=input_lengths, targets=targets, # targets 추가
                                   teacher_forcing_ratio=0.0, return_decode_dict=False) 
                    logit = torch.stack(output, dim=1).to(self.device)
                    targets = targets[:, 1:] # train과 같은 형태로 변경

                elif self.architecture == 'transformer':
                    logit = model(inputs, input_lengths, return_decode_dict=False)

                else:
                    raise ValueError("Unsupported architecture : {0}".format(self.architecture))

                hypothesis = logit.max(-1)[1]
                cer = self.metric(targets, hypothesis)
                logit = logit[:, :targets.size(1), :] # 이 부분은 train에는 없던데 validate만 해당하는 걸까요?
                loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), targets.contiguous().view(-1))

                total_loss += loss.item()
                total_num += sum(input_lengths)

        logger.info('validate() completed')
        return total_loss / total_num, cer

바쁘실텐데 읽어주셔서 감사합니다!

sooftware commented 4 years ago

line 310에 targets를 넘겨주면 안됩니다.
티쳐포싱이란 개념에 대해서 검색해보시거나 전에 제가 정리한 글이있습니다. 링크

targets이란 결국 ground truth (정답) 인데, 트레이닝시에는 빠르고 정확한 학습을 위해 정답 레이블을 다음 입력으로 넣어주게 되지만, validate나 evaluate시에는 실제 성능을 확인하고자 진행하는 과정이므로 티쳐포싱을 사용하지 않습니다. 티쳐포싱을 사용하게 되면 좀 더 정확하게 예측하게 될 것이므로 편법이라고 볼 수 있습니다.

그리고 주석 다신 부분인

targets = targets[:, 1:] # train과 같은 형태로 변경

같은 경우는 같은 형태로 변경하는 과정이 아니라 sos 토큰을 없애주는 과정입니다.

targets에는 sos 가 나 다 라 마 바 사 . . . eos 와 같은 형태인데, 입력과 아웃풋의 이상적인 형태는 다음과 같습니다.

criterion에는 output의 이상적인 형태와 비교해야 하므로 를 제거해주는 과정입니다.

데이터셋은 어떤 데이터셋을 쓰시는지 몰라서 어떻게 수정해야할지는 말씀드리기 애매하나,
validate시에 loss 계산을 따로 하지 않고 cer만 측정하시게 되면 문제없이 넘어가지 않을까 싶습니다.

ghost commented 4 years ago

데이터셋은 aihub와 동일한 형태로 진행했는데 그러네요 ㅠ 따로 더 살펴봐야겠네요! train은 문제가 없는데 validate만 문제가 생기네요ㅠ 말씀하신대로 loss를 제거했을때는 문제없이 잘 돌아갑니다! 자세한 설명 감사드립니다!