Skip to content

Commit 2811604

Browse files
authored
Add lrl apeer prompt analysis (#185)
- Expand response analysis to other prompt formats. - Some minor clean ups
1 parent c91c011 commit 2811604

File tree

4 files changed

+200
-61
lines changed

4 files changed

+200
-61
lines changed

src/rank_llm/analysis/response_analysis.py

+96-42
Original file line numberDiff line numberDiff line change
@@ -11,46 +11,61 @@
1111
sys.path.append(parent)
1212

1313
from rank_llm.data import Result
14+
from rank_llm.rerank import PromptMode
1415

1516

1617
class ResponseAnalyzer:
1718
def __init__(
1819
self,
1920
data: Union[List[str], List[Result]],
2021
use_alpha: bool = False,
22+
prompt_mode: PromptMode = PromptMode.RANK_GPT,
2123
) -> None:
2224
self._data = data
2325
self._use_alpha = use_alpha
26+
self._prompt_mode = prompt_mode
2427

2528
@staticmethod
2629
def from_inline_results(
27-
results: List[Result], use_alpha: bool = False
30+
results: List[Result],
31+
use_alpha: bool = False,
32+
prompt_mode: PromptMode = PromptMode.RANK_GPT,
2833
) -> "ResponseAnalyzer":
2934
"""
3035
Method to create a ResponseAnalyzer instance from a list of Result objects.
3136
3237
Args:
3338
results (List[Result]): A list of Result objects.
39+
use_alpha (bool): Whether to evaluate the alphabetical list instead of the numerical one, defaults to False.
40+
prompt_mode (PromptMode): The prompt mode to use for analysis, defaults to RANK_GPT.
3441
3542
Returns:
3643
ResponseAnalyzer: An instance of the ResponseAnalyzer.
3744
"""
38-
return ResponseAnalyzer(data=results, use_alpha=use_alpha)
45+
return ResponseAnalyzer(
46+
data=results, use_alpha=use_alpha, prompt_mode=prompt_mode
47+
)
3948

4049
@staticmethod
4150
def from_stored_files(
42-
filenames: List[str], use_alpha: bool = False
51+
filenames: List[str],
52+
use_alpha: bool = False,
53+
prompt_mode: PromptMode = PromptMode.RANK_GPT,
4354
) -> "ResponseAnalyzer":
4455
"""
4556
Method to create to create a ResponseAnalyzer instance from a list of filenames.
4657
4758
Args:
4859
filenames (List[str]): A list of filenames where each file contains data to be analyzed.
60+
use_alpha (bool): Whether to evaluate the alphabetical list instead of the numerical one, defaults to False.
61+
prompt_mode (PromptMode): The prompt mode to use for analysis, defaults to RANK_GPT.
4962
5063
Returns:
5164
ResponseAnalyzer: An instance of the ResponseAnalyzer.
5265
"""
53-
return ResponseAnalyzer(data=filenames, use_alpha=use_alpha)
66+
return ResponseAnalyzer(
67+
data=filenames, use_alpha=use_alpha, prompt_mode=prompt_mode
68+
)
5469

5570
def read_results_responses(self) -> Tuple[List[str], List[int]]:
5671
"""
@@ -106,60 +121,79 @@ def read_responses(self) -> Tuple[List[str], List[int]]:
106121
def _validate_format(self, response: str) -> bool:
107122
if self._use_alpha:
108123
for c in response:
109-
if not c.isupper() and c != "[" and c != "]" and c != ">" and c != " ":
124+
if not c.isupper() and c not in "[]> ":
110125
return False
111126
return True
112127

113128
for c in response:
114-
if not c.isdigit() and c != "[" and c != "]" and c != ">" and c != " ":
129+
if not c.isdigit() and c not in "[]> ,":
115130
return False
116131
return True
117132

118133
def _get_num_passages(self, prompt) -> int:
119-
# TODO: support lrl and rank_gpt_apeer prompt formats
120-
search_text = ""
121-
if type(prompt) == str:
122-
search_text = prompt
123-
124-
elif type(prompt) == list:
125-
if not prompt:
126-
return 0
127-
if "text" in prompt[0]:
128-
# For LiT5, there is one "text" entry per passage.
134+
match self._prompt_mode:
135+
case PromptMode.LRL:
136+
assert isinstance(prompt, list)
137+
assert len(prompt) == 1
138+
search_text = prompt[0]["content"]
139+
# Look for PASSAGES=[...] and count the number of passages in the list
140+
begin = search_text.find("PASSAGES = [")
141+
search_text = search_text[begin:]
142+
end = search_text.find("]")
143+
search_text = search_text[:end]
144+
return len(search_text.split(", "))
145+
case PromptMode.LiT5:
146+
assert type(prompt) == list
147+
if not prompt:
148+
return 0
149+
# For LiT5, there is one dict with "text" key per passage.
150+
assert "text" in prompt[0]
129151
return len(prompt)
130-
if "content" in prompt[0]:
131-
# For GPT runs, the prompt is an array of json objects with "role" and "content" as keys.
132-
for message in prompt:
133-
search_text += message["content"]
134-
else:
152+
case PromptMode.RANK_GPT:
153+
search_text = ""
154+
if type(prompt) == str:
155+
search_text = prompt
156+
elif type(prompt) == list:
157+
for message in prompt:
158+
search_text += message["content"]
159+
else:
160+
raise ValueError(f"Unsupported prompt format.")
161+
regex = r"(I will provide you with) (\d+) (passages)"
162+
match = re.search(regex, search_text)
163+
if not match:
164+
raise ValueError(f"Unsupported prompt format.")
165+
return int(match.group(2))
166+
case PromptMode.RANK_GPT_APEER:
167+
assert isinstance(prompt, list)
168+
search_text = ""
169+
for entry in prompt:
170+
search_text += entry["content"]
171+
# No mention of the total number of passages.
172+
# Find the last passage identifier instead.
173+
matches = re.findall(r"\[\d+\]", search_text)
174+
return int(matches[-1][1:-1])
175+
case _:
135176
raise ValueError(f"Unsupported prompt format.")
136-
else:
137-
raise ValueError(f"Unsupported prompt format.")
138-
regex = r"(I will provide you with) (\d+) (passages)"
139-
match = re.search(regex, search_text)
140-
if not match:
141-
raise ValueError(f"Unsupported prompt format.")
142-
return int(match.group(2))
143-
144-
def process_numerical_format(
177+
178+
def _process_numerical_format(
145179
self, response: str, num_passage: int, verbose: bool, stats_dict: Dict[str, int]
146180
):
147181
resp = response.replace("[rankstart]", "")
148182
resp = resp.replace("[rankend]", "")
183+
resp = resp.replace("SORTED_PASSAGES =", "")
184+
resp = resp.replace(" ", "")
185+
resp = resp.replace("PASSAGE", "")
186+
resp = resp.replace("[", "")
187+
resp = resp.replace("]", "")
149188
resp = resp.strip()
150189
if not self._validate_format(resp):
151190
if verbose:
152191
print(resp)
153192
stats_dict["wrong_format"] += 1
154193
return
155-
begin, end = 0, 0
156-
while begin < len(resp) and not resp[begin].isdigit():
157-
begin += 1
158-
while end < len(resp) and not resp[len(resp) - end - 1].isdigit():
159-
end += 1
160194
try:
161-
resp = resp[begin : len(resp) - end]
162-
ranks = resp.split("] > [")
195+
delim = "," if self._prompt_mode == PromptMode.LRL else ">"
196+
ranks = resp.split(delim)
163197
ranks = [int(rank) for rank in ranks]
164198
except ValueError:
165199
if verbose:
@@ -178,7 +212,7 @@ def process_numerical_format(
178212
return
179213
stats_dict["ok"] += 1
180214

181-
def process_alphabetical_format(
215+
def _process_alphabetical_format(
182216
self, response: str, num_passage: int, verbose: bool, stats_dict: Dict[str, int]
183217
):
184218
resp = response.strip()
@@ -236,14 +270,14 @@ def count_errors(
236270
}
237271
for resp, num_passage in zip(responses, num_passages):
238272
if self._use_alpha:
239-
self.process_alphabetical_format(
273+
self._process_alphabetical_format(
240274
response=resp,
241275
num_passage=num_passage,
242276
verbose=verbose,
243277
stats_dict=stats_dict,
244278
)
245279
else:
246-
self.process_numerical_format(
280+
self._process_numerical_format(
247281
response=resp,
248282
num_passage=num_passage,
249283
verbose=verbose,
@@ -263,12 +297,16 @@ def count_errors(
263297

264298
def main(args):
265299
if args.files:
266-
response_analyzer = ResponseAnalyzer.from_stored_files(args.files)
300+
response_analyzer = ResponseAnalyzer.from_stored_files(
301+
args.files, use_alpha=args.use_alpha, prompt_mode=args.prompt_mode
302+
)
267303
else:
268304
print("Error: Please specify the files containing ranking summaries.")
269305
sys.exit(1)
270306

271-
error_counts = response_analyzer.count_errors(args.verbose)
307+
error_counts = response_analyzer.count_errors(
308+
verbose=args.verbose, normalize=args.normalize
309+
)
272310
print("Normalized scores:", error_counts)
273311

274312

@@ -277,9 +315,25 @@ def main(args):
277315
parser.add_argument(
278316
"--files", nargs="+", help="Filenames of ranking summaries", required=False
279317
)
318+
parser.add_argument(
319+
"--use-alpha",
320+
action="store_true",
321+
help="Use alphabetical identifiers instead of the numerical ids",
322+
)
323+
parser.add_argument(
324+
"--prompt-mode",
325+
type=PromptMode,
326+
default=PromptMode.RANK_GPT,
327+
choices=list(PromptMode),
328+
)
280329
parser.add_argument(
281330
"--verbose", action="store_true", help="Verbose output of errors"
282331
)
332+
parser.add_argument(
333+
"--normalize",
334+
action="store_true",
335+
help="Normalize the output dictionary of errors",
336+
)
283337
args = parser.parse_args()
284338

285339
main(args)

src/rank_llm/demo/experimental_results.py

+57-16
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,39 @@ def create_reranker(name: str):
7373
return Reranker(
7474
SafeGenai("gemini-2.0-flash-001", 4096, keys=get_genai_api_key())
7575
)
76+
if name == "qwen":
77+
return Reranker(
78+
RankListwiseOSLLM(
79+
model="Qwen/Qwen2.5-7B-Instruct",
80+
vllm_batched=True,
81+
)
82+
)
83+
if name == "llama":
84+
return Reranker(
85+
RankListwiseOSLLM(
86+
model="meta-llama/Llama-3.1-8B-Instruct",
87+
vllm_batched=True,
88+
)
89+
)
7690

7791

7892
rerankers = [
7993
"monot5",
94+
"lit5",
8095
"rv",
8196
"rz",
82-
"lit5",
8397
"mistral",
98+
"qwen",
99+
"llama",
84100
"rank_gpt",
85101
"gemini",
86-
"lrl",
87102
"rank_gpt_apeer",
103+
"lrl",
88104
]
89105
results = {}
90106
for key in rerankers:
91107
reranker = create_reranker(key)
92-
for dataset in ["dl19", "dl20", "dl21", "dl22"]: # , "dl23"
108+
for dataset in ["dl19", "dl20", "dl21", "dl22", "dl23"]:
93109
retrieved_results = Retriever.from_dataset_with_prebuilt_index(dataset, k=100)
94110
topics = TOPICS[dataset]
95111
ret_ndcg_10 = EvalFunction.from_results(retrieved_results, topics)
@@ -108,22 +124,47 @@ def create_reranker(name: str):
108124

109125
# Eval
110126
rerank_ndcg_10 = EvalFunction.from_results(rerank_results, topics)
111-
112-
# Response Analysis
113-
# TODO: For now skipping lrl and rank_gpt_apeer since the response analyzer does not support these prompt formats, yet.
114-
if key not in ["monot5", "duot5", "lrl", "rank_gpt_apeer"]:
115-
use_alpha = True if key == "mistral" else False
116-
analyzer = ResponseAnalyzer.from_inline_results(
117-
rerank_results, use_alpha=use_alpha
118-
)
119-
error_counts = analyzer.count_errors()
120-
else:
121-
error_counts = {}
122-
results[(key, dataset)] = (ret_ndcg_10, rerank_ndcg_10, error_counts.__repr__())
127+
results[(key, dataset)] = (ret_ndcg_10, rerank_ndcg_10)
123128
with open(f"{output_path_prefix}/eval_results.txt", "w") as f:
124-
f.write(f"{(ret_ndcg_10, rerank_ndcg_10, error_counts.__repr__())}")
129+
f.write(f"{(ret_ndcg_10, rerank_ndcg_10)}")
125130

126131
# Free up the memory
127132
del reranker
128133

129134
print(results)
135+
136+
# Analyze invocations
137+
results = {}
138+
for model in [
139+
"rv",
140+
"rz",
141+
"lit5",
142+
"mistral",
143+
"rank_gpt",
144+
"gemini",
145+
"rank_gpt_apeer",
146+
"lrl",
147+
"qwen",
148+
"llama",
149+
]:
150+
use_alpha = True if model == "mistral" else False
151+
if model == "lit5":
152+
prompt_mode = PromptMode.LiT5
153+
elif model == "rank_gpt_apeer":
154+
prompt_mode = PromptMode.RANK_GPT_APEER
155+
elif model == "lrl":
156+
prompt_mode = PromptMode.LRL
157+
else:
158+
prompt_mode = PromptMode.RANK_GPT
159+
files = []
160+
for dataset in ["dl19", "dl20", "dl21", "dl22", "dl23"]:
161+
files.append(
162+
f"demo_outputs/{dataset}/{model}/inference_invocations_history.json"
163+
)
164+
analyzer = ResponseAnalyzer.from_stored_files(
165+
files, use_alpha=use_alpha, prompt_mode=prompt_mode
166+
)
167+
error_counts = analyzer.count_errors(verbose=True, normalize=True)
168+
results[model] = error_counts.__repr__()
169+
170+
print(results)

src/rank_llm/demo/rerank_qwen.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
import sys
3+
from pathlib import Path
4+
5+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
6+
parent = os.path.dirname(SCRIPT_DIR)
7+
parent = os.path.dirname(parent)
8+
sys.path.append(parent)
9+
10+
from rank_llm.analysis.response_analysis import ResponseAnalyzer
11+
from rank_llm.data import DataWriter
12+
from rank_llm.evaluation.trec_eval import EvalFunction
13+
from rank_llm.rerank import Reranker
14+
from rank_llm.rerank.listwise import RankListwiseOSLLM
15+
from rank_llm.retrieve import Retriever
16+
17+
# By default uses BM25 for retrieval
18+
dataset_name = "dl19"
19+
requests = Retriever.from_dataset_with_prebuilt_index(dataset_name)
20+
model_coordinator = RankListwiseOSLLM(
21+
model="Qwen/Qwen2.5-7B-Instruct",
22+
vllm_batched=True,
23+
)
24+
reranker = Reranker(model_coordinator)
25+
kwargs = {"populate_invocations_history": True}
26+
rerank_results = reranker.rerank_batch(requests, **kwargs)
27+
28+
# Analyze the response
29+
analyzer = ResponseAnalyzer.from_inline_results(rerank_results, use_alpha=False)
30+
error_counts = analyzer.count_errors()
31+
print(error_counts.__repr__())
32+
33+
# Eval
34+
rerank_ndcg_10 = EvalFunction.from_results(rerank_results, topics)
35+
print(rerank_ndcg_10)
36+
37+
# Write rerank results
38+
writer = DataWriter(rerank_results)
39+
Path(f"demo_outputs/").mkdir(parents=True, exist_ok=True)
40+
writer.write_in_jsonl_format(f"demo_outputs/rerank_results.jsonl")
41+
writer.write_in_trec_eval_format(f"demo_outputs/rerank_results.txt")
42+
writer.write_inference_invocations_history(
43+
f"demo_outputs/inference_invocations_history.json"
44+
)

0 commit comments

Comments
 (0)