-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeep-q-learning.py
executable file
·127 lines (87 loc) · 3.9 KB
/
deep-q-learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque, namedtuple
import numpy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state'))
class MemoryReplay():
def __init__(self, capacity):
self._memory = deque(maxlen=capacity)
def sample(self, batch_size):
return random.sample(self._memory, batch_size)
def save(self, state, action, reward, next_state):
self._memory.append(Transition(state,action,reward,next_state))
def __len__(self):
return len(self._memory)
class DQN(nn.Module):
def __init__(self, env: gym.Env):
super().__init__()
self.network = nn.Sequential(
nn.Linear(env.observation_space.shape[0], 128),
nn.GELU(),
nn.Linear(128,128),
nn.GELU(),
nn.Linear(128, env.action_space.n)
)
def forward(self, x: torch.Tensor):
return self.network(x)
def main():
env = gym.make("CartPole-v1")
q_network = DQN(env).to(device)
target_network = DQN(env).to(device)
target_network.load_state_dict(q_network.state_dict())
#target_network = copy.deepcopy(q_network)
epsilon_start = 1
epsilon_end = 0.01
epsilon_end_episode = 500
lr = 0.00015
gamma = 0.99
C = 13
batch_size = 256
max_episode_count = 1000
replay_bank_size = 5000
optimizer = optim.AdamW(q_network.parameters(), lr=lr, amsgrad=True)
replay = MemoryReplay(replay_bank_size)
previous_total_rewards = []
for episode in range(max_episode_count):
state,_ = env.reset()
done = False
truncated = False
total_reward = 0
while not done and not truncated:
epsilon = max(episode*((epsilon_end-epsilon_start)/epsilon_end_episode)+epsilon_start, epsilon_end)
if random.random() < epsilon:
action = env.action_space.sample()
else:
with torch.no_grad():
action = torch.argmax(q_network(torch.tensor(state, device=device))).item()
next_state,reward,done,truncated,_ = env.step(action)
total_reward += reward
next_state = None if done else next_state
replay.save(state, [action], [reward], next_state)
state = next_state
# Optimize the model
#if len(replay) > batch_size:
transitions = replay.sample(min(batch_size,len(replay)))
state_batch,action_batch,reward_batch,next_state_batch = list(zip(*transitions))
q_values = q_network(torch.tensor(numpy.array(state_batch), device=device)).gather(1, torch.tensor(action_batch, device=device))
mask = list(map(lambda x: x is not None, next_state_batch))
non_terminated_next_states = tuple(next_state for next_state in next_state_batch if next_state is not None)
with torch.no_grad():
target_q_values = target_network(torch.tensor(numpy.array(non_terminated_next_states), device=device)).max(1,keepdim=True).values
y = torch.tensor(reward_batch, device=device, dtype=torch.float32)
y[mask] += target_q_values * gamma
optimizer.zero_grad()
loss = torch.nn.HuberLoss()(q_values, y)
loss.backward()
optimizer.step()
if episode % C == 0:
target_network.load_state_dict(q_network.state_dict())
previous_total_rewards.append(total_reward)
if len(previous_total_rewards) > 100:
print(sum(previous_total_rewards[-100:])/100,epsilon,episode,total_reward,len(replay))
if (__name__ == "__main__"):
main()