Open ahuang11 opened 2 months ago
async def find_join_tables(self, messages: list | str, join_tables: list[str] | None = None):
if join_tables:
multi_source = any('//' in jt for jt in join_tables)
else:
multi_source = len(memory['available_sources']) > 1
if multi_source:
available_tables = ", ".join(f"//{a_source}//{a_table}" for a_source in memory["available_sources"] for a_table in a_source.get_tables())
else:
available_tables = memory['current_source'].get_tables()
with self.interface.add_step(title="Determining tables required for join") as step:
output = await self.llm.invoke(
messages,
system=f"List the tables that need to be joined; be sure to include both `//`: {available_tables}",
response_model=TableJoins,
)
join_tables = output.tables
step.stream(f'\nJoin requires following tables: {join_tables}', replace=True)
step.success_title = 'Found tables required for join'
sources = {}
for source_table in join_tables:
if multi_source:
try:
_, a_source_name, a_table = source_table.split("//", maxsplit=2)
except ValueError:
a_source_name, a_table = source_table.split("//", maxsplit=1)
a_source = next((source for source in memory["available_sources"] if a_source_name == source.name), None)
else:
a_source = memory['current_source']
a_table = source_table
sources[a_source_name] = (a_source, a_table)
return sources
I think this needs context of the column schemas.