Skip to content

Commit

Permalink
add tool use intent metrics (#871)
Browse files Browse the repository at this point in the history
* add tool use intent metrics

* address comments
  • Loading branch information
fnan authored Dec 6, 2024
1 parent f61a269 commit cc4a9df
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 1 deletion.
44 changes: 44 additions & 0 deletions axlearn/open_api/metrics/tool_use_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,31 @@ def metric_fn(
Returns:
A dictionary of metrics.
- number_of_examples: Total number of examples.
- number_of_parsing_errors: Total number of examples with parsing errors.
- number_of_generation_errors: Total number of examples with generation errors.
- accuracy: Accuracy after standardization of argument values including removing spaces,
punctuation, and converting to lowercases. Also apply matching rules if provided.
See `_default_value_match()` for string standardization.
- strict_accuracy: Strict accuracy where argument values exactly matches the
ground truth or not.
- lenient_accuracy: Accuracy after removing spaces, punctuation and a static list of
stop words in the argument values.
See `_is_arg_value_equal()` in ./tool_use_execution_utils.py for details.
- bow_accuracy: Bag-of-words accuracy. Transforms the argument strings in the same way
as the lenient matching. But instead of comparing the resulting strings it checks
if the words in the ground truth argument values are contained in the predicted
argument values.
See `_is_arg_value_equal()` in ./tool_use_execution_utils.py for details.
- func_name_accuracy: Accuracy of matching the ground truth tool names.
- number_of_expected_tool_calls: Number of tool calls in ground truth.
- number_of_func_call_intents_ground_truth: Number of assistant turns with tool calls
in the ground truth.
- number_of_func_call_intents_pred: Number of assistant turns with tool calls
in predictions.
- func_intent_recall: Recall of tool calls intent.
- func_intent_precision: Precision of tool calls intent.
"""
if len(responses) == 0:
return {
Expand Down Expand Up @@ -453,6 +478,10 @@ def get_tool_calls_from_message(message: dict[str, Any]) -> list[dict[str, Any]]
total_lenient_matches = 0
total_bow_matches = 0

number_of_func_call_intents_ground_truth = 0
number_of_func_call_intents_pred = 0
number_of_func_call_intents_ground_truth_pred = 0

number_of_parsing_errors = 0
number_of_generation_errors = 0
generator_cfg: Generator.Config = generators[EvalGeneratorType.RESPONSE].config
Expand Down Expand Up @@ -487,9 +516,14 @@ def get_tool_calls_from_message(message: dict[str, Any]) -> list[dict[str, Any]]
if target.tool_calls is not None:
target_tool_calls = get_tool_calls_from_message(target.model_dump())
total_tool_calls += len(target_tool_calls)
number_of_func_call_intents_ground_truth += 1

if len(pred_messages) > 0:
pred = pred_messages[0]
if pred.tool_calls is not None:
number_of_func_call_intents_pred += 1
if target.tool_calls is not None:
number_of_func_call_intents_ground_truth_pred += 1

# Check string match.
if (
Expand Down Expand Up @@ -540,6 +574,12 @@ def get_tool_calls_from_message(message: dict[str, Any]) -> list[dict[str, Any]]
if matched:
total_matches += 1

func_intent_recall = number_of_func_call_intents_ground_truth_pred / max(
number_of_func_call_intents_ground_truth, 1
)
func_intent_precision = number_of_func_call_intents_ground_truth_pred / max(
number_of_func_call_intents_pred, 1
)
return {
"accuracy": total_matches / len(responses),
"number_of_examples": len(responses),
Expand All @@ -550,4 +590,8 @@ def get_tool_calls_from_message(message: dict[str, Any]) -> list[dict[str, Any]]
"lenient_accuracy": total_lenient_matches / max(1, total_tool_calls),
"bow_accuracy": total_bow_matches / max(1, total_tool_calls),
"number_of_expected_tool_calls": total_tool_calls,
"number_of_func_call_intents_ground_truth": number_of_func_call_intents_ground_truth,
"number_of_func_call_intents_pred": number_of_func_call_intents_pred,
"func_intent_recall": func_intent_recall,
"func_intent_precision": func_intent_precision,
}
182 changes: 181 additions & 1 deletion axlearn/open_api/metrics/tool_use_execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Unit test for tool_use_execution.py."""

import json
from typing import Union
from unittest.mock import MagicMock, Mock, patch

import pytest
Expand Down Expand Up @@ -41,7 +42,7 @@ def test_empty_responses(self):
self.assertEqual(metrics["accuracy"], 0)
self.assertEqual(metrics["number_of_examples"], 0)

def test_responses_without_tool_calls(self):
def test_responses_without_target_message(self):
"""Tests with responses that lack target_message field."""
responses = [
{
Expand Down Expand Up @@ -290,6 +291,10 @@ def test_match_rules(
self.assertEqual(
metrics["number_of_expected_tool_calls"], number_of_expected_tool_calls
)
self.assertEqual(metrics["number_of_func_call_intents_ground_truth"], 1)
self.assertEqual(metrics["number_of_func_call_intents_pred"], 1)
self.assertEqual(metrics["func_intent_recall"], 1)
self.assertEqual(metrics["func_intent_precision"], 1)

def test_empty_pred(self):
pred_message = {
Expand Down Expand Up @@ -334,3 +339,178 @@ def test_empty_pred(self):
},
)
self.assertEqual(metrics["accuracy"], 0.0)
self.assertEqual(metrics["number_of_func_call_intents_ground_truth"], 1)
self.assertEqual(metrics["number_of_func_call_intents_pred"], 0)
self.assertEqual(metrics["func_intent_recall"], 0)
self.assertEqual(metrics["func_intent_precision"], 0)

@parameterized.parameters(
# Simple match.
dict(
pred_target_pair=[
[{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}],
[{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}],
],
expected_accuracy=1.0,
expected_number_of_examples=1,
expected_func_name_accuracy=1.0,
expected_strict_accuracy=1.0,
expected_lenient_accuracy=1.0,
expected_bow_accuracy=1.0,
expected_number_of_expected_tool_calls=1,
expected_number_of_func_call_intents_ground_truth=1,
expected_number_of_func_call_intents_pred=1,
expected_func_intent_recall=1.0,
expected_func_intent_precision=1.0,
),
# Test function intent, ground truth no tool_calls.
dict(
pred_target_pair=[
[{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}],
"""Get weather for Boston~""",
],
expected_accuracy=0.0,
expected_number_of_examples=1,
expected_func_name_accuracy=0.0,
expected_strict_accuracy=0.0,
expected_lenient_accuracy=0.0,
expected_bow_accuracy=0.0,
expected_number_of_expected_tool_calls=0,
expected_number_of_func_call_intents_ground_truth=0,
expected_number_of_func_call_intents_pred=1,
expected_func_intent_recall=0.0,
expected_func_intent_precision=0.0,
),
# Test function intent, pred no tool_calls.
dict(
pred_target_pair=[
"""Get weather for Boston~""",
[{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}],
],
expected_accuracy=0.0,
expected_number_of_examples=1,
expected_func_name_accuracy=0.0,
expected_strict_accuracy=0.0,
expected_lenient_accuracy=0.0,
expected_bow_accuracy=0.0,
expected_number_of_expected_tool_calls=1,
expected_number_of_func_call_intents_ground_truth=1,
expected_number_of_func_call_intents_pred=0,
expected_func_intent_recall=0.0,
expected_func_intent_precision=0.0,
),
# Tool call name mis-match.
dict(
pred_target_pair=[
[{"name": "get_weather", "arguments": {"location": "Boston, MA"}}],
[{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}],
],
expected_accuracy=0.0,
expected_number_of_examples=1,
expected_func_name_accuracy=0.0,
expected_strict_accuracy=0.0,
expected_lenient_accuracy=0.0,
expected_bow_accuracy=0.0,
expected_number_of_expected_tool_calls=1,
expected_number_of_func_call_intents_ground_truth=1,
expected_number_of_func_call_intents_pred=1,
expected_func_intent_recall=1.0,
expected_func_intent_precision=1.0,
),
# Tool call BOW.
dict(
pred_target_pair=[
[{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}],
[{"name": "get_current_weather", "arguments": {"location": "Boston"}}],
],
expected_accuracy=0.0,
expected_number_of_examples=1,
expected_func_name_accuracy=1.0,
expected_strict_accuracy=0.0,
expected_lenient_accuracy=1.0, # 'ma' is part of _STOP_WORDS and removed.
expected_bow_accuracy=1.0,
expected_number_of_expected_tool_calls=1,
expected_number_of_func_call_intents_ground_truth=1,
expected_number_of_func_call_intents_pred=1,
expected_func_intent_recall=1.0,
expected_func_intent_precision=1.0,
),
)
def test_all_metrics(
self,
pred_target_pair,
expected_accuracy,
expected_number_of_examples,
expected_func_name_accuracy,
expected_strict_accuracy,
expected_lenient_accuracy,
expected_bow_accuracy,
expected_number_of_expected_tool_calls,
expected_number_of_func_call_intents_ground_truth,
expected_number_of_func_call_intents_pred,
expected_func_intent_recall,
expected_func_intent_precision,
):
"""Tests all metrics."""

def _make_message(content_tool_call_info: Union[list, str]) -> dict:
if isinstance(content_tool_call_info, str):
content = content_tool_call_info
tool_calls = None
elif isinstance(content_tool_call_info, list):
content = ""
tool_calls = []
for tool_call_info in content_tool_call_info:
tool_calls.append(
{
"type": "function",
"function": tool_call_info,
}
)
return {
"role": "assistant",
"content": content,
"tool_calls": tool_calls,
}

responses = []
pred_message = _make_message(pred_target_pair[0])
target_message = _make_message(pred_target_pair[1])
responses = [
{
"response": json.dumps({"choices": [{"message": pred_message}]}),
"target_message": target_message,
}
]
mock_target_message = Mock(**target_message)
mock_target_message.model_dump.return_value = target_message
mock_pred_message = Mock(**pred_message)
mock_pred_message.model_dump.return_value = pred_message
self.generator.config.client.klass.parse_generation.return_value = [mock_pred_message]

with patch(
"axlearn.open_api.openai.OpenAIClient.format_message", return_value=mock_target_message
):
metrics = metric_fn(
responses=responses,
generators={
EvalGeneratorType.RESPONSE: self.generator,
},
)
expected_metrics = {
"accuracy": expected_accuracy,
"number_of_examples": expected_number_of_examples,
"number_of_parsing_errors": 0,
"number_of_generation_errors": 0,
"func_name_accuracy": expected_func_name_accuracy,
"strict_accuracy": expected_strict_accuracy,
"lenient_accuracy": expected_lenient_accuracy,
"bow_accuracy": expected_bow_accuracy,
"number_of_expected_tool_calls": expected_number_of_expected_tool_calls,
"number_of_func_call_intents_ground_truth": expected_number_of_func_call_intents_ground_truth, # pylint: disable=line-too-long
"number_of_func_call_intents_pred": expected_number_of_func_call_intents_pred,
"func_intent_recall": expected_func_intent_recall,
"func_intent_precision": expected_func_intent_precision,
}

self.assertEqual(metrics, expected_metrics)

0 comments on commit cc4a9df

Please sign in to comment.