Skip to content

Commit

Permalink
Allow setting query transformers in the BaseRAGQA
Browse files Browse the repository at this point in the history
Related to pathwaycom#67

Add query transformation behavior to `BaseRAGQuestionAnswerer` initialization.

* **`python/pathway/xpacks/llm/question_answering.py`**
  - Add `query_rewrite_method` parameter to `BaseRAGQuestionAnswerer` initialization.
  - Update `answer_query` method to apply the selected query transformation prompt.
  - Use `query_rewrite_method` to select the appropriate query transformation prompt.
  - Add `query_rewrite_method` parameter to `GeometricRAGQuestionAnswerer` initialization.
  - Update `answer_query` method in `GeometricRAGQuestionAnswerer` to apply the selected query transformation prompt.

* **`python/pathway/xpacks/llm/tests/test_rag.py`**
  - Add tests to verify the new functionality of `query_rewrite_method`.
  - Test different values of `query_rewrite_method` parameter.
  • Loading branch information
vishwamartur committed Dec 10, 2024
1 parent 75b0fa8 commit 1f8113c
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 0 deletions.
23 changes: 23 additions & 0 deletions python/pathway/xpacks/llm/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class BaseRAGQuestionAnswerer(SummaryQuestionAnswerer):
A pw.udf function is expected. Defaults to ``pathway.xpacks.llm.prompts.prompt_qa``.
summarize_template: Template for text summarization. Defaults to ``pathway.xpacks.llm.prompts.prompt_summarize``.
search_topk: Top k parameter for the retrieval. Adjusts number of chunks in the context.
query_rewrite_method: Method for query transformation. Accepts values: 'hyde', 'default', or None. Defaults to None.
Example:
Expand Down Expand Up @@ -357,6 +358,7 @@ def __init__(
long_prompt_template: pw.UDF = prompts.prompt_qa,
summarize_template: pw.UDF = prompts.prompt_summarize,
search_topk: int = 6,
query_rewrite_method: str | None = None,
) -> None:

self.llm = llm
Expand All @@ -372,6 +374,7 @@ def __init__(
self.summarize_template = summarize_template

self.search_topk = search_topk
self.query_rewrite_method = query_rewrite_method

self.server: None | QASummaryRestServer = None
self._pending_endpoints: list[tuple] = []
Expand Down Expand Up @@ -402,6 +405,15 @@ def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table:
"""Main function for RAG applications that answer questions
based on available information."""

if self.query_rewrite_method == "hyde":
pw_ai_queries += pw_ai_queries.select(
prompt=prompts.prompt_query_rewrite_hyde(pw.this.prompt)
)
elif self.query_rewrite_method == "default":
pw_ai_queries += pw_ai_queries.select(
prompt=prompts.prompt_query_rewrite(pw.this.prompt)
)

pw_ai_results = pw_ai_queries + self.indexer.retrieve_query(
pw_ai_queries.select(
metadata_filter=pw.this.filters,
Expand Down Expand Up @@ -653,6 +665,7 @@ def __init__(
factor: int = 2,
max_iterations: int = 4,
strict_prompt: bool = False,
query_rewrite_method: str | None = None,
) -> None:
super().__init__(
llm,
Expand All @@ -661,6 +674,7 @@ def __init__(
short_prompt_template=short_prompt_template,
long_prompt_template=long_prompt_template,
summarize_template=summarize_template,
query_rewrite_method=query_rewrite_method,
)
self.n_starting_documents = n_starting_documents
self.factor = factor
Expand All @@ -677,6 +691,15 @@ def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table:
else:
data_column_name = "text"

if self.query_rewrite_method == "hyde":
pw_ai_queries += pw_ai_queries.select(
prompt=prompts.prompt_query_rewrite_hyde(pw.this.prompt)
)
elif self.query_rewrite_method == "default":
pw_ai_queries += pw_ai_queries.select(
prompt=prompts.prompt_query_rewrite(pw.this.prompt)
)

result = pw_ai_queries.select(
*pw.this,
result=answer_with_geometric_rag_strategy_from_index(
Expand Down
118 changes: 118 additions & 0 deletions python/pathway/xpacks/llm/tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,121 @@ def test_base_rag():
"""
),
)


def test_base_rag_with_query_rewrite():
schema = pw.schema_from_types(data=bytes, _metadata=dict)
input = pw.debug.table_from_rows(
schema=schema, rows=[("foo", {}), ("bar", {}), ("baz", {})]
)

vector_server = VectorStoreServer(
input,
embedder=fake_embeddings_model,
)

rag = BaseRAGQuestionAnswerer(
IdentityMockChat(),
vector_server,
short_prompt_template=_short_template,
long_prompt_template=_long_template,
summarize_template=_summarize_template,
search_topk=2,
query_rewrite_method="default",
)

answer_queries = pw.debug.table_from_rows(
schema=rag.AnswerQuerySchema,
rows=[
("foo", None, "gpt3.5", "short"),
("baz", None, "gpt4", "long"),
],
)

answer_output = rag.answer_query(answer_queries)
assert_table_equality(
answer_output.select(result=pw.this.result),
pw.debug.table_from_markdown(
"""
result
gpt3.5,short,foo,foo,bar
gpt4,long,baz,baz,bar
"""
),
)

summarize_query = pw.debug.table_from_rows(
schema=rag.SummarizeQuerySchema,
rows=[(["foo", "bar"], "gpt2")],
)

summarize_outputs = rag.summarize_query(summarize_query)

assert_table_equality(
summarize_outputs.select(result=pw.this.result),
pw.debug.table_from_markdown(
"""
result
gpt2,summarize,foo,bar
"""
),
)


def test_base_rag_with_hyde_query_rewrite():
schema = pw.schema_from_types(data=bytes, _metadata=dict)
input = pw.debug.table_from_rows(
schema=schema, rows=[("foo", {}), ("bar", {}), ("baz", {})]
)

vector_server = VectorStoreServer(
input,
embedder=fake_embeddings_model,
)

rag = BaseRAGQuestionAnswerer(
IdentityMockChat(),
vector_server,
short_prompt_template=_short_template,
long_prompt_template=_long_template,
summarize_template=_summarize_template,
search_topk=2,
query_rewrite_method="hyde",
)

answer_queries = pw.debug.table_from_rows(
schema=rag.AnswerQuerySchema,
rows=[
("foo", None, "gpt3.5", "short"),
("baz", None, "gpt4", "long"),
],
)

answer_output = rag.answer_query(answer_queries)
assert_table_equality(
answer_output.select(result=pw.this.result),
pw.debug.table_from_markdown(
"""
result
gpt3.5,short,foo,foo,bar
gpt4,long,baz,baz,bar
"""
),
)

summarize_query = pw.debug.table_from_rows(
schema=rag.SummarizeQuerySchema,
rows=[(["foo", "bar"], "gpt2")],
)

summarize_outputs = rag.summarize_query(summarize_query)

assert_table_equality(
summarize_outputs.select(result=pw.this.result),
pw.debug.table_from_markdown(
"""
result
gpt2,summarize,foo,bar
"""
),
)

0 comments on commit 1f8113c

Please sign in to comment.