diff --git a/axlearn/open_api/metrics/tool_use_execution.py b/axlearn/open_api/metrics/tool_use_execution.py index 4923a3f9c..3296247c3 100644 --- a/axlearn/open_api/metrics/tool_use_execution.py +++ b/axlearn/open_api/metrics/tool_use_execution.py @@ -424,6 +424,27 @@ def metric_fn( Returns: A dictionary of metrics. + - number_of_examples: Total number of examples. + - number_of_parsing_errors: Total number of examples with parsing errors. + - number_of_generation_errors: Total number of examples with generation errors. + - accuracy: Accuracy after standardization of argument values including removing spaces, + punctuation, and converting to lowercases. Also apply matching rules if provided. + - strict_accuracy: Strict accuracy where argument values exactly matches the + ground truth or not. + - lenient_accuracy: Accuracy after removing spaces, punctuation and a static list of + stop words in the argument values. + - bow_accuracy: Bag-of-words accuracy. Transforms the argument strings in the same way + as the lenient matching. But instead of comparing the resulting strings it checks + if the words in the ground truth argument values are contained in the predicted + argument values. + - func_name_accuracy: Accuracy of matching the ground truth tool names. + - number_of_expected_tool_calls: Number of expected tool calls. + - num_func_call_intents_ground_truth: Number of assistant turns with tool calls + in the ground truth. + - num_func_call_intents_pred: Number of assistant turns with tool calls in predictions. + - func_intent_recall: Recall of tool calls intent. + - func_intent_precision": Precision of tool calls intent. + """ if len(responses) == 0: return { @@ -453,6 +474,10 @@ def get_tool_calls_from_message(message: dict[str, Any]) -> list[dict[str, Any]] total_lenient_matches = 0 total_bow_matches = 0 + num_func_call_intents_ground_truth = 0 + num_func_call_intents_pred = 0 + num_func_call_intents_ground_truth_pred = 0 + number_of_parsing_errors = 0 number_of_generation_errors = 0 generator_cfg: Generator.Config = generators[EvalGeneratorType.RESPONSE].config @@ -487,9 +512,14 @@ def get_tool_calls_from_message(message: dict[str, Any]) -> list[dict[str, Any]] if target.tool_calls is not None: target_tool_calls = get_tool_calls_from_message(target.model_dump()) total_tool_calls += len(target_tool_calls) + num_func_call_intents_ground_truth += 1 if len(pred_messages) > 0: pred = pred_messages[0] + if pred.tool_calls is not None: + num_func_call_intents_pred += 1 + if target.tool_calls is not None: + num_func_call_intents_ground_truth_pred += 1 # Check string match. if ( @@ -540,6 +570,12 @@ def get_tool_calls_from_message(message: dict[str, Any]) -> list[dict[str, Any]] if matched: total_matches += 1 + func_intent_recall = num_func_call_intents_ground_truth_pred / max( + num_func_call_intents_ground_truth, 1 + ) + func_intent_precision = num_func_call_intents_ground_truth_pred / max( + num_func_call_intents_pred, 1 + ) return { "accuracy": total_matches / len(responses), "number_of_examples": len(responses), @@ -550,4 +586,8 @@ def get_tool_calls_from_message(message: dict[str, Any]) -> list[dict[str, Any]] "lenient_accuracy": total_lenient_matches / max(1, total_tool_calls), "bow_accuracy": total_bow_matches / max(1, total_tool_calls), "number_of_expected_tool_calls": total_tool_calls, + "num_func_call_intents_ground_truth": num_func_call_intents_ground_truth, + "num_func_call_intents_pred": num_func_call_intents_pred, + "func_intent_recall": func_intent_recall, + "func_intent_precision": func_intent_precision, } diff --git a/axlearn/open_api/metrics/tool_use_execution_test.py b/axlearn/open_api/metrics/tool_use_execution_test.py index 9b08ca856..75b2e0983 100644 --- a/axlearn/open_api/metrics/tool_use_execution_test.py +++ b/axlearn/open_api/metrics/tool_use_execution_test.py @@ -2,6 +2,7 @@ """Unit test for tool_use_execution.py.""" import json +from typing import Union from unittest.mock import MagicMock, Mock, patch import pytest @@ -41,7 +42,7 @@ def test_empty_responses(self): self.assertEqual(metrics["accuracy"], 0) self.assertEqual(metrics["number_of_examples"], 0) - def test_responses_without_tool_calls(self): + def test_responses_without_target_message(self): """Tests with responses that lack target_message field.""" responses = [ { @@ -290,6 +291,10 @@ def test_match_rules( self.assertEqual( metrics["number_of_expected_tool_calls"], number_of_expected_tool_calls ) + self.assertEqual(metrics["num_func_call_intents_ground_truth"], 1) + self.assertEqual(metrics["num_func_call_intents_pred"], 1) + self.assertEqual(metrics["func_intent_recall"], 1) + self.assertEqual(metrics["func_intent_precision"], 1) def test_empty_pred(self): pred_message = { @@ -334,3 +339,177 @@ def test_empty_pred(self): }, ) self.assertEqual(metrics["accuracy"], 0.0) + self.assertEqual(metrics["num_func_call_intents_ground_truth"], 1) + self.assertEqual(metrics["num_func_call_intents_pred"], 0) + self.assertEqual(metrics["func_intent_recall"], 0) + self.assertEqual(metrics["func_intent_precision"], 0) + + @parameterized.parameters( + # Simple match. + dict( + pred_target_pair=[ + [{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}], + [{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}], + ], + expected_accuracy=1.0, + expected_number_of_examples=1, + expected_func_name_accuracy=1.0, + expected_strict_accuracy=1.0, + expected_lenient_accuracy=1.0, + expected_bow_accuracy=1.0, + expected_number_of_expected_tool_calls=1, + expected_num_func_call_intents_ground_truth=1, + expected_num_func_call_intents_pred=1, + expected_func_intent_recall=1.0, + expected_func_intent_precision=1.0, + ), + # Test function intent, ground truth no tool_calls. + dict( + pred_target_pair=[ + [{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}], + """Get weather for Boston~""", + ], + expected_accuracy=0.0, + expected_number_of_examples=1, + expected_func_name_accuracy=0.0, + expected_strict_accuracy=0.0, + expected_lenient_accuracy=0.0, + expected_bow_accuracy=0.0, + expected_number_of_expected_tool_calls=0, + expected_num_func_call_intents_ground_truth=0, + expected_num_func_call_intents_pred=1, + expected_func_intent_recall=0.0, + expected_func_intent_precision=0.0, + ), + # Test function intent, pred no tool_calls. + dict( + pred_target_pair=[ + """Get weather for Boston~""", + [{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}], + ], + expected_accuracy=0.0, + expected_number_of_examples=1, + expected_func_name_accuracy=0.0, + expected_strict_accuracy=0.0, + expected_lenient_accuracy=0.0, + expected_bow_accuracy=0.0, + expected_number_of_expected_tool_calls=1, + expected_num_func_call_intents_ground_truth=1, + expected_num_func_call_intents_pred=0, + expected_func_intent_recall=0.0, + expected_func_intent_precision=0.0, + ), + # Tool call name mis-match. + dict( + pred_target_pair=[ + [{"name": "get_weather", "arguments": {"location": "Boston, MA"}}], + [{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}], + ], + expected_accuracy=0.0, + expected_number_of_examples=1, + expected_func_name_accuracy=0.0, + expected_strict_accuracy=0.0, + expected_lenient_accuracy=0.0, + expected_bow_accuracy=0.0, + expected_number_of_expected_tool_calls=1, + expected_num_func_call_intents_ground_truth=1, + expected_num_func_call_intents_pred=1, + expected_func_intent_recall=1.0, + expected_func_intent_precision=1.0, + ), + # Tool call BOW. + dict( + pred_target_pair=[ + [{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}], + [{"name": "get_current_weather", "arguments": {"location": "Boston"}}], + ], + expected_accuracy=0.0, + expected_number_of_examples=1, + expected_func_name_accuracy=1.0, + expected_strict_accuracy=0.0, + expected_lenient_accuracy=1.0, # 'ma' is part of _STOP_WORDS and removed. + expected_bow_accuracy=1.0, + expected_number_of_expected_tool_calls=1, + expected_num_func_call_intents_ground_truth=1, + expected_num_func_call_intents_pred=1, + expected_func_intent_recall=1.0, + expected_func_intent_precision=1.0, + ), + ) + def test_all_metrics( + self, + pred_target_pair, + expected_accuracy, + expected_number_of_examples, + expected_func_name_accuracy, + expected_strict_accuracy, + expected_lenient_accuracy, + expected_bow_accuracy, + expected_number_of_expected_tool_calls, + expected_num_func_call_intents_ground_truth, + expected_num_func_call_intents_pred, + expected_func_intent_recall, + expected_func_intent_precision, + ): + """Tests all metrics.""" + + def _make_message(content_tool_call_info: Union[list, str]) -> dict: + if isinstance(content_tool_call_info, str): + content = content_tool_call_info + tool_calls = None + elif isinstance(content_tool_call_info, list): + content = "" + tool_calls = [] + for tool_call_info in content_tool_call_info: + tool_calls.append( + { + "type": "function", + "function": tool_call_info, + } + ) + return { + "role": "assistant", + "content": content, + "tool_calls": tool_calls, + } + + responses = [] + pred_message = _make_message(pred_target_pair[0]) + target_message = _make_message(pred_target_pair[1]) + responses = [ + { + "response": json.dumps({"choices": [{"message": pred_message}]}), + "target_message": target_message, + } + ] + mock_target_message = Mock(**target_message) + mock_target_message.model_dump.return_value = target_message + mock_pred_message = Mock(**pred_message) + mock_pred_message.model_dump.return_value = pred_message + self.generator.config.client.klass.parse_generation.return_value = [mock_pred_message] + + with patch( + "axlearn.open_api.openai.OpenAIClient.format_message", return_value=mock_target_message + ): + metrics = metric_fn( + responses=responses, + generators={ + EvalGeneratorType.RESPONSE: self.generator, + }, + ) + expected_metrics = { + "accuracy": expected_accuracy, + "number_of_examples": expected_number_of_examples, + "number_of_parsing_errors": 0, + "number_of_generation_errors": 0, + "func_name_accuracy": expected_func_name_accuracy, + "strict_accuracy": expected_strict_accuracy, + "lenient_accuracy": expected_lenient_accuracy, + "bow_accuracy": expected_bow_accuracy, + "number_of_expected_tool_calls": expected_number_of_expected_tool_calls, + "num_func_call_intents_ground_truth": expected_num_func_call_intents_ground_truth, + "num_func_call_intents_pred": expected_num_func_call_intents_pred, + "func_intent_recall": expected_func_intent_recall, + "func_intent_precision": expected_func_intent_precision, + } + self.assertEqual(metrics, expected_metrics)