Closed DixitAdh closed 3 months ago
Thanks for reporting this @DixitAdh. However, you should be fine with reconstructing the LLMRails instance from scratch, from the config files (so you should persist those). Is pickling a must?
@drazvan Thanks for replying, pickling is must and also beyond my control. Please look below how mlflow when try to wrap this as a runnable chain using mlflow's pyfunc flavor , it tries to serialize the whole chain and while saving it it fails with the error
import nest_asyncio
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails
from nemoguardrails import RailsConfig
nest_asyncio.apply()
config = RailsConfig.from_path("../_resources/config")
guardrails = RunnableRails(config)
chain_with_guardrails = guardrails | chain
# custom function AgentCaller() to wrap the chain_with_guardrails in mlflow pyfunc class
# here goes the class implementation
import mlflow
import langchain
import databricks
import transformers
import langchain_community
import langchain_core
import langchain_openai
import databricks.vector_search as db_vs
import databricks.sdk.version as sdk_ver
from mlflow.models import infer_signature
signature = infer_signature(input_request, prediction)
mlflow.set_registry_uri("databricks-uc")
model_name = f"{catalog}.{db}.cc_troubleshooting_rag"
input_example = json.dumps(input_request)
with mlflow.start_run():
model_info = mlflow.pyfunc.log_model(
artifact_path="cc_troubleshooting_rag",
python_model=AgentCaller(),
registered_model_name=model_name,
signature=signature,
input_example=input_example,
pip_requirements=[
"mlflow==" + mlflow.__version__,
"langchain==" + langchain.__version__,
#"databricks-vectorsearch==" + db_vs.__version__,
"databricks-sdk==" + sdk_ver.__version__,
"transformers==" + transformers.__version__,
"langchain_community==" + langchain_community.__version__,
"langchain-core==" + langchain_core.__version__,
"langchainhub==0.1.15",
"langchain_openai==0.0.8"
]
)
Output:
TypeError: cannot pickle '_thread.RLock' object
File <command-667055235574013>, line 20
17 input_example = json.dumps(input_request)
19 with mlflow.start_run():
---> 20 model_info = mlflow.pyfunc.log_model(
21 artifact_path="cc_troubleshooting_rag",
22 python_model=AgentCaller(),
23 registered_model_name=model_name,
24 signature=signature,
25 input_example=input_example,
26 pip_requirements=[
27 "mlflow==" + mlflow.__version__,
28 "langchain==" + langchain.__version__,
29 #"databricks-vectorsearch==" + db_vs.__version__,
30 "databricks-sdk==" + sdk_ver.__version__,
31 "transformers==" + transformers.__version__,
32 "langchain_community==" + langchain_community.__version__,
33 "langchain-core==" + langchain_core.__version__,
34 "langchainhub==0.1.15",
35 "langchain_openai==0.0.8"
36 ]
37 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-48e6b904-fdf2-413b-878b-b781b1aef178/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py:2220, in log_model(artifact_path, loader_module, data_path, code_path, conda_env, python_model, artifacts, registered_model_name, signature, input_example, await_registration_for, pip_requirements, extra_pip_requirements, metadata, model_config, example_no_conversion)
2051 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
2052 def log_model(
2053 artifact_path,
(...)
2068 example_no_conversion=False,
2069 ):
2070 """
2071 Log a Pyfunc model with custom inference logic and optional data dependencies as an MLflow
2072 artifact for the current run.
(...)
2218 metadata of the logged model.
2219 """
-> 2220 return Model.log(
2221 artifact_path=artifact_path,
2222 flavor=mlflow.pyfunc,
2223 loader_module=loader_module,
2224 data_path=data_path,
2225 code_path=code_path,
2226 python_model=python_model,
2227 artifacts=artifacts,
2228 conda_env=conda_env,
2229 registered_model_name=registered_model_name,
2230 signature=signature,
2231 input_example=input_example,
2232 await_registration_for=await_registration_for,
2233 pip_requirements=pip_requirements,
2234 extra_pip_requirements=extra_pip_requirements,
2235 metadata=metadata,
2236 model_config=model_config,
2237 example_no_conversion=example_no_conversion,
2238 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-48e6b904-fdf2-413b-878b-b781b1aef178/lib/python3.10/site-packages/mlflow/models/model.py:622, in Model.log(cls, artifact_path, flavor, registered_model_name, await_registration_for, metadata, run_id, **kwargs)
616 if (
617 (tracking_uri == "databricks" or get_uri_scheme(tracking_uri) == "databricks")
618 and kwargs.get("signature") is None
619 and kwargs.get("input_example") is None
620 ):
621 _logger.warning(_LOG_MODEL_MISSING_SIGNATURE_WARNING)
--> 622 flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)
623 mlflow.tracking.fluent.log_artifacts(local_path, mlflow_model.artifact_path, run_id)
624 try:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-48e6b904-fdf2-413b-878b-b781b1aef178/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py:2036, in save_model(path, loader_module, data_path, code_path, conda_env, mlflow_model, python_model, artifacts, signature, input_example, pip_requirements, extra_pip_requirements, metadata, model_config, example_no_conversion, **kwargs)
2024 return _save_model_with_loader_module_and_data_path(
2025 path=path,
2026 loader_module=loader_module,
(...)
2033 model_config=model_config,
2034 )
2035 elif second_argument_set_specified:
-> 2036 return mlflow.pyfunc.model._save_model_with_class_artifacts_params(
2037 path=path,
2038 signature=signature,
2039 hints=hints,
2040 python_model=python_model,
2041 artifacts=artifacts,
2042 conda_env=conda_env,
2043 code_paths=code_path,
2044 mlflow_model=mlflow_model,
2045 pip_requirements=pip_requirements,
2046 extra_pip_requirements=extra_pip_requirements,
2047 model_config=model_config,
2048 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-48e6b904-fdf2-413b-878b-b781b1aef178/lib/python3.10/site-packages/mlflow/pyfunc/model.py:247, in _save_model_with_class_artifacts_params(path, python_model, signature, hints, artifacts, conda_env, code_paths, mlflow_model, pip_requirements, extra_pip_requirements, model_config)
245 saved_python_model_subpath = "python_model.pkl"
246 with open(os.path.join(path, saved_python_model_subpath), "wb") as out:
--> 247 cloudpickle.dump(python_model, out)
248 custom_model_config_kwargs[CONFIG_KEY_PYTHON_MODEL] = saved_python_model_subpath
250 if artifacts:
File /databricks/python/lib/python3.10/site-packages/cloudpickle/cloudpickle_fast.py:57, in dump(obj, file, protocol, buffer_callback)
45 def dump(obj, file, protocol=None, buffer_callback=None):
46 """Serialize obj as bytes streamed into file
47
48 protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
(...)
53 compatibility with older versions of Python.
54 """
55 CloudPickler(
56 file, protocol=protocol, buffer_callback=buffer_callback
---> 57 ).dump(obj)
File /databricks/python/lib/python3.10/site-packages/cloudpickle/cloudpickle_fast.py:602, in CloudPickler.dump(self, obj)
600 def dump(self, obj):
601 try:
--> 602 return Pickler.dump(self, obj)
603 except RuntimeError as e:
604 if "recursion" in e.args[0]:
Important : Without guardrails it works as expected and whole rag chain is saved and logged using MLFlow
@Pouyanpi I see that you have assigned this to yourself, any thoughts or idea for workaround for this?
Hi @DixitAdh, this issue arises from the fact that both threading and contextvars are not pickleable by default. To address this, we need to implement the __getstate__
and __setstate__
methods to customize the pickling process. I'll be opening a draft PR to introduce these changes. It would be of great help if you could test this branch once it's ready.
Thank you!
Hi @Pouyanpi thanks for your reply and appreciate that you are looking into it. I would be more than happy to test it and support on this.
@DixitAdh , could you please test your use case on #627?
This will help us better understand the extent of the changes needed. The current solution is straightforward and functional. However, if there are additional attributes of the LLMRails instance that you need to serialize, we can certainly extend this solution.
@Pouyanpi I tried the code, as i can see the rails do get serialized but when mlflow load them back which means during unpickling it is throwing below error
sb9ql] An error occurred while loading the model: There is no current event loop in thread 'ThreadPoolExecutor-1_0'.
[sb9ql] Traceback (most recent call last):
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/mlflowserving/scoring_server/__init__.py", line 182, in get_model_option_or_exit
[sb9ql] self.model = self.model_future.result()
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/concurrent/futures/_base.py", line 451, in result
[sb9ql] return self.__get_result()
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
[sb9ql] raise self._exception
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/concurrent/futures/thread.py", line 58, in run
[sb9ql] result = self.fn(*self.args, **self.kwargs)
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/mlflowserving/scoring_server/__init__.py", line 125, in _load_model_closure
[sb9ql] model = load_model_fn(path)
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/mlflow/tracing/provider.py", line 237, in wrapper
[sb9ql] is_func_called, result = True, f(*args, **kwargs)
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py", line 1019, in load_model
[sb9ql] model_impl = importlib.import_module(conf[MAIN])._load_pyfunc(data_path)
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/mlflow/langchain/__init__.py", line 884, in _load_pyfunc
[sb9ql] return wrapper_cls(_load_model_from_local_fs(path, model_config), path)
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/mlflow/langchain/__init__.py", line 925, in _load_model_from_local_fs
[sb9ql] return _load_model(local_model_path, flavor_conf)
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/mlflow/langchain/__init__.py", line 603, in _load_model
[sb9ql] model = _load_runnables(local_model_path, flavor_conf)
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/mlflow/langchain/runnables.py", line 479, in _load_runnables
[sb9ql] return _load_from_pickle(os.path.join(path, model_data))
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/mlflow/langchain/utils/__init__.py", line 442, in _load_from_pickle
[sb9ql] return cloudpickle.load(f)
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/nemoguardrails/rails/llm/llmrails.py", line 1075, in __setstate__
[sb9ql] self.__init__(config=config)
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/nemoguardrails/rails/llm/llmrails.py", line 223, in __init__
[sb9ql] self.llm_generation_actions = llm_generation_actions_class(
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/site-packages/nemoguardrails/actions/llm/generation.py", line 106, in __init__
[sb9ql] loop = asyncio.get_event_loop()
[sb9ql] File "/opt/conda/envs/mlflow-env/lib/python3.10/asyncio/events.py", line 656, in get_event_loop
[sb9ql] raise RuntimeError('There is no current event loop in thread %r.'
[sb9ql] RuntimeError: There is no current event loop in thread 'ThreadPoolExecutor-1_0'.
@Pouyanpi maybe checking and creating "loop = asyncio.new_event_loop()" could do the job?
Thank you @DixitAdh! I'll look into it soon. Your suggestion might resolve this problem but cannot tell for sure if other issues would not arise. I'll update you soon 👍🏻
@DixitAdh, It is hard to reproduce the error. But give the new version a try it would be helpful if you could share your Rails config.
@Pouyanpi Apologies for delay in my response, i was on vacation. I have tested the changes, i am able to log the chain with nemo rails through mlflow pyfunc flavor. Before logging i have tested that the python class that i wrote has a predict function and it works as expected. The moment i load the mlflow model that i logged and then call the predict function it fails with the below error
TypeError: Object of type Series is not JSON serializable
File <command-2162541516900227>, line 5
1 #import pandas as pd
2
3 #input = pd.DataFrame({"input": "Washing machine showing E20 error"})
----> 5 llm_model.predict({"input": "Washing machine showing E20 error"})
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py:739, in PyFuncModel.predict(self, data, params)
737 with self._try_get_or_generate_prediction_context() as context:
738 self._update_dependencies_schemas_in_prediction_context(context)
--> 739 return self._predict(data, params)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py:777, in PyFuncModel._predict(self, data, params)
775 params_arg = inspect.signature(self._predict_fn).parameters.get("params")
776 if params_arg and params_arg.kind != inspect.Parameter.VAR_KEYWORD:
--> 777 return self._predict_fn(data, params=params)
779 _log_warning_if_params_not_in_predict_signature(_logger, params)
780 return self._predict_fn(data)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/mlflow/pyfunc/model.py:641, in _PythonModelPyfuncWrapper.predict(self, model_input, params)
637 return self.python_model.predict(
638 self.context, self._convert_input(model_input), params=params
639 )
640 _log_warning_if_params_not_in_predict_signature(_logger, params)
--> 641 return self.python_model.predict(self.context, self._convert_input(model_input))
File <command-2162541516900220>, line 51, in GuardrailWrapper.predict(self, context, model_input)
49 issue = model_input["input"]
50 input = {"input": issue}
---> 51 res = runnable_wrapper.invoke(input)
52 return res
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/langchain_core/runnables/base.py:3963, in RunnableLambda.invoke(self, input, config, **kwargs)
3961 """Invoke this runnable synchronously."""
3962 if hasattr(self, "func"):
-> 3963 return self._call_with_config(
3964 self._invoke,
3965 input,
3966 self._config(config, self.func),
3967 **kwargs,
3968 )
3969 else:
3970 raise TypeError(
3971 "Cannot invoke a coroutine function synchronously."
3972 "Use `ainvoke` instead."
3973 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/langchain_core/runnables/base.py:1626, in Runnable._call_with_config(self, func, input, config, run_type, **kwargs)
1622 context = copy_context()
1623 context.run(var_child_runnable_config.set, child_config)
1624 output = cast(
1625 Output,
-> 1626 context.run(
1627 call_func_with_variable_args, # type: ignore[arg-type]
1628 func, # type: ignore[arg-type]
1629 input, # type: ignore[arg-type]
1630 config,
1631 run_manager,
1632 **kwargs,
1633 ),
1634 )
1635 except BaseException as e:
1636 run_manager.on_chain_error(e)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/langchain_core/runnables/config.py:347, in call_func_with_variable_args(func, input, config, run_manager, **kwargs)
345 if run_manager is not None and accepts_run_manager(func):
346 kwargs["run_manager"] = run_manager
--> 347 return func(input, **kwargs)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/langchain_core/runnables/base.py:3837, in RunnableLambda._invoke(self, input, run_manager, config, **kwargs)
3835 output = chunk
3836 else:
-> 3837 output = call_func_with_variable_args(
3838 self.func, input, config, run_manager, **kwargs
3839 )
3840 # If the output is a runnable, invoke it
3841 if isinstance(output, Runnable):
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/langchain_core/runnables/config.py:347, in call_func_with_variable_args(func, input, config, run_manager, **kwargs)
345 if run_manager is not None and accepts_run_manager(func):
346 kwargs["run_manager"] = run_manager
--> 347 return func(input, **kwargs)
File <command-2162541516900211>, line 25, in run_custom_chain(input_dict, chain)
24 def run_custom_chain(input_dict: dict, chain: RunnableRails = chain_with_guardrails ):
---> 25 return NemoGuardrailWrapperRunnable(chain).invoke(input_dict)
File <command-2162541516900211>, line 16, in NemoGuardrailWrapperRunnable.invoke(self, input, config, **kwargs)
14 def invoke(self, input, config = None, **kwargs):
15 input["output"]= input["input"]
---> 16 return self.formate_output(self.chain.invoke(input, config = None, **kwargs))
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/nemoguardrails/integrations/langchain/runnable_rails.py:186, in RunnableRails.invoke(self, input, config, **kwargs)
184 """Invoke this runnable synchronously."""
185 input_messages = self._transform_input_to_rails_format(input)
--> 186 res = self.rails.generate(
187 messages=input_messages, options=GenerationOptions(output_vars=True)
188 )
189 context = res.output_data
190 result = res.response
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/nemoguardrails/rails/llm/llmrails.py:887, in LLMRails.generate(self, prompt, messages, return_context, options, state)
880 raise RuntimeError(
881 "You are using the sync `generate` inside async code. "
882 "You should replace with `await generate_async(...)` or use `nest_asyncio.apply()`."
883 )
885 loop = get_or_create_event_loop()
--> 887 return loop.run_until_complete(
888 self.generate_async(
889 prompt=prompt,
890 messages=messages,
891 options=options,
892 state=state,
893 return_context=return_context,
894 )
895 )
File /databricks/python/lib/python3.10/site-packages/nest_asyncio.py:90, in _patch_loop.<locals>.run_until_complete(self, future)
87 if not f.done():
88 raise RuntimeError(
89 'Event loop stopped before Future completed.')
---> 90 return f.result()
File /usr/lib/python3.10/asyncio/futures.py:201, in Future.result(self)
199 self.__log_traceback = False
200 if self._exception is not None:
--> 201 raise self._exception.with_traceback(self._exception_tb)
202 return self._result
File /usr/lib/python3.10/asyncio/tasks.py:232, in Task.__step(***failed resolving arguments***)
228 try:
229 if exc is None:
230 # We use the `send` method directly, because coroutines
231 # don't have `__iter__` and `__next__` methods.
--> 232 result = coro.send(None)
233 else:
234 result = coro.throw(exc)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/nemoguardrails/rails/llm/llmrails.py:638, in LLMRails.generate_async(self, prompt, messages, options, state, streaming_handler, return_context)
635 processing_log = []
637 # The array of events corresponding to the provided sequence of messages.
--> 638 events = self._get_events_for_messages(messages, state)
640 if self.config.colang_version == "1.0":
641 # If we had a state object, we also need to prepend the events from the state.
642 state_events = []
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/nemoguardrails/rails/llm/llmrails.py:438, in LLMRails._get_events_for_messages(self, messages, state)
436 p = len(messages) - 1
437 while p > 0:
--> 438 cache_key = get_history_cache_key(messages[0:p])
439 if cache_key in self.events_history_cache:
440 events = self.events_history_cache[cache_key].copy()
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b1a9e324-e840-4ec5-8df0-c6b4328b3789/lib/python3.10/site-packages/nemoguardrails/rails/llm/utils.py:39, in get_history_cache_key(messages)
37 key_items.append(msg["content"])
38 elif msg["role"] == "context":
---> 39 key_items.append(json.dumps(msg["content"]))
40 elif msg["role"] == "event":
41 key_items.append(json.dumps(msg["event"]))
File /usr/lib/python3.10/json/__init__.py:231, in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
226 # cached encoder
227 if (not skipkeys and ensure_ascii and
228 check_circular and allow_nan and
229 cls is None and indent is None and separators is None and
230 default is None and not sort_keys and not kw):
--> 231 return _default_encoder.encode(obj)
232 if cls is None:
233 cls = JSONEncoder
File /usr/lib/python3.10/json/encoder.py:199, in JSONEncoder.encode(self, o)
195 return encode_basestring(o)
196 # This doesn't pass the iterator directly to ''.join() because the
197 # exceptions aren't as detailed. The list call should be roughly
198 # equivalent to the PySequence_Fast that ''.join() would do.
--> 199 chunks = self.iterencode(o, _one_shot=True)
200 if not isinstance(chunks, (list, tuple)):
201 chunks = list(chunks)
File /usr/lib/python3.10/json/encoder.py:257, in JSONEncoder.iterencode(self, o, _one_shot)
252 else:
253 _iterencode = _make_iterencode(
254 markers, self.default, _encoder, self.indent, floatstr,
255 self.key_separator, self.item_separator, self.sort_keys,
256 self.skipkeys, _one_shot)
--> 257 return _iterencode(o, 0)
File /usr/lib/python3.10/json/encoder.py:179, in JSONEncoder.default(self, o)
160 def default(self, o):
161 """Implement this method in a subclass such that it returns
162 a serializable object for ``o``, or calls the base implementation
163 (to raise a ``TypeError``).
(...)
177
178 """
--> 179 raise TypeError(f'Object of type {o.__class__.__name__} '
180 f'is not JSON serializable')
I have tried both async and sync approach while building the chain with nemo both shows the same error. I have been banging my head since yesterday and running out of options. If you want we can schedule a call and i can show you as well.
@Pouyanpi please ignore my previous message. It was an issue on MLFLow side and i got some help from them. Now i can confirm that this code change has fixed my issue for serialisation. I am able to log the chain in mlflow and load it back. I appreciate all your help and thanks again :)
Summary
I started working with Nemo and instantly adopted it for wrapping our RAG chain which is written using langchain. It works great when i am working in notebook and create a chain with guardrails but the moment i try to save it using mlflow the whole thing breaks and throw this error TypeError: cannot pickle '_thread.RLock' object
steps to reproduce
I have provided a short example to show how it fails to serialize, similar thing happens when i try to use mlflow to log the chain using langchain flavor, i have also tried with custom pyfunc flavor.
I guess the problem is just to be able to serialise this, if that can be achieved i can log the whole chain and my RAG chain can work end to end.