-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerator_twitter_same_dist.py
70 lines (51 loc) · 1.85 KB
/
generator_twitter_same_dist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import sys
import pickle
from datetime import datetime, timedelta
import time
import random
random.seed(42)
import re
import yaml
modes = ['bert_768', 'bow_50', 'bow_768']
if len(sys.argv) < 2 or sys.argv[1] not in modes:
print('Need mode {mode} as parameter!'.format(mode=modes))
exit(1)
mode = sys.argv[1]
num_samples = 500
num_permutations = 20 + 1
# embeddings_file is to be generated here
start_file_b = 'data/twitter/biden_{}_embeddings.pickle'.format(mode)
start_file_t = 'data/twitter/trump_{}_embeddings.pickle'.format(mode)
embeddings_file = 'data/twitter/twitter_{mode}_same_dist.pickle'.format(mode=mode)
if os.path.isfile(embeddings_file): # Do not overwrite
print("Embeddings file already exists, exiting.", embeddings_file)
exit()
with open(start_file_b, 'rb') as handle:
biden = pickle.load(handle)
with open(start_file_t, 'rb') as handle:
trump = pickle.load(handle)
# sanity check: dont create to many permutations if there are not enough datapoints
classes = [biden, trump]
mini = min([len(x) for x in classes])
if mini < num_permutations * num_samples:
print('WARNING: too few samples for {n} permutations, doing '.format(num_permutations), end='')
num_permutations = int(mini/num_samples)
print(' {} permutations instead!'.format(num_permutations))
for i in range(len(classes)):
random.shuffle(classes[i])
data = []
for perm in range(num_permutations):
entry = []
for i in range(len(classes)):
entry.extend(classes[i][num_samples*perm:num_samples*(perm+1)])
data.append(entry)
embeddings = []
e_keys = []
for d in data:
emb_texts, emb_keys = [list(t) for t in zip(*d)]
embeddings.append(emb_texts)
e_keys.append(emb_keys)
dump_data = {'data': (embeddings, e_keys)}
with open(embeddings_file, 'wb') as handle:
pickle.dump(dump_data, handle)