Unbabel / OpenKiwi

Open-Source Machine Translation Quality Estimation in PyTorch
https://unbabel.github.io/OpenKiwi/
GNU Affero General Public License v3.0
228 stars 48 forks source link

training failed when only predicting source tags #91

Closed iamhere1 closed 3 years ago

iamhere1 commented 3 years ago

When I run the experiment with training bert based qe models and predicting only the source tags, it failed after reporting for the metric on validation dataset. The error occured as below. It seems the function "self.monitor_op(current - self.min_delta, self.best)" could only run on cpu. But when I run the training process and predicting all the source tags, target tags and the other tags, the training process is ok.

image

daandouwe commented 3 years ago

This is interesting, thank you. I believe we have never trained on just source tags, so that must be why we have never encountered this error before. Thank you for reporting it.

Let me just setup a little experiment to try and reproduce this, and will report to you when I understand how this could be fixed.

daandouwe commented 3 years ago

I just tried to train Bert and XLM-R based models on GPU with predicting just source tags, and did not encounter this problem. I used

trainer:
    main_metric:
        - source_tags_MCC

and

system:
    model:
        outputs:
            word_level:
                target: false
                gaps: false
                source: true

Could you share (the relevant parts of) your config? For example, what is main_metric that you are using?

iamhere1 commented 3 years ago

Thank you for your help, I use source_tags_F1_MULT and source_tags_CORRECT as my main metric, and the sentence level tag is not predicted in my experiment. The relevant parts in my config are as following.

outputs:
        ####################################################
        # Output options configure the downstream tasks the
        #  model will be trained on by adding specific layers
        #  responsible for transforming decoder features into
        #  predictions.
        word_level:
            target: false
            gaps: false
            source: true
            class_weights:
                target_tags:
                    BAD: 3.0
                gap_tags:
                    BAD: 5.0
                source_tags:
                    BAD: 3.0
        sentence_level:
            hter: false
            use_distribution: false
            binary: false
        n_layers_output: 2
        sentence_loss_weight: 1

and

main_metric:
    - source_tags_F1_MULT
    - source_tags_CORRECT
iamhere1 commented 3 years ago

Hi, @daandouwe I tried to replace the code in callbacks.py like this, and it seems worked now

    # mode_dict = {
    #     'min': np.less,
    #     'max': np.greater,
    #     'auto': np.greater if 'acc' in self.monitor else np.less,
    # }
    mode_dict = {
        'min': torch.lt,
        'max': torch.gt,
        'auto': torch.gt if 'acc' in self.monitor else torch.lt,
    }
daandouwe commented 3 years ago

Thanks!

Turns out that the problem is caused by the fact that not all the metric values have been moved to CPU.

Inspecting the values in the dictionary metrics in https://github.com/Unbabel/OpenKiwi/blob/master/kiwi/training/callbacks.py#L94 tells us the following:

{'F1_BAD': tensor(0.6000),
 'F1_OK': tensor(0.6092),
 'loss': tensor(205.8955, device='cuda:0'),
 'metrics': {'F1_BAD': tensor(0.6000),
             'F1_OK': tensor(0.6092),
             'source_tags_CORRECT': tensor(0.6047, device='cuda:0'),
             'source_tags_F1_MULT': 0.36551724137931035,
             'source_tags_F1_MULT+source_tags_CORRECT': tensor(0.9702, device='cuda:0'),
             'source_tags_MCC': tensor(0.2713, dtype=torch.float64)},
 'source_tags_CORRECT': tensor(0.6047, device='cuda:0'),
 'source_tags_F1_MULT': 0.36551724137931035,
 'source_tags_F1_MULT+source_tags_CORRECT': tensor(0.9702, device='cuda:0'),
 'source_tags_MCC': tensor(0.2713, dtype=torch.float64),
 'val_F1_BAD': tensor(0.4514),
 'val_F1_OK': tensor(0.6182),
 'val_loss': tensor(141.6612, device='cuda:0'),
 'val_loss_source_tags': tensor(141.6612, device='cuda:0'),
 'val_source_tags_CORRECT': tensor(0.5497, device='cuda:0'),
 'val_source_tags_F1_MULT': 0.27903935726135765,
 'val_source_tags_F1_MULT+source_tags_CORRECT': tensor(0.8288, device='cuda:0'),
 'val_source_tags_MCC': tensor(0.1847, dtype=torch.float64)}

This is also the reason that main metric source_tags_MCC works but source_tags_CORRECT does not. (I believe this flew under the radar because it seems to be the case for the lesser-used metrics).

We will need to solve this by making sure all the metrics return torch tensors that have been moved to CPU. Preferably, all the values returned by metrics will actually just be python floats.

We will try and fix this in a PR.

iamhere1 commented 3 years ago

Yes, Thank you for your help! And to replace the fuctions of numpy with the operation of pytorch, maybe another solution?

iamhere1 commented 3 years ago

Thanks!

Turns out that the problem is caused by the fact that not all the metric values have been moved to CPU.

Inspecting the values in the dictionary metrics in https://github.com/Unbabel/OpenKiwi/blob/master/kiwi/training/callbacks.py#L94 tells us the following:

{'F1_BAD': tensor(0.6000),
 'F1_OK': tensor(0.6092),
 'loss': tensor(205.8955, device='cuda:0'),
 'metrics': {'F1_BAD': tensor(0.6000),
             'F1_OK': tensor(0.6092),
             'source_tags_CORRECT': tensor(0.6047, device='cuda:0'),
             'source_tags_F1_MULT': 0.36551724137931035,
             'source_tags_F1_MULT+source_tags_CORRECT': tensor(0.9702, device='cuda:0'),
             'source_tags_MCC': tensor(0.2713, dtype=torch.float64)},
 'source_tags_CORRECT': tensor(0.6047, device='cuda:0'),
 'source_tags_F1_MULT': 0.36551724137931035,
 'source_tags_F1_MULT+source_tags_CORRECT': tensor(0.9702, device='cuda:0'),
 'source_tags_MCC': tensor(0.2713, dtype=torch.float64),
 'val_F1_BAD': tensor(0.4514),
 'val_F1_OK': tensor(0.6182),
 'val_loss': tensor(141.6612, device='cuda:0'),
 'val_loss_source_tags': tensor(141.6612, device='cuda:0'),
 'val_source_tags_CORRECT': tensor(0.5497, device='cuda:0'),
 'val_source_tags_F1_MULT': 0.27903935726135765,
 'val_source_tags_F1_MULT+source_tags_CORRECT': tensor(0.8288, device='cuda:0'),
 'val_source_tags_MCC': tensor(0.1847, dtype=torch.float64)}

This is also the reason that main metric source_tags_MCC works but source_tags_CORRECT does not. (I believe this flew under the radar because it seems to be the case for the lesser-used metrics).

We will need to solve this by making sure all the metrics return torch tensors that have been moved to CPU. Preferably, all the values returned by metrics will actually just be python floats.

We will try and fix this in a PR.

Yeah, that's more reasonable. Thank you!

daandouwe commented 3 years ago

maybe another solution?

You could edit

        current = metrics.get(self.monitor)
        if self.monitor_op(current - self.min_delta, self.best):

to

        current = metrics.get(self.monitor)
        if not isinstance(current, float):
            current = current.cpu().item()
        if self.monitor_op(current - self.min_delta, self.best):

in https://github.com/Unbabel/OpenKiwi/blob/master/kiwi/training/callbacks.py#L96 (in your local kiwi path).

I tried this and it worked, but of course it's not as nice as dealing with the problem at the root ;). We'll keep you updated on that.

iamhere1 commented 3 years ago

maybe another solution?

You could edit

        current = metrics.get(self.monitor)
        if self.monitor_op(current - self.min_delta, self.best):

to

        current = metrics.get(self.monitor)
        if not isinstance(current, float):
            current = current.cpu().item()
        if self.monitor_op(current - self.min_delta, self.best):

in https://github.com/Unbabel/OpenKiwi/blob/master/kiwi/training/callbacks.py#L96 (in your local kiwi path).

I tried this and it worked, but of course it's not at nice and dealing with the problem at the root ;). We'll keep you updated on that.

OK, thanks for your help, now it's woking.

captainvera commented 3 years ago

Hey @iamhere1,

I'm glad it's working, closing this issue due to inactivity.