Skip to content

Commit 216cb7a

Browse files
ankursharmascopybara-github
authored andcommitted
fix: Update routes for run_eval, get_eval_result, list_eval_results and list_metrics_info
PiperOrigin-RevId: 799718430
1 parent ad81aa5 commit 216cb7a

File tree

2 files changed

+109
-18
lines changed

2 files changed

+109
-18
lines changed

src/google/adk/cli/adk_web_server.py

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,18 @@ class AddSessionToEvalSetRequest(common.BaseModel):
173173

174174

175175
class RunEvalRequest(common.BaseModel):
176-
eval_ids: list[str] # if empty, then all evals in the eval set are run.
176+
eval_ids: list[str] = Field(
177+
deprecated=True,
178+
default_factory=list,
179+
description="This field is deprecated, use eval_case_ids instead.",
180+
)
181+
eval_case_ids: list[str] = Field(
182+
default_factory=list,
183+
description=(
184+
"List of eval case ids to evaluate. if empty, then all eval cases in"
185+
" the eval set are run."
186+
),
187+
)
177188
eval_metrics: list[EvalMetric]
178189

179190

@@ -195,6 +206,10 @@ class RunEvalResult(common.BaseModel):
195206
session_id: str
196207

197208

209+
class RunEvalResponse(common.BaseModel):
210+
run_eval_results: list[RunEvalResult]
211+
212+
198213
class GetEventGraphResult(common.BaseModel):
199214
dot_src: str
200215

@@ -207,6 +222,22 @@ class ListEvalSetsResponse(common.BaseModel):
207222
eval_set_ids: list[str]
208223

209224

225+
class EvalResult(EvalSetResult):
226+
"""This class has no field intentionally.
227+
228+
The goal here is to just give a new name to the class to align with the API
229+
endpoint.
230+
"""
231+
232+
233+
class ListEvalResultsResponse(common.BaseModel):
234+
eval_result_ids: list[str]
235+
236+
237+
class ListMetricsInfoResponse(common.BaseModel):
238+
metrics_info: list[MetricInfo]
239+
240+
210241
class AdkWebServer:
211242
"""Helper class for setting up and running the ADK web server on FastAPI.
212243
@@ -690,14 +721,30 @@ async def delete_eval(
690721
except NotFoundError as nfe:
691722
raise HTTPException(status_code=404, detail=str(nfe)) from nfe
692723

724+
@deprecated(
725+
"Please use run_eval instead. This will be removed in future releases."
726+
)
693727
@app.post(
694728
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
695729
response_model_exclude_none=True,
696730
tags=[TAG_EVALUATION],
697731
)
698-
async def run_eval(
732+
async def run_eval_legacy(
699733
app_name: str, eval_set_id: str, req: RunEvalRequest
700734
) -> list[RunEvalResult]:
735+
run_eval_response = await run_eval(
736+
app_name=app_name, eval_set_id=eval_set_id, req=req
737+
)
738+
return run_eval_response.run_eval_results
739+
740+
@app.post(
741+
"/apps/{app_name}/eval-sets/{eval_set_id}/run",
742+
response_model_exclude_none=True,
743+
tags=[TAG_EVALUATION],
744+
)
745+
async def run_eval(
746+
app_name: str, eval_set_id: str, req: RunEvalRequest
747+
) -> RunEvalResponse:
701748
"""Runs an eval given the details in the eval request."""
702749
# Create a mapping from eval set file to all the evals that needed to be
703750
# run.
@@ -727,7 +774,7 @@ async def run_eval(
727774
inference_request = InferenceRequest(
728775
app_name=app_name,
729776
eval_set_id=eval_set.eval_set_id,
730-
eval_case_ids=req.eval_ids,
777+
eval_case_ids=req.eval_case_ids or req.eval_ids,
731778
inference_config=InferenceConfig(),
732779
)
733780
inference_results = await _collect_inferences(
@@ -760,18 +807,41 @@ async def run_eval(
760807
)
761808
)
762809

763-
return run_eval_results
810+
return RunEvalResponse(run_eval_results=run_eval_results)
764811

765812
@app.get(
766-
"/apps/{app_name}/eval_results/{eval_result_id}",
813+
"/apps/{app_name}/eval-results/{eval_result_id}",
767814
response_model_exclude_none=True,
768815
tags=[TAG_EVALUATION],
769816
)
770817
async def get_eval_result(
771818
app_name: str,
772819
eval_result_id: str,
773-
) -> EvalSetResult:
820+
) -> EvalResult:
774821
"""Gets the eval result for the given eval id."""
822+
try:
823+
eval_set_result = self.eval_set_results_manager.get_eval_set_result(
824+
app_name, eval_result_id
825+
)
826+
return EvalResult(**eval_set_result.model_dump())
827+
except ValueError as ve:
828+
raise HTTPException(status_code=404, detail=str(ve)) from ve
829+
except ValidationError as ve:
830+
raise HTTPException(status_code=500, detail=str(ve)) from ve
831+
832+
@deprecated(
833+
"Please use get_eval_result instead. This will be removed in future"
834+
" releases."
835+
)
836+
@app.get(
837+
"/apps/{app_name}/eval_results/{eval_result_id}",
838+
response_model_exclude_none=True,
839+
tags=[TAG_EVALUATION],
840+
)
841+
async def get_eval_result_legacy(
842+
app_name: str,
843+
eval_result_id: str,
844+
) -> EvalSetResult:
775845
try:
776846
return self.eval_set_results_manager.get_eval_set_result(
777847
app_name, eval_result_id
@@ -782,27 +852,46 @@ async def get_eval_result(
782852
raise HTTPException(status_code=500, detail=str(ve)) from ve
783853

784854
@app.get(
785-
"/apps/{app_name}/eval_results",
855+
"/apps/{app_name}/eval-results",
786856
response_model_exclude_none=True,
787857
tags=[TAG_EVALUATION],
788858
)
789-
async def list_eval_results(app_name: str) -> list[str]:
859+
async def list_eval_results(app_name: str) -> ListEvalResultsResponse:
790860
"""Lists all eval results for the given app."""
791-
return self.eval_set_results_manager.list_eval_set_results(app_name)
861+
eval_result_ids = self.eval_set_results_manager.list_eval_set_results(
862+
app_name
863+
)
864+
return ListEvalResultsResponse(eval_result_ids=eval_result_ids)
865+
866+
@deprecated(
867+
"Please use list_eval_results instead. This will be removed in future"
868+
" releases."
869+
)
870+
@app.get(
871+
"/apps/{app_name}/eval_results",
872+
response_model_exclude_none=True,
873+
tags=[TAG_EVALUATION],
874+
)
875+
async def list_eval_results_legacy(app_name: str) -> list[str]:
876+
list_eval_results_response = await list_eval_results(app_name)
877+
return list_eval_results_response.eval_result_ids
792878

793879
@app.get(
794-
"/apps/{app_name}/eval_metrics",
880+
"/apps/{app_name}/metrics-info",
795881
response_model_exclude_none=True,
796882
tags=[TAG_EVALUATION],
797883
)
798-
async def list_eval_metrics(app_name: str) -> list[MetricInfo]:
884+
async def list_metrics_info(app_name: str) -> ListMetricsInfoResponse:
799885
"""Lists all eval metrics for the given app."""
800886
try:
801887
from ..evaluation.metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
802888

803889
# Right now we ignore the app_name as eval metrics are not tied to the
804890
# app_name, but they could be moving forward.
805-
return DEFAULT_METRIC_EVALUATOR_REGISTRY.get_registered_metrics()
891+
metrics_info = (
892+
DEFAULT_METRIC_EVALUATOR_REGISTRY.get_registered_metrics()
893+
)
894+
return ListMetricsInfoResponse(metrics_info=metrics_info)
806895
except ModuleNotFoundError as e:
807896
logger.exception("%s\n%s", MISSING_EVAL_DEPENDENCIES_MESSAGE, e)
808897
raise HTTPException(

tests/unittests/cli/test_fast_api.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -845,18 +845,20 @@ def verify_eval_case_result(actual_eval_case_result):
845845
assert data == [f"{info['app_name']}_test_eval_set_id_eval_result"]
846846

847847

848-
def test_list_eval_metrics(test_app):
849-
"""Test listing eval metrics."""
850-
url = "/apps/test_app/eval_metrics"
848+
def test_list_metrics_info(test_app):
849+
"""Test listing metrics info."""
850+
url = "/apps/test_app/metrics-info"
851851
response = test_app.get(url)
852852

853853
# Verify the response
854854
assert response.status_code == 200
855855
data = response.json()
856-
assert isinstance(data, list)
856+
metrics_info_key = "metricsInfo"
857+
assert metrics_info_key in data
858+
assert isinstance(data[metrics_info_key], list)
857859
# Add more assertions based on the expected metrics
858-
assert len(data) > 0
859-
for metric in data:
860+
assert len(data[metrics_info_key]) > 0
861+
for metric in data[metrics_info_key]:
860862
assert "metricName" in metric
861863
assert "description" in metric
862864
assert "metricValueInfo" in metric

0 commit comments

Comments
 (0)