-
Notifications
You must be signed in to change notification settings - Fork 49
/
Copy pathutils.py
29 lines (24 loc) · 939 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import pandas as pd
from typing import List
def generate_prompts(prompt_template: str, data_df: pd.DataFrame) -> List[str]:
"""
Generates prompts for the rows in data_df using the template prompt_template.
Args:
prompt_template: a prompt template
data_df: pandas dataframe of samples to generate prompts for
Returns:
prompts: a list of prompts corresponding to the rows of data_df
"""
assert (
"{{" in prompt_template
), f"Prompt template has no fields to fill, {prompt_template}"
prompts = []
dicts = data_df.to_dict(orient="records")
for dd in dicts:
prompt = str(prompt_template)
for k, v in dd.items():
prompt = prompt.replace("{{" + k + "}}", str(v))
assert not "{{" in prompt, print(prompt)
prompts.append(prompt)
assert len(set(prompts)) == len(prompts), "Duplicated prompts detected"
return prompts