-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.py
133 lines (105 loc) · 3.91 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from transformers import AutoModelWithLMHead, AutoTokenizer, top_k_top_p_filtering
import torch
from flask import Flask, request, Response, jsonify
from torch.nn import functional as F
from queue import Queue, Empty
import time
import threading
import torch
from util import get_bad_word_list
# Server & Handling Setting
app = Flask(__name__)
requests_queue = Queue()
BATCH_SIZE = 1
CHECK_INTERVAL = 0.1
tokenizer = AutoTokenizer.from_pretrained("mrm8488/gpt2-finetuned-reddit-tifu")
model = AutoModelWithLMHead.from_pretrained("mrm8488/gpt2-finetuned-reddit-tifu", return_dict=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
bad_word_tokens = get_bad_word_list()
# Queue 핸들링
def handle_requests_by_batch():
while True:
requests_batch = []
while not (len(requests_batch) >= BATCH_SIZE):
try:
requests_batch.append(requests_queue.get(timeout=CHECK_INTERVAL))
except Empty:
continue
for requests in requests_batch:
if len(requests['input']) == 2:
requests['output'] = run_word(requests['input'][0], requests['input'][1])
else:
requests['output'] = run_generate(requests['input'][0], requests['input'][1], requests['input'][2])
# 쓰레드
threading.Thread(target=handle_requests_by_batch).start()
def run_word(sequence, num_samples):
try:
input_ids = tokenizer.encode(sequence, return_tensors="pt")
tokens_tensor = input_ids.to(device)
next_token_logits = model(tokens_tensor).logits[:, -1, :]
filtered_next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=50, top_p=1.0)
probs = F.softmax(filtered_next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=num_samples)
result = dict()
for idx, token in enumerate(next_token.tolist()[0]):
result[idx] = tokenizer.decode(token)
return result
except Exception as e:
print(e)
return 500
def run_generate(text, num_samples, length):
try:
input_ids = tokenizer.encode(text, return_tensors="pt")
tokens_tensor = input_ids.to(device)
min_length = len(input_ids.tolist()[0])
length += min_length
outputs = model.generate(tokens_tensor,
pad_token_id=50256,
max_length=length,
min_length=length,
do_sample=True,
top_k=50,
num_return_sequences=num_samples,
bad_words_ids=bad_word_tokens)
result = {}
for idx, output in enumerate(outputs):
result[idx] = tokenizer.decode(output.tolist()[min_length:], skip_special_tokens=True)
return result
except Exception as e:
print(e)
return 500
@app.route("/gpt2-reddit/<mode>", methods=['POST'])
def run_gpt2_reddit(mode):
if mode not in ["short", "long"]:
return jsonify({'error': 'This is wrong address'}), 400
# 큐에 쌓여있을 경우,
if requests_queue.qsize() > BATCH_SIZE:
return jsonify({'error': 'TooManyReqeusts'}), 429
# 웹페이지로부터 이미지와 스타일 정보를 얻어옴.
try:
args = []
args.append(request.form['text'])
args.append(int(request.form['num_samples']))
if mode == "long":
length = args.append(int(request.form['length']))
except Exception:
print("Empty Text")
return Response("fail", status=400)
# Queue - put data
req = {
'input': args
}
requests_queue.put(req)
# Queue - wait & check
while 'output' not in req:
time.sleep(CHECK_INTERVAL)
result = req['output']
return result
# Health Check
@app.route("/healthz", methods=["GET"])
def healthCheck():
return "", 200
if __name__ == "__main__":
from waitress import serve
serve(app, host='0.0.0.0', port=8000)