forked from eoin-cr/PedalToTheMetal
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpedalServer.py
65 lines (51 loc) · 2.24 KB
/
pedalServer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from flask import Flask, request
import os
import data
import spotifyHandler
import torch
import numpy as np
from model import SimpleNN
loaded_model = torch.load('trained_model.pth')
loaded_model.eval()
emotionsMap = {"happy": 0, "sad": 1, "chill": 2, "angry": 3, "invalid": 4}
def predict(file):
input_data = np.genfromtxt(file, dtype=float, delimiter=',', names=True)
print("hi")
print(f'data: {input_data}')
target_array_shape = (45, 3)
pad_x = (target_array_shape[0] - input_data.shape[0])
input_data = [list(item) for item in input_data]
input_data = np.delete(input_data, 3, 1)
input_data = np.pad(input_data, ((pad_x, 0), (0, 0)), mode="constant")
input_data = input_data.flatten()
test_data_array = input_data
test_data_array = np.vstack((test_data_array, np.zeros((135,))))
test_data_array = np.array(test_data_array, dtype="float32")
print(type(test_data_array[0][0]))
test_data_array = torch.Tensor(test_data_array)
with torch.no_grad():
predicted_outputs = loaded_model(test_data_array)
_, predicted_labels = torch.max(predicted_outputs, 1)
predicted_labels = predicted_labels.tolist()
return predicted_labels[0]
# To run: python -m flask --app pedalServer run --host=0.0.0.0
app = Flask(__name__)
@app.route('/post', methods=['POST'])
def result():
try:
print(request.form, request.headers)
csvStr = request.form['csv_as_str']
with open("./receivedData.csv", "tw", encoding="utf8", newline="") as F:
F.write(csvStr)
data.parseFile("./receivedData.csv")
fileCount = len([file for file in os.listdir("./receivedData/") if os.path.isfile(os.path.join("./receivedData/", file))])
file = f"./receivedData/receivedData-Parsed{fileCount-1}.csv" # Take the last 2-second entry
# Trigger the AI model then return the result from spotify
result = predict(file)
for f in os.listdir("./receivedData/"):
os.remove(os.path.join("./receivedData/", f))
os.rmdir("./receivedData/")
return str(result) + "\n" + "\n".join([item for item in spotifyHandler.playlistGeneration(spotifyHandler.getGenres(result), 5)])
except Exception as e:
print(e)
return ""