lyutyuh / ASP

PyTorch implementation and pre-trained models for ASP - Autoregressive Structured Prediction with Language Models, EMNLP 22. https://arxiv.org/pdf/2210.14698.pdf
MIT License
100 stars 15 forks source link

Missing test evaluation #11

Open Niklss opened 1 year ago

Niklss commented 1 year ago

Missing code for test evaluation, only dev evaluation exists.

# Evaluate
                    if self.scheduler._step_count % self.config['eval_frequency'] == 0:
                        logger.info('Dev')

                        f1, _ = self.evaluate(
                            model, examples_dev, stored_info, self.scheduler._step_count
                        )
                        logger.info('Test')
                        f1_test = 0. # It is always zero
                        if f1 > max_f1:
                            max_f1 = max(max_f1, f1)
                            max_f1_test = 0. 
                            self.save_model_checkpoint(
                                model, self.optimizer, self.scheduler, self.scheduler._step_count, epo
                            )

                        logger.info('Eval max f1: %.2f' % max_f1)
                        logger.info('Test max f1: %.2f' % max_f1_test)
                        start_time = time.time()
HMTTT commented 1 year ago

I also meet this issue. How to evaluate with test dataset?

Niklss commented 1 year ago

I also meet this issue. How to evaluate with test dataset?

I simply rewrote the code to also evaluate on test and save the best model based on test eval. But be careful, test dataset is pretty big, so it's better to increase eval_frequency to not stuck on evaluation for too long.

# Evaluate
                    if self.scheduler._step_count % self.config['eval_frequency'] == 0:
                        logger.info('Dev')

                        f1, _ = self.evaluate(
                            model, examples_dev, stored_info, self.scheduler._step_count
                        )
                        logger.info('Test')
                        f1_test, _ = self.evaluate(
                            model, examples_test, stored_info, self.scheduler._step_count
                        )
                        max_f1 = max(max_f1, f1)
                        if f1_test > max_f1_test:
                            max_f1_test = max(max_f1_test, f1_test)
                            self.save_model_checkpoint(
                                model, self.optimizer, self.scheduler, self.scheduler._step_count, epo
                            )

                        logger.info('Eval max f1: %.2f' % max_f1)
                        logger.info('Test max f1: %.2f' % max_f1_test)
                        start_time = time.time()