-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathA1.py
92 lines (73 loc) · 2.91 KB
/
A1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import matplotlib.pyplot as plt
def main():
nArms = 10
eps = 0.9
totalrew = 0
trueDistrib = setEnv(nArms)
# trueDistrib = [0.20667192, 0.21488821, 0.60482498, 0.30061569, 0.44926704, 0.98593824, 0.3833595, 0.98111571, 0.99304015, 0.96186659]
estDistrib = np.full((nArms,), 1/nArms)
print(trueDistrib)
for i in range(1, 100000):
isGreedy = rollGreedy(eps)
estDistrib = greedyArm(trueDistrib, estDistrib, i) if isGreedy else randomArm(trueDistrib, estDistrib, i)
print(estDistrib)
plotHistograms(trueDistrib, estDistrib)
# Generate a true probability distribution of nArm bandits
def setEnv(nArms):
trueDist = np.random.rand(nArms,)
return trueDist
def averageReward(currRew, currAvg, n):
return currAvg + (1/n) * (currRew - currAvg)
def rollReward(nArms, ind, prob):
# Roll a number that samples from the uniform distribution [0,1)
# If the number is smaller than the probability, return a reward
rewTable = np.zeros((nArms,))
roll = np.random.rand(1)[0]
if roll <= prob:
rewTable[ind] = 1
return rewTable
# Update estimated probabilities
def updateProbs(reward, table, n):
# table[ind] += (1/n) * (reward - table[ind])
table += (1/n) * (reward - table)
return table
# Pick whether or not the move made is greedy
def rollGreedy(e):
return True if np.random.rand(1)[0] <= e else False
# Action to take if making a greedy move
def greedyArm(trueDistrib, estDistrib, n):
# Pick the indices from the estimated distribution with the highest values
# If more than one of the same highest value, pick one at random
# Take the arm with a probability corresponding to the true distribution
maxInds = np.argwhere(estDistrib == np.amax(estDistrib))
if len(maxInds) > 1:
indPicked = maxInds[np.random.randint(0, len(maxInds))]
armPicked = trueDistrib[np.squeeze(indPicked)]
else:
indPicked = maxInds[0][0]
armPicked = trueDistrib[indPicked]
# Check whether or not the action gets a reward. Update estimated table.
rewardVal = rollReward(len(estDistrib), indPicked, armPicked)
estDistrib = updateProbs(rewardVal, estDistrib, n)
return estDistrib
# Action to take if making a random move
def randomArm(trueDistrib, estDistrib, n):
# Pick an index at random. Take the true probability corresponding to it.
indPicked = np.random.randint(0, len(estDistrib))
armPicked = trueDistrib[indPicked]
rewardVal = rollReward(len(estDistrib), indPicked, armPicked)
estDistrib = updateProbs(rewardVal, estDistrib, n)
return estDistrib
def actionSelect():
estDistrib + c * np.sqrt(np.log(t) / numActions)
def plotHistograms(true, est):
plt.subplot(1, 2, 1)
plt.title("True Distribution")
plt.plot(true)
plt.subplot(1, 2, 2)
plt.title("Estimated Distribution")
plt.plot(est)
plt.show()
if __name__ == '__main__':
main()