-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcgan.py
146 lines (116 loc) · 6.33 KB
/
cgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# code for the gan model
# Shivan Bhatt 18249
from typing import Tuple, List #importing necessary keras libraries
import numpy as np
from tensorflow.keras import Model, Input
from tensorflow.keras.callbacks import History, BaseLogger, ProgbarLogger, CallbackList, Callback
from tensorflow.keras.optimizers import Adam, Optimizer
from data_generator import DataGenerator
from plotter import Plotter
class CGAN(): #defining gan as a class and initializing its members
def __init__(self, data_generator: DataGenerator,
discriminative_network_model: Model,
generative_network_model: Model,
input_shape: Tuple[int, int, int],
condition_shape: Tuple[int, int, int],
optimizer: Optimizer = Adam(0.0002, 0.5)):
self.data_generator = data_generator
self.discriminative_network_model = discriminative_network_model
self.generative_network_model = generative_network_model
self.input_shape = input_shape
self.condition_shape = condition_shape
condition = Input(shape=condition_shape, name='condition_mask') #defining the input layer
artificial = self.generative_network_model(condition)
frozen_discriminative_network_model = Model( #wrapping the discriminative model
inputs=discriminative_network_model.inputs,
outputs=discriminative_network_model.outputs,
name='discriminator'
)
frozen_discriminative_network_model.trainable = False
discrimination_result = frozen_discriminative_network_model([artificial, condition])
self.cgan_model = Model(
inputs=[condition],
outputs=[discrimination_result, artificial],
name='sentinel-cgan'
)
self.cgan_model.compile(loss=['binary_crossentropy', 'mae'], optimizer=optimizer, loss_weights=[1, 100])
self.cgan_model.stop_training = False
self.plotter = Plotter(generative_network_model, data_generator)
def fit(self, epochs: int = 1, batch: int = 1, artificial_label: int = 0, real_label: int = 1, #fitting the model
callbacks: List[Callback] = None) -> History:
processed_images_count = len(self.data_generator.images_df())
callback_metrics = [ #defining evaluation metrics
'discriminator_artificial_acc', 'discriminator_artificial_loss',
'discriminator_real_acc', 'discriminator_real_loss',
'generator_loss'
]
history = History()
_callbacks = [
BaseLogger(stateful_metrics=callback_metrics),
ProgbarLogger(count_mode='steps', stateful_metrics=callback_metrics),
history
]
_callbacks = _callbacks + callbacks if callbacks else _callbacks
callbacks = CallbackList(_callbacks)
callbacks.set_model(self.cgan_model)
callbacks.set_params({
'epochs': epochs,
'steps': int(processed_images_count / batch) + (processed_images_count % batch > 0),
'samples': processed_images_count,
'verbose': True,
'do_validation': False,
'metrics': callback_metrics
})
callbacks.on_train_begin()
for epoch in range(epochs): #training each epoch
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
epoch_artificial_dn_loss = []
epoch_real_dn_loss = []
epoch_gn_loss = []
for i, (real_satellite_images, mask_images) in enumerate(self.data_generator.load(batch)):
effective_batch_size = len(real_satellite_images)
batch_logs = {'batch': i, 'size': effective_batch_size}
callbacks.on_batch_begin(i, batch_logs)
def form_base(bound):
modifier = int(self.input_shape[0] / 2 ** 5)
return np.full((effective_batch_size,) + (modifier, modifier, 1), bound)
artificial_base = form_base(artificial_label)
real_base = form_base(real_label)
artificial_satellite_images = self.generative_network_model.predict(mask_images)
batch_real_dn_loss = self.discriminative_network_model.train_on_batch(
x=[real_satellite_images, mask_images],
y=real_base
)
epoch_real_dn_loss.append(batch_real_dn_loss)
batch_artificial_dn_loss = self.discriminative_network_model.train_on_batch(
x=[artificial_satellite_images, mask_images],
y=artificial_base
)
epoch_artificial_dn_loss.append(batch_artificial_dn_loss)
batch_gn_loss = self.cgan_model.train_on_batch(
x=[mask_images],
y=[real_base, real_satellite_images]
)
epoch_gn_loss.append(batch_gn_loss)
callbacks.on_batch_end(i)
if self.cgan_model.stop_training:
break
epoch_artificial_dn_loss = np.mean(epoch_artificial_dn_loss, axis=0)
epoch_real_dn_loss = np.mean(epoch_real_dn_loss, axis=0)
epoch_gn_loss = np.mean(epoch_gn_loss, axis=0)
epoch_logs.update({ #displaying logs after each epoch
'discriminator_artificial_acc': epoch_artificial_dn_loss[1],
'discriminator_artificial_loss': epoch_artificial_dn_loss[0],
'discriminator_real_acc': epoch_real_dn_loss[1],
'discriminator_real_loss': epoch_real_dn_loss[0],
'generator_loss': epoch_gn_loss[0]
})
self.plotter.plot_epoch_result(epoch, self.input_shape[2]) #plotting performance after each epoch as a grid
callbacks.on_epoch_end(epoch, epoch_logs)
self.plotter.plot_history(history)
if self.cgan_model.stop_training:
break
callbacks.on_train_end()
self.plotter.plot_history(history)
return history