huggingface / setfit

Efficient few-shot learning with Sentence Transformers
https://hf.co/docs/setfit
Apache License 2.0
2.24k stars 223 forks source link

create_model_card() TemplateAssertionError #481

Closed abedini-arteriaai closed 9 months ago

abedini-arteriaai commented 9 months ago

Hi SetFit team,

Is there a way to disable create_model_card for README as well?

model_body.save has this option with create_model_card=False, however it automatically will create a model card to save README.

This gives a TemplateAssertionError and if possible, would like to skip that part entirely as I don't need to save the README.

Context for the code in SetFit:

 /databricks/python/lib/python3.8/site-packages/setfit/modeling.py in _save_pretrained(self, save_directory)
    695         self.model_body.save(path=save_directory, create_model_card=False)
    696         # Save the README
--> 697         self.create_model_card(path=save_directory, model_name=save_directory)

Context for the TemplateAssertionError:

[my own file]
--> 131         self.trainer.model.save_pretrained('setfit_model')
    132         # writing metrics in a file
    133         json_obj = dumps(self.eval_metrics)

/databricks/python/lib/python3.8/site-packages/huggingface_hub/hub_mixin.py in save_pretrained(self, save_directory, config, repo_id, push_to_hub, **kwargs)
     57 
     58         # saving model weights/files
---> 59         self._save_pretrained(save_directory)
     60 
     61         # saving config

/databricks/python/lib/python3.8/site-packages/setfit/modeling.py in _save_pretrained(self, save_directory)
    695         self.model_body.save(path=save_directory, create_model_card=False)
    696         # Save the README
--> 697         self.create_model_card(path=save_directory, model_name=save_directory)
    698         # Move the head to the CPU before saving
    699         if self.has_differentiable_head:

/databricks/python/lib/python3.8/site-packages/setfit/modeling.py in create_model_card(self, path, model_name)
    668 
    669         with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
--> 670             f.write(self.generate_model_card())
    671 
    672     def generate_model_card(self) -> str:

/databricks/python/lib/python3.8/site-packages/setfit/modeling.py in generate_model_card(self)
    676             str: The model card string.
    677         """
--> 678         return generate_model_card(self)
    679 
    680     def _save_pretrained(self, save_directory: Union[Path, str]) -> None:

/databricks/python/lib/python3.8/site-packages/setfit/model_card.py in generate_model_card(model)
    574 def generate_model_card(model: "SetFitModel") -> str:
    575     template_path = Path(__file__).parent / "model_card_template.md"
--> 576     model_card = ModelCard.from_template(card_data=model.model_card_data, template_path=template_path, hf_emoji="🤗")
    577     return model_card.content

/databricks/python/lib/python3.8/site-packages/huggingface_hub/repocard.py in from_template(cls, card_data, template_path, **template_kwargs)
    403             ```
    404         """
--> 405         return super().from_template(card_data, template_path, **template_kwargs)
    406 
    407 

/databricks/python/lib/python3.8/site-packages/huggingface_hub/repocard.py in from_template(cls, card_data, template_path, **template_kwargs)
    321         kwargs = card_data.to_dict().copy()
    322         kwargs.update(template_kwargs)  # Template_kwargs have priority
--> 323         template = jinja2.Template(Path(template_path or cls.default_template_path).read_text())
    324         content = template.render(card_data=card_data.to_yaml(), **kwargs)
    325         return cls(content)

/databricks/python/lib/python3.8/site-packages/jinja2/environment.py in __new__(cls, source, block_start_string, block_end_string, variable_start_string, variable_end_string, comment_start_string, comment_end_string, line_statement_prefix, line_comment_prefix, trim_blocks, lstrip_blocks, newline_sequence, keep_trailing_newline, extensions, optimized, undefined, finalize, autoescape, enable_async)
   1029             enable_async,
   1030         )
-> 1031         return env.from_string(source, template_class=cls)
   1032 
   1033     @classmethod

/databricks/python/lib/python3.8/site-packages/jinja2/environment.py in from_string(self, source, globals, template_class)
    939         globals = self.make_globals(globals)
    940         cls = template_class or self.template_class
--> 941         return cls.from_code(self, self.compile(source), globals, None)
    942 
    943     def make_globals(self, d):

/databricks/python/lib/python3.8/site-packages/jinja2/environment.py in compile(self, source, name, filename, raw, defer_init)
    636             return self._compile(source, filename)
    637         except TemplateSyntaxError:
--> 638             self.handle_exception(source=source_hint)
    639 
    640     def compile_expression(self, source, undefined_to_none=True):

/databricks/python/lib/python3.8/site-packages/jinja2/environment.py in handle_exception(self, source)
    830         from .debug import rewrite_traceback_stack
    831 
--> 832         reraise(*rewrite_traceback_stack(source=source))
    833 
    834     def join_path(self, template, parent):

/databricks/python/lib/python3.8/site-packages/jinja2/_compat.py in reraise(tp, value, tb)
     26     def reraise(tp, value, tb=None):
     27         if value.__traceback__ is not tb:
---> 28             raise value.with_traceback(tb)
     29         raise value
     30 

<unknown> in template()

TemplateAssertionError: no test named 'False'

Edit: for reference of others, upgrading Jinja2 addresses TemplateAssertionError

tomaarsen commented 9 months ago

Hello!

It isn't possible to explicitly disable README creation when saving a model, but you can subclass the SetFitModel to remove the model card creation functionality. For example:

from setfit import SetFitModel

class SetFitModelWithoutModelCard(SetFitModel):
    def create_model_card(self, path: str, model_name: Optional[str] = "SetFit Model") -> None:
        pass

SetFitModelWithoutModelCard.from_pretrained("...")
# usage continues like normal

I'm not sure why you're experiencing that error to begin with; I'd like to figure that out in the future as well, but this should help in the meantime.