forked from ethanhe42/resnet-cifar10-caffe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_utils.py
93 lines (80 loc) · 2.55 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import sys
import cPickle as pickle
import numpy as np
import os
from scipy.misc import imread
import numpy as np
import lmdb
import caffe
def load_CIFAR_batch(filename, pad=True):
""" load single batch of cifar """
with open(filename, 'rb') as f:
datadict = pickle.load(f)
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).astype(np.uint8)
padded = np.zeros((10000, 3, 40, 40), dtype=np.uint8)
padded[:,:,:,:] = 128
padded[:,:,4:-4, 4:-4] = X
Y = np.array(Y, dtype=np.int64)
if not pad:
return X, Y
return padded, Y
def load_CIFAR10(ROOT):
""" load all of cifar """
xs = []
ys = []
for b in range(1,6):
f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
X, Y = load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
Xtr = np.concatenate(xs)
Ytr = np.concatenate(ys)
idx = np.arange(len(Ytr))
np.random.shuffle(idx)
print 'shuffle training data', len(idx)
Xtr = Xtr[idx]
Ytr = Ytr[idx]
print idx
print 'tr label',Ytr.min(), Ytr.max()
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'), pad=False)
print 'te label',Yte.min(), Yte.max()
print Xtr.shape
print Ytr.shape
print Xte.shape
print Yte.shape
return Xtr, Ytr, Xte, Yte
def py2lmdb(X, y, save_path):
# Let's pretend this is interesting data
assert X.dtype == np.uint8
N = X.shape[0]
assert N == y.shape[0], str(N) + ' ' + str(y.shape)
# We need to prepare the database for the size. We'll set it 10 times
# greater than what we theoretically need. There is little drawback to
# setting this too big. If you still run into problem after raising
# this, you might want to try saving fewer entries in a single
# transaction.
map_size = X.nbytes * 10
env = lmdb.open(save_path, map_size=map_size)
with env.begin(write=True) as txn:
# txn is a Transaction object
for i in range(N):
datum = caffe.proto.caffe_pb2.Datum()
datum.channels = X.shape[1]
datum.height = X.shape[2]
datum.width = X.shape[3]
datum.data = X[i].tobytes() # or .tostring() if numpy < 1.9
datum.label = int(y[i])
str_id = '{:08}'.format(i)
# The encode is only essential in Python 3
txn.put(str_id.encode('ascii'), datum.SerializeToString())
if __name__ == '__main__':
root = sys.argv[1]
Xtr, Ytr, Xte, Yte = load_CIFAR10(root)
paths = [ os.path.join(root, i) for i in ['train', 'test']]
py2lmdb(Xtr, Ytr, paths[0])
py2lmdb(Xte, Yte, paths[1])
for i in paths:
print 'saved to', i