-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimulation.py
71 lines (51 loc) · 2.19 KB
/
simulation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from pydeeptoy.computational_graph import *
from itertools import takewhile
from itertools import chain
class SimulationContext:
def __init__(self):
self.data_bag = dict()
def get_data(self, key):
if key not in self.data_bag:
self.data_bag[key] = ConnectionData(value=key.init_value)
return self.data_bag[key]
def __getitem__(self, key):
return self.get_data(key)
def __setitem__(self, key, value):
self.data_bag[key] = value
@staticmethod
def sort_topologically(cg: ComputationalGraph, out=list()):
sorted_nodes = []
def depth_first_search(on_vertex_finished):
discovered = dict()
finished = dict()
def visit(vertex, time):
time += 1
discovered[vertex] = time
for v in cg.get_adjacent_in_nodes(vertex):
if v not in discovered:
time = visit(v, time)
time += 1
finished[vertex] = time
on_vertex_finished(time, vertex)
return time
time = 0
root_nodes = chain.from_iterable([cg.adjacencyOutMap[c] for c in out]) if len(out) > 0 else cg.nodes
for v in root_nodes:
if v not in discovered:
time = visit(v, time)
depth_first_search(lambda time, node: sorted_nodes.insert(0, node))
sorted_nodes.reverse()
return sorted_nodes
def forward(self, cg: ComputationalGraph, params=dict(), out=list()):
for p, v in params.items():
self.get_data(p).value = v
for node in self.sort_topologically(cg, out):
node.forward(self)
def backward(self, cg: ComputationalGraph, reset_gradient=True, out=list()):
if reset_gradient:
for i in cg.outputs:
self.get_data(i).reset_gradient(to_value=1)
[node.backward(self) for node in reversed(self.sort_topologically(cg))]
def forward_backward(self, cg: ComputationalGraph, params=dict(), reset_gradient=True, out=list()):
self.forward(cg, params, out=out)
self.backward(cg, reset_gradient=reset_gradient, out=out)