run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
36.52k stars 5.22k forks source link

[Question]: `SimpleChatStore` encoding for `persist`. #15055

Open Adversarian opened 3 months ago

Adversarian commented 3 months ago

Question Validation

Question

Hi! Hope you're having a swell day.

So I'm using a SimpleChatStore to keep a memory of user interactions with the chatbot. The issue I'm having right now is that since my chatbot speaks Persian, calling persist saves the contents of the memory for the user as a json string to a file. The problem is that while the contents are correctly parsed with utf-8 encoding afterwards, saving them to a json string occupies a lot of space as it uses unicode encoding by default.

For instance the word "سلام" which is 4 characters in persian/arabic, is mapped to \u0633\u0644\u0627\u0645. I'm concerned that this might cause scalability issues in the long run.

The easiest fix for this that comes to my mind is monkey-patching the SimpleChatStore class to add a encoding=None kwarg to the persist and from_persist_path methods that is fed to the open function prior to writing the memory contents to the file. More succinctly, something like this:

def persist(
    self,
    persist_path: str = "chat_store.json",
    fs: Optional[fsspec.AbstractFileSystem] = None,
    encoding: Optional[str] = None,
) -> None:
    """Persist the docstore to a file."""
    fs = fs or fsspec.filesystem("file")
    dirpath = os.path.dirname(persist_path)
    if not fs.exists(dirpath):
        fs.makedirs(dirpath)

    with fs.open(persist_path, "w", encoding=encoding) as f:
        f.write(json.dumps(self.json()))

@classmethod
def from_persist_path(
    cls,
    persist_path: str = "chat_store.json",
    fs: Optional[fsspec.AbstractFileSystem] = None,
    encoding: Optional[str] = None,
) -> "SimpleChatStore":
    """Create a SimpleChatStore from a persist path."""
    fs = fs or fsspec.filesystem("file")
    if not fs.exists(persist_path):
        return cls()
    with fs.open(persist_path, "r", encoding=encoding) as f:
        data = json.load(f)
    return cls.parse_raw(data)

I was wondering if there was a cleaner way to do this with the current API. If not, I can introduce the change in a small PR if you like the idea.

dosubot[bot] commented 3 months ago

Hey there, @Adversarian! I'm here to help you out. Let's tackle this SimpleChatStore encoding issue together.

The SimpleChatStore class does not currently support any form of encoding options for its persist and from_persist_path methods. The methods use json.dumps and json.load without any additional encoding parameters [1].

However, you can modify the persist method to handle different encoding options by passing additional parameters to the json.dumps function. Here is an example of how you can modify the persist method to handle different encoding options:

def persist(
    self,
    persist_path: str = "chat_store.json",
    fs: Optional[fsspec.AbstractFileSystem] = None,
    ensure_ascii: bool = True
) -> None:
    """Persist the docstore to a file."""
    fs = fs or fsspec.filesystem("file")
    dirpath = os.path.dirname(persist_path)
    if not fs.exists(dirpath):
        fs.makedirs(dirpath)

    with fs.open(persist_path, "w") as f:
        f.write(json.dumps(self.json(), ensure_ascii=ensure_ascii))

In this modified version, the persist method now accepts an additional ensure_ascii parameter, which is passed to json.dumps. This allows you to control the encoding of the JSON output [1].

For the from_persist_path method, you can similarly add an encoding parameter to handle different encodings when reading the file:

@classmethod
def from_persist_path(
    cls,
    persist_path: str = "chat_store.json",
    fs: Optional[fsspec.AbstractFileSystem] = None,
    encoding: Optional[str] = None
) -> "SimpleChatStore":
    """Create a SimpleChatStore from a persist path."""
    fs = fs or fsspec.filesystem("file")
    if not fs.exists(persist_path):
        return cls()
    with fs.open(persist_path, "r", encoding=encoding) as f:
        data = json.load(f)
    return cls.parse_raw(data)

This approach should help you handle different encoding options more cleanly without needing to monkey-patch the class.

To continue talking to Dosu, mention @dosu.

Adversarian commented 3 months ago

Evidently, the solution isn't as straightforward as I originally thought. This goes all the way up to BaseComponent class where the to_json method needs to be modified so that it returns json.dumps(data, ensure_ascii=False) instead. I'm wondering if forcing ensure_ascii=True is intended and if so why? I'm asking because to me it appears that this change would likely result in a fairly substantial reduction in the persisted chat store file size.

dosubot[bot] commented 3 months ago

The current implementation of the to_json method in the BaseComponent class does not explicitly enforce ensure_ascii=True. The method simply calls json.dumps(data) without specifying the ensure_ascii parameter, which defaults to True in Python's json module.

If you want to handle Unicode encoding more cleanly, you can modify the to_json method to set ensure_ascii=False. This will allow Unicode characters to be included in the JSON output without being escaped.

Here is the modified to_json method:

def to_json(self, **kwargs: Any) -> str:
    data = self.to_dict(**kwargs)
    return json.dumps(data, ensure_ascii=False)

This change will ensure that Unicode characters are properly handled in the JSON output, potentially reducing the file size of the persisted chat store [1][2].

To continue talking to Dosu, mention @dosu.

Adversarian commented 3 months ago

@dosu The issue with this is that monkey-patching this class requires me to also monkey-patch the entire chain of classes leading up to it (i.e., SimpleChatStore -> BaseChatStore -> BaseComponent in my case). I would also have to potentially repeat this if I ever want to have this feature against for another component. It seems easier to just modify BaseComponent directly but I'm wondering if there's an explicit reason why the default behavior of json.dumps is used instead of just using ensure_ascii=False or at least exposing it through a keyword argument.

dosubot[bot] commented 3 months ago

The default behavior of json.dumps is used in the BaseComponent class without ensure_ascii=False. However, it is possible to modify the BaseComponent class to either use ensure_ascii=False or expose this option through a keyword argument.

Here is how you can modify the BaseComponent class to expose the ensure_ascii option through a keyword argument:

class BaseComponent(BaseModel):
    """Base component object to capture class names."""

    class Config:
        @staticmethod
        def schema_extra(schema: Dict[str, Any], model: "BaseComponent") -> None:
            """Add class name to schema."""
            schema["properties"]["class_name"] = {
                "title": "Class Name",
                "type": "string",
                "default": model.class_name(),
            }

    @classmethod
    def class_name(cls) -> str:
        """
        Get the class name, used as a unique ID in serialization.

        This provides a key that makes serialization robust against actual class
        name changes.
        """
        return "base_component"

    def json(self, **kwargs: Any) -> str:
        return self.to_json(**kwargs)

    def dict(self, **kwargs: Any) -> Dict[str, Any]:
        data = super().dict(**kwargs)
        data["class_name"] = self.class_name()
        return data

    def __getstate__(self) -> Dict[str, Any]:
        state = super().__getstate__()

        # tiktoken is not pickleable
        # state["__dict__"] = self.dict()
        state["__dict__"].pop("tokenizer", None)

        # remove local functions
        keys_to_remove = []
        for key, val in state["__dict__"].items():
            if key.endswith("_fn"):
                keys_to_remove.append(key)
            if "<lambda>" in str(val):
                keys_to_remove.append(key)
        for key in keys_to_remove:
            state["__dict__"].pop(key, None)

        # remove private attributes -- kind of dangerous
        state["__private_attribute_values__"] = {}

        return state

    def __setstate__(self, state: Dict[str, Any]) -> None:
        # Use the __dict__ and __init__ method to set state
        # so that all variable initialize
        try:
            self.__init__(**state["__dict__"])  # type: ignore
        except Exception:
            # Fall back to the default __setstate__ method
            super().__setstate__(state)

    def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
        data = self.dict(**kwargs)
        data["class_name"] = self.class_name()
        return data

    def to_json(self, **kwargs: Any) -> str:
        ensure_ascii = kwargs.pop('ensure_ascii', True)
        data = self.to_dict(**kwargs)
        return json.dumps(data, ensure_ascii=ensure_ascii)

    # TODO: return type here not supported by current mypy version
    @classmethod
    def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self:  # type: ignore
        if isinstance(kwargs, dict):
            data.update(kwargs)

        data.pop("class_name", None)
        return cls(**data)

    @classmethod
    def from_json(cls, data_str: str, **kwargs: Any) -> Self:  # type: ignore
        data = json.loads(data_str)
        return cls.from_dict(data, **kwargs)

With this modification, you can now call to_json with the ensure_ascii parameter:

component = BaseComponent()
json_str = component.to_json(ensure_ascii=False)

This approach is similar to how the JSONReader class in the LlamaIndex framework handles the ensure_ascii parameter, allowing for flexible JSON serialization options [1][2][3][4][5].

To continue talking to Dosu, mention @dosu.