From 809b06863135661c561625298fb5363cc66c9795 Mon Sep 17 00:00:00 2001 From: HuXiangkun Date: Thu, 5 Sep 2024 02:57:09 +0000 Subject: [PATCH] Additonal params for sagemaker --- pyproject.toml | 4 ++-- ragchecker/evaluator.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1a88a8a..470109b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ragchecker" -version = "0.1.4" +version = "0.1.5" description = "RAGChecker: A Fine-grained Framework For Diagnosing Retrieval-Augmented Generation (RAG) systems." authors = [ "Xiangkun Hu ", @@ -15,7 +15,7 @@ license = "Apache-2.0" [tool.poetry.dependencies] python = "^3.9" -refchecker = "^0.2.5" +refchecker = "^0.2.6" loguru = "^0.7" dataclasses-json = "^0.6" diff --git a/ragchecker/evaluator.py b/ragchecker/evaluator.py index 81c1bd1..8558d78 100644 --- a/ragchecker/evaluator.py +++ b/ragchecker/evaluator.py @@ -52,6 +52,7 @@ def __init__( joint_check=True, joint_check_num=5, sagemaker_client=None, + sagemaker_params=None, **kwargs ): if openai_api_key: @@ -61,6 +62,7 @@ def __init__( self.joint_check_num = joint_check_num self.kwargs = kwargs self.sagemaker_client = sagemaker_client + self.sagemaker_params = sagemaker_params self.extractor = LLMExtractor( model=extractor_name, @@ -108,6 +110,7 @@ def extract_claims(self, results: List[RAGResult], extract_type="gt_answer"): batch_questions=questions, max_new_tokens=self.extractor_max_new_tokens, sagemaker_client=self.sagemaker_client, + sagemaker_params=self.sagemaker_params, **self.kwargs ) claims = [[c.content for c in res.claims] for res in extraction_results] @@ -169,6 +172,7 @@ def check_claims(self, results: RAGResults, check_type="answer2response"): is_joint=self.joint_check, joint_check_num=self.joint_check_num, sagemaker_client=self.sagemaker_client, + sagemaker_params=self.sagemaker_params, **self.kwargs ) for i, result in enumerate(results):