11
11
sys .path .append (parent )
12
12
13
13
from rank_llm .data import Result
14
+ from rank_llm .rerank import PromptMode
14
15
15
16
16
17
class ResponseAnalyzer :
17
18
def __init__ (
18
19
self ,
19
20
data : Union [List [str ], List [Result ]],
20
21
use_alpha : bool = False ,
22
+ prompt_mode : PromptMode = PromptMode .RANK_GPT ,
21
23
) -> None :
22
24
self ._data = data
23
25
self ._use_alpha = use_alpha
26
+ self ._prompt_mode = prompt_mode
24
27
25
28
@staticmethod
26
29
def from_inline_results (
27
- results : List [Result ], use_alpha : bool = False
30
+ results : List [Result ],
31
+ use_alpha : bool = False ,
32
+ prompt_mode : PromptMode = PromptMode .RANK_GPT ,
28
33
) -> "ResponseAnalyzer" :
29
34
"""
30
35
Method to create a ResponseAnalyzer instance from a list of Result objects.
31
36
32
37
Args:
33
38
results (List[Result]): A list of Result objects.
39
+ use_alpha (bool): Whether to evaluate the alphabetical list instead of the numerical one, defaults to False.
40
+ prompt_mode (PromptMode): The prompt mode to use for analysis, defaults to RANK_GPT.
34
41
35
42
Returns:
36
43
ResponseAnalyzer: An instance of the ResponseAnalyzer.
37
44
"""
38
- return ResponseAnalyzer (data = results , use_alpha = use_alpha )
45
+ return ResponseAnalyzer (
46
+ data = results , use_alpha = use_alpha , prompt_mode = prompt_mode
47
+ )
39
48
40
49
@staticmethod
41
50
def from_stored_files (
42
- filenames : List [str ], use_alpha : bool = False
51
+ filenames : List [str ],
52
+ use_alpha : bool = False ,
53
+ prompt_mode : PromptMode = PromptMode .RANK_GPT ,
43
54
) -> "ResponseAnalyzer" :
44
55
"""
45
56
Method to create to create a ResponseAnalyzer instance from a list of filenames.
46
57
47
58
Args:
48
59
filenames (List[str]): A list of filenames where each file contains data to be analyzed.
60
+ use_alpha (bool): Whether to evaluate the alphabetical list instead of the numerical one, defaults to False.
61
+ prompt_mode (PromptMode): The prompt mode to use for analysis, defaults to RANK_GPT.
49
62
50
63
Returns:
51
64
ResponseAnalyzer: An instance of the ResponseAnalyzer.
52
65
"""
53
- return ResponseAnalyzer (data = filenames , use_alpha = use_alpha )
66
+ return ResponseAnalyzer (
67
+ data = filenames , use_alpha = use_alpha , prompt_mode = prompt_mode
68
+ )
54
69
55
70
def read_results_responses (self ) -> Tuple [List [str ], List [int ]]:
56
71
"""
@@ -106,60 +121,79 @@ def read_responses(self) -> Tuple[List[str], List[int]]:
106
121
def _validate_format (self , response : str ) -> bool :
107
122
if self ._use_alpha :
108
123
for c in response :
109
- if not c .isupper () and c != "[" and c != "]" and c != ">" and c != " " :
124
+ if not c .isupper () and c not in "[]> " :
110
125
return False
111
126
return True
112
127
113
128
for c in response :
114
- if not c .isdigit () and c != "[" and c != "]" and c != ">" and c != " " :
129
+ if not c .isdigit () and c not in "[]> , " :
115
130
return False
116
131
return True
117
132
118
133
def _get_num_passages (self , prompt ) -> int :
119
- # TODO: support lrl and rank_gpt_apeer prompt formats
120
- search_text = ""
121
- if type (prompt ) == str :
122
- search_text = prompt
123
-
124
- elif type (prompt ) == list :
125
- if not prompt :
126
- return 0
127
- if "text" in prompt [0 ]:
128
- # For LiT5, there is one "text" entry per passage.
134
+ match self ._prompt_mode :
135
+ case PromptMode .LRL :
136
+ assert isinstance (prompt , list )
137
+ assert len (prompt ) == 1
138
+ search_text = prompt [0 ]["content" ]
139
+ # Look for PASSAGES=[...] and count the number of passages in the list
140
+ begin = search_text .find ("PASSAGES = [" )
141
+ search_text = search_text [begin :]
142
+ end = search_text .find ("]" )
143
+ search_text = search_text [:end ]
144
+ return len (search_text .split (", " ))
145
+ case PromptMode .LiT5 :
146
+ assert type (prompt ) == list
147
+ if not prompt :
148
+ return 0
149
+ # For LiT5, there is one dict with "text" key per passage.
150
+ assert "text" in prompt [0 ]
129
151
return len (prompt )
130
- if "content" in prompt [0 ]:
131
- # For GPT runs, the prompt is an array of json objects with "role" and "content" as keys.
132
- for message in prompt :
133
- search_text += message ["content" ]
134
- else :
152
+ case PromptMode .RANK_GPT :
153
+ search_text = ""
154
+ if type (prompt ) == str :
155
+ search_text = prompt
156
+ elif type (prompt ) == list :
157
+ for message in prompt :
158
+ search_text += message ["content" ]
159
+ else :
160
+ raise ValueError (f"Unsupported prompt format." )
161
+ regex = r"(I will provide you with) (\d+) (passages)"
162
+ match = re .search (regex , search_text )
163
+ if not match :
164
+ raise ValueError (f"Unsupported prompt format." )
165
+ return int (match .group (2 ))
166
+ case PromptMode .RANK_GPT_APEER :
167
+ assert isinstance (prompt , list )
168
+ search_text = ""
169
+ for entry in prompt :
170
+ search_text += entry ["content" ]
171
+ # No mention of the total number of passages.
172
+ # Find the last passage identifier instead.
173
+ matches = re .findall (r"\[\d+\]" , search_text )
174
+ return int (matches [- 1 ][1 :- 1 ])
175
+ case _:
135
176
raise ValueError (f"Unsupported prompt format." )
136
- else :
137
- raise ValueError (f"Unsupported prompt format." )
138
- regex = r"(I will provide you with) (\d+) (passages)"
139
- match = re .search (regex , search_text )
140
- if not match :
141
- raise ValueError (f"Unsupported prompt format." )
142
- return int (match .group (2 ))
143
-
144
- def process_numerical_format (
177
+
178
+ def _process_numerical_format (
145
179
self , response : str , num_passage : int , verbose : bool , stats_dict : Dict [str , int ]
146
180
):
147
181
resp = response .replace ("[rankstart]" , "" )
148
182
resp = resp .replace ("[rankend]" , "" )
183
+ resp = resp .replace ("SORTED_PASSAGES =" , "" )
184
+ resp = resp .replace (" " , "" )
185
+ resp = resp .replace ("PASSAGE" , "" )
186
+ resp = resp .replace ("[" , "" )
187
+ resp = resp .replace ("]" , "" )
149
188
resp = resp .strip ()
150
189
if not self ._validate_format (resp ):
151
190
if verbose :
152
191
print (resp )
153
192
stats_dict ["wrong_format" ] += 1
154
193
return
155
- begin , end = 0 , 0
156
- while begin < len (resp ) and not resp [begin ].isdigit ():
157
- begin += 1
158
- while end < len (resp ) and not resp [len (resp ) - end - 1 ].isdigit ():
159
- end += 1
160
194
try :
161
- resp = resp [ begin : len ( resp ) - end ]
162
- ranks = resp .split ("] > [" )
195
+ delim = "," if self . _prompt_mode == PromptMode . LRL else ">"
196
+ ranks = resp .split (delim )
163
197
ranks = [int (rank ) for rank in ranks ]
164
198
except ValueError :
165
199
if verbose :
@@ -178,7 +212,7 @@ def process_numerical_format(
178
212
return
179
213
stats_dict ["ok" ] += 1
180
214
181
- def process_alphabetical_format (
215
+ def _process_alphabetical_format (
182
216
self , response : str , num_passage : int , verbose : bool , stats_dict : Dict [str , int ]
183
217
):
184
218
resp = response .strip ()
@@ -236,14 +270,14 @@ def count_errors(
236
270
}
237
271
for resp , num_passage in zip (responses , num_passages ):
238
272
if self ._use_alpha :
239
- self .process_alphabetical_format (
273
+ self ._process_alphabetical_format (
240
274
response = resp ,
241
275
num_passage = num_passage ,
242
276
verbose = verbose ,
243
277
stats_dict = stats_dict ,
244
278
)
245
279
else :
246
- self .process_numerical_format (
280
+ self ._process_numerical_format (
247
281
response = resp ,
248
282
num_passage = num_passage ,
249
283
verbose = verbose ,
@@ -263,12 +297,16 @@ def count_errors(
263
297
264
298
def main (args ):
265
299
if args .files :
266
- response_analyzer = ResponseAnalyzer .from_stored_files (args .files )
300
+ response_analyzer = ResponseAnalyzer .from_stored_files (
301
+ args .files , use_alpha = args .use_alpha , prompt_mode = args .prompt_mode
302
+ )
267
303
else :
268
304
print ("Error: Please specify the files containing ranking summaries." )
269
305
sys .exit (1 )
270
306
271
- error_counts = response_analyzer .count_errors (args .verbose )
307
+ error_counts = response_analyzer .count_errors (
308
+ verbose = args .verbose , normalize = args .normalize
309
+ )
272
310
print ("Normalized scores:" , error_counts )
273
311
274
312
@@ -277,9 +315,25 @@ def main(args):
277
315
parser .add_argument (
278
316
"--files" , nargs = "+" , help = "Filenames of ranking summaries" , required = False
279
317
)
318
+ parser .add_argument (
319
+ "--use-alpha" ,
320
+ action = "store_true" ,
321
+ help = "Use alphabetical identifiers instead of the numerical ids" ,
322
+ )
323
+ parser .add_argument (
324
+ "--prompt-mode" ,
325
+ type = PromptMode ,
326
+ default = PromptMode .RANK_GPT ,
327
+ choices = list (PromptMode ),
328
+ )
280
329
parser .add_argument (
281
330
"--verbose" , action = "store_true" , help = "Verbose output of errors"
282
331
)
332
+ parser .add_argument (
333
+ "--normalize" ,
334
+ action = "store_true" ,
335
+ help = "Normalize the output dictionary of errors" ,
336
+ )
283
337
args = parser .parse_args ()
284
338
285
339
main (args )
0 commit comments