-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbvr_regression.py
64 lines (53 loc) · 2.25 KB
/
bvr_regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import csv
import numpy as np
from sklearn.svm import SVR
import matplotlib.pyplot as plt
def get_data(filename):
i = 0
dates = []
prices = []
with open(filename, 'r') as csvfile:
csvfileReader = csv.reader(csvfile, delimiter=';')
next(csvfileReader)
for row in csvfileReader:
dates.append(i)
prices.append(float(row[1]))
i = i + 1
return dates, prices, i
def predict_prices(dates, prices, x, fdates, fprices):
dates = np.reshape(dates, (len(dates), 1))
print('starting pricing')
svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1)
svr_rbf.fit(dates, prices)
print('svr completed')
forecastDates = [i + x for i in range(10)]
forecastDates = np.reshape(forecastDates, (len(forecastDates), 1))
fdates = [i +x for i in fdates]
fdates = np.reshape(fdates, (len(fdates), 1))
plt.scatter(dates, prices, color='black', label='Data')
plt.plot(dates, svr_rbf.predict(dates), color='red', label='RBF model')
plt.plot(forecastDates, svr_rbf.predict(forecastDates), color='pink', label='RBF model')
plt.scatter(fdates, fprices, color='orange', label='RBF predition model')
plt.ylabel('USD\RUB')
plt.show()
plt.scatter(dates[-5:], prices[-5:], color='black', label='Данные')
plt.plot(dates[-5:], svr_rbf.predict(dates)[-5:], color='red', label='RBF модель')
plt.plot(forecastDates, svr_rbf.predict(forecastDates + 2), color='red', label='Прогноз')
plt.scatter(fdates, fprices, color='orange', label='Реальные данные')
plt.ylabel('USD\RUB')
#plt.title('Результаты регрессии')
plt.legend()
plt.show()
sum = .0
M = .0
for i in range(2):
M = M + abs(svr_rbf.predict(i + x + 2)[0].item() - fprices[i])
sum = sum + (svr_rbf.predict(i + x + 2)[0].item() - fprices[i])**2
r = (sum / 3) ** 0.5 # standard deviation
M = M / 3 # expectation
return svr_rbf.predict(x)[0], r, M, svr_rbf.predict(x + 2)[0], fprices[0]
dates, prices, i = get_data('UsdRubLight.csv')
fdates, fprices, fi = get_data('UsdRubLightFuture.csv')
predicted_price = predict_prices(dates, prices, i, fdates, fprices)
print(i)
print(predicted_price)