-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_data.py
72 lines (51 loc) · 1.88 KB
/
generate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pickle
import random
import re
import json
import argparse
with open('question_temmplate.pkl','rb') as f:
final_dict = pickle.load(f)
list_of_topic = list(final_dict.keys())
def run(config):
lst = []
data_size = config['data_size']
print(data_size)
for i in range(0,data_size):
write_it = {}
x = 0
topic = 'sport'
question_temmplate = final_dict[topic]['question_temmplate']
mapping_words = list(final_dict[topic]['mapping'].keys())
question_list = list(question_temmplate.keys())
choose_question = random.randint(0,len(question_list)-1)
question = question_list[choose_question]
respone = question_temmplate[question]
choose_response = random.randint(0,len(respone)-1)
answer = respone[choose_response]
for word in mapping_words:
if question.find(word) > -1:
replace_word_list = final_dict[topic]['mapping'][word]
replace_word = replace_word_list[random.randint(0,len(replace_word_list) - 1 )]
question = question.replace(word , replace_word)
if answer.find(word) > -1:
answer = answer.replace(word, replace_word)
question = re.sub('<DOMAIN>',topic, question)
answer = re.sub('<DOMAIN>', topic, answer)
#print("question {} \n answer {} \n ----------------\n".format(question, answer))
write_it['question'] = question
write_it['answer'] = answer
write_it['domain'] = topic
lst.append(write_it)
print(config['data_type'])
file_name = config['data_type']+'.json'
print(len(lst))
with open(file_name,'w') as f:
json.dump(lst,f)
if __name__ == '__main__':
PARSER = argparse.ArgumentParser("Command line arguments")
PARSER.add_argument("-n", "--size", default=1000,
type=int, dest="data_size", help="Number of question and answer generate")
PARSER.add_argument("-t", "--data_type", default='train_sport',
type=str, dest="data_type", help="generate train or test data")
FLAGS = PARSER.parse_args()
run(vars(FLAGS))