-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
107 lines (84 loc) · 3.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 29 10:01:50 2019
@author: a-kojima
"""
import numpy as np
import numpy.matlib as npm
import soundfile as sf
from deep_waveform_dnn import deep_waveform_dnn
from tensorflow.keras import backend as K
from tensorflow.keras import models
import glob
from sklearn.preprocessing import LabelBinarizer
import sys
# ==================
# analysis parameters
# ==================
SAMPLING_FREQUENCY = 16000
AUDIO_SAMPLE_DUR = 0.275
NUMBER_OF_CHANNELS = 2
IMPULSE_RESPONSE_LENGTH = 400
NUMBER_OF_FILTERS = 40 # band-pss filter
FRAME_DUR = 0.025
FRAME_SHIFT_DUR = 0.01
NUMBER_OF_SAMPLE = 10000
DEBUG = False
GAMMATONE_WEIGHT_PATH = r'./gammatone_weight.npy' # for gamma-tone init
number_of_relu_layer = 1
number_of_relu_node = 320
class opt:
def zero_mean_unit_variance(sample):
return (sample - np.mean(sample)) / np.std(sample)
def convert_one_hot_vector(label_sequence):
trans = LabelBinarizer()
one_hot_vector = trans.fit_transform(label_sequence)
return one_hot_vector
audio_length = np.int(AUDIO_SAMPLE_DUR * SAMPLING_FREQUENCY)
'''
model_list = np.load(r'./model_list.npy')
data = np.zeros((1, audio_length, NUMBER_OF_CHANNELS), dtype=np.float32)
label = np.array([], dtype=np.str)
for model in model_list:
if DEBUG:
wavform_name_list = glob.glob('./' + str(model) + '/' + '**.wav')[0:30]
else:
wavform_name_list = glob.glob('./' + str(model) + '/' + '**.wav')
#print(len(wavform_name_list))
for wavform_name in wavform_name_list:
print(wavform_name)
wavform, _ = sf.read(wavform_name, dtype='float32')
wavform = opt.zero_mean_unit_variance(wavform)
data = np.concatenate((data, npm.reshape(wavform, [1, audio_length, NUMBER_OF_CHANNELS])))
label = np.append(label, str(model))
data = data[1:, :, :]
one_hot_vector = opt.convert_one_hot_vector(label)
print('label', np.shape(one_hot_vector))
print('data', np.shape(data))
np.save('label.npy', label)
np.save('data.npy', data)
'''
label = np.load('./label.npy')
data = np.load('./data.npy')
one_hot_vector = opt.convert_one_hot_vector(label)
K.set_learning_phase(1)
K.clear_session()
deep_waveform_dnn = deep_waveform_dnn(NUMBER_OF_FILTERS,
NUMBER_OF_CHANNELS,
audio_length,
IMPULSE_RESPONSE_LENGTH,
GAMMATONE_WEIGHT_PATH,
sampling_frequency=SAMPLING_FREQUENCY,
number_of_class=len(list(set(label))))
dwd = deep_waveform_dnn.get_model(number_of_dense_layer=number_of_relu_layer,
number_of_ff_node=number_of_relu_node,
optimizer='sgd',
learning_rate=0.01)
dwd.load_weights(r'./deep_waveform_dnn_weight.hdf5')
"""
dwd = deep_waveform_dnn.gammatone_init(dwd)
print(dwd.summary())
train_model = deep_waveform_dnn.train(data, one_hot_vector, dwd, batch_size=batch_size, epochs=17, validation_split=0.2)
#models.save_model(train_model, model_name)
train_model.save_weights('./deep_waveform_dnn_weight.hdf5')
"""