BobXWu / FASTopic

A Fast, Adaptive, Stable, and Transferable Topic Model
https://pypi.org/project/fastopic/
Apache License 2.0
41 stars 5 forks source link

Load pre-trained model #4

Closed GuishePerez closed 4 months ago

GuishePerez commented 4 months ago

Hi! first of all, thanks for sharing the code of your model. It's a very interesting paper and I would like to give it a a try with my own data.

However, first I'm trying to reproduce some toy examples like the following one, and I'm having issues for loading a pre-trained model.

This is what I tried so far:

from fastopic import FASTopic
import topmost
from topmost.data import download_dataset

# Download data and pre-process if required
dataset_name = "NYT"
download_dataset(dataset_name, cache_path="./datasets")
dataset = topmost.data.DynamicDataset("./datasets/NYT", as_tensor=False)
docs = dataset.train_texts
preprocessing = None

# Number of topics
K = 50

# Define doc embedding model
doc_embedder = "all-MiniLM-L6-v2"  # default one

# Instantiate model's object
model = FASTopic(
    K,
    preprocessing,
    doc_embed_model=doc_embedder,
)

# Train
topic_top_words, doc_topic_dist = model.fit_transform(docs)

# Save
model.save_model(f"/root/test/models/{doc_embedder}_{dataset_name}")

dataset_name = "NYT"

Number of topics

K = 50

Define doc embedding model

doc_embedder = "all-MiniLM-L6-v2" # default one

Instantiate model's object

model = FASTopic( K, preprocessing=None, doc_embed_model=doc_embedder, ) model.load_model(f"/root/test/models/{docembedder}{dataset_name}")


This throws the following error:

```console
Traceback (most recent call last):
  File "/root/test/test_fasttopic.py", line 51, in <module>
    model.load_model(f"/root/test/models/{doc_embedder}_{dataset_name}")
  File "/usr/local/lib/python3.10/site-packages/fastopic/FASTopic.py", line 217, in load_model
    self.model.load_state_dict(torch.load(f"{path}.zip"))
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for fastopic:
        Unexpected key(s) in state_dict: "word_embeddings", "topic_embeddings", "word_weights", "topic_weights", "DT_ETP.init_b_dist", "DT_ETP.b_dist", "TW_ETP.init_b_dist", "TW_ETP.b_dist". 

I think it probably has to do with the fact that _fastopic.py has 2 different init methods, the Python reserved one (__init__()) and the custom one you've created (init()). This second one is called when runing the fit_transform() method for training the model for the first time but, when trying to load that already trained model afterwards to perform inference on new docs, the new object FASTopic.model has not initialized any of the expected layers in the state_dict we saved beforehand as we see in the error trace.

Can you provide a minimal example on how to do it properly? Or point me what I could be doing wrong?

Thanks in advance!

Guille

BobXWu commented 4 months ago

Thank you for your interest in our work! I've just released a new version (0.0.4) that solves this issue. Please have a try.

GuishePerez commented 4 months ago

Hi @BobXWu and thanks for your quick response.

The new release fixes the error I mentioned, but I think maybe could be more interesting to apply a simple change. I was thinking that it should be more useful to save/load the FASTopic object itself rather than only the pytorch model weights, because if someone loads a pretrained model will also be able to call the analysis methods (get_topic(), visualize_topic(), get_top_words(), ...). However, for calling those methods we need to update the attributes of the current instance of the FASTopic object to the same of the pretrained object, like beta or vocab for example, something that is not done in the current implementation.

I've replaced the original save_model() and load_model() methods by these new ones:

# Added some new dependency imports
from pathlib import Path
import joblib
from datetime import datetime

class FASTopic:
    def __init__(
        self,
        num_topics: int=50,  <-- I set this by default in order to do `model = FASTopic.from_pretrained(path)`
        preprocessing: Preprocessing=None,
        doc_embed_model: Union[str, callable]="all-MiniLM-L6-v2",
        num_top_words: int=15,
        DT_alpha: float=3.0,
        TW_alpha: float=2.0,
        theta_temp: float=1.0,
        epochs: int=200,
        learning_rate: float=0.002,
        device: str=None,
        save_memory: bool=False,
        batch_size: int=None,
        log_interval: int=10,
        verbose: bool=False,
    ):
        ...

    def save(self, path: str, model_name: str=None, overwrite: bool=False):
        """Saves the FASTopic model and its PyTorch model weights to the specified path.

        This method saves both the internal state of the FASTopic object (`self`) and the weights of its PyTorch model to the provided path. 

        Args:
            path (str): The path to save the model files. If the directory doesn't exist, it will be created.
            model_name (str, optional): The name of the model file. If not provided, a default name will be generated based on the current timestamp. Defaults to None.
            overwrite (bool, optional): Whether to overwrite existing files in the provided path. Defaults to False.

        Raises:
            FileExistsError: If `overwrite` is False and a file already exists at the specified path.

        Returns:
            None

        This method creates the following files in the specified path:

            * `fastopic.pkl`: Contains the serialized state of the FASTopic object using `joblib.dump`.
            * `pt_model.bin`: Contains the PyTorch model weights saved using `torch.save`.
        """
        # Format output paths
        if model_name is None:
            model_name = f"fastopic_model_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl"
        path = Path(path).joinpath(model_name)
        # Check if path exists
        if path.exists():
            if overwrite:
                logger.info(f"Overwriting model {path}")
                file_path_fastopic = path.joinpath("fastopic.pkl")
                file_path_fastopic_model = path.joinpath("pt_model.bin")
                # Save FASTopic object
                joblib.dump(self, file_path_fastopic)
                # Save FASTopic pytorch model
                torch.save(self.model.state_dict(), file_path_fastopic_model)
                logger.info(
                    f"FASTopic model saved in {file_path_fastopic}. Load this model at any time providing this path to `FASTopic.from_pretrained(path)`.\n"
                    f"FASTopic pytorch model weights saved in {file_path_fastopic_model}."
                )
            else:
                logger.info(f"There is an existing model in the provided path {path}. If you want to save it anyway pass `overwrite=True` as input argument.")
        else:
            path.mkdir(parents=True, exist_ok=True)
            file_path_fastopic = path.joinpath("fastopic.pkl")
            file_path_fastopic_model = path.joinpath("pt_model.bin")
            # Save FASTopic object
            joblib.dump(self, file_path_fastopic)
            # Save FASTopic pytorch model
            torch.save(self.model.state_dict(), file_path_fastopic_model)
            logger.info(
                f"FASTopic model saved in {file_path_fastopic}. Load this model at any time providing this path to `FASTopic.from_pretrained(path)`.\n"
                f"FASTopic pytorch model weights saved in {file_path_fastopic_model}."
            )

    @staticmethod
    def from_pretrained(file_path: str) -> "FASTopic":
        """Loads a pre-trained FASTopic model from a saved file.

        This static method allows you to load a previously saved FASTopic model instance.

        Args:
            file_path (str): The path to the directory containing the serialized FASTopic object (`fastopic.pkl`).

        Returns:
            FASTopic: An instance of the FASTopic class loaded from the provided file.

        Raises:
            FileNotFoundError: If the specified `file_path` does not exist.

        This method expects the following files to be present in the provided directory:

            * `fastopic.pkl`: The serialized state of the FASTopic object, saved using `joblib.dump`.
        """
        # Load FASTopic pretrained instance
        pretrained_model = joblib.load(file_path)
        return pretrained_model

I'm still saving the pytorch model weights. It is not required but I haven't removed it just in case it's useful have them apart at some point.

I've tested and it works as expected and the underlying FASTopic.model (the pytorch module) loads correctly all the pretrained weights, like beta or trained_theta.

Here's an example:

from numpy.testing import assert_almost_equal

from fastopic import FASTopic
import topmost
from topmost.data import download_dataset

# Download data and pre-process if required
dataset_name = "NYT"
download_dataset(dataset_name, cache_path="./datasets")
dataset = topmost.data.DynamicDataset("./datasets/NYT", as_tensor=False)
docs = dataset.train_texts
preprocessing = None

# Number of topics
K = 50

# Define doc embedding model
doc_embedder = "all-MiniLM-L6-v2"  # default one

# Instantiate model's object
model_0 = FASTopic(
    K,
    preprocessing,
    doc_embed_model=doc_embedder,
)
beta_0 = model_0.beta

# Train
topic_top_words, doc_topic_dist = model_0.fit_transform(docs)

# Save
model_0.save(
    path="/root/fastopic_test/models/",
    model_name=f"{doc_embedder}_{dataset_name}",
    overwrite=True,
)

# Load a pre-trained FASTopic model
model_1 = FASTopic.from_pretrained(
    file_path=f"/root/fastopic_test/models/{doc_embedder}_{dataset_name}/fastopic.pkl"
)
beta_1 = model_1.beta

# Assert outputs are loaded correctly
assert_almost_equal(beta_0, beta_1)

Let me know if you see any inconvenience of doing so.

Guille

BobXWu commented 4 months ago

Wow, it looks great! Can you submit a pull request for this update? Thank you.

GuishePerez commented 4 months ago

Sure! closing this issue as it was already solved.