-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerator_amazon_movie_same_dist.py
104 lines (81 loc) · 3.17 KB
/
generator_amazon_movie_same_dist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import sys
import pickle
from datetime import datetime, timedelta
import time
import random
random.seed(42)
import re
import yaml
modes = ['bert_768', 'bow_50', 'bow_768']
if len(sys.argv) < 2 or sys.argv[1] not in modes:
print('Need mode {mode} as parameter!'.format(mode=modes))
exit(1)
mode = sys.argv[1]
num_samples = 500
num_permutations = 20 + 1
# amazon_raw_file generated by amazon_movie_sorter.py
# gensim_model_file available at https://hobbitdata.informatik.uni-leipzig.de/EML4U/2021-05-17-Amazon-Doc2Vec/
# generated by word2vec/doc2vec.py
# embeddings_file is to be generated here
amazon_raw_file = 'data/movies/embeddings/amazon_raw.pickle'
gensim_model_50_file = 'data/movies/amazonreviews_model/amazonreviews_50.model'
gensim_model_768_file = 'data/movies/amazonreviews_model/amazonreviews_768.model'
bert_model = 'data/movies/movie_9e'
embeddings_file = 'data/movies/embeddings/amazon_{mode}_same_dist.pickle'.format(mode=mode)
## Embed everything, save in chunks so the memory doesn't explode
# Configure model to use by mode
if(mode == "bert_768"):
from embedding import BertHuggingface
bert = BertHuggingface(5, batch_size=8)
bert.load(bert_model)
embed = bert.embed
elif(mode == "bow_50"):
print("gensim_model_50_file", gensim_model_50_file)
from word2vec.Word2Vec import Word2Vec
word2vec = Word2Vec(gensim_model_50_file)
word2vec.prepare()
embed = word2vec.embed
elif(mode == "bow_768"):
print("gensim_model_768_file", gensim_model_768_file)
from word2vec.Word2Vec import Word2Vec
word2vec = Word2Vec(gensim_model_768_file)
word2vec.prepare()
embed = word2vec.embed
else:
raise ValueError("Unknown mode " + mode)
with open(amazon_raw_file, 'rb') as handle:
texts, keys = pickle.load(handle)
for i in range(len(keys)):
keys[i][1] -= 1 # fix class names from 1..5 to 0..4 for easier 1-hot encoding
if os.path.isfile(embeddings_file): # Do not overwrite
print("Embeddings file already exists, exiting.", embeddings_file)
exit()
# gather amazon reviews of the fourth year only
classes = [[] for x in range(5)]
data = [x for x in list(zip(texts, keys)) if x[1][-2].year == 2011]
for point in data:
classes[point[1][1]].append(point)
# sanity check: dont create to many permutations if there are not enough datapoints
mini = min([len(x) for x in classes])
if mini < num_permutations * num_samples:
print('WARNING: too few samples for {n} permutations, doing '.format(num_permutations), end='')
num_permutations = int(mini/num_samples)
print(' {} permutations instead!'.format(num_permutations))
for i in range(len(classes)):
random.shuffle(classes[i])
data = []
for perm in range(num_permutations):
entry = []
for i in range(len(classes)):
entry.extend(classes[i][num_samples*perm:num_samples*(perm+1)])
data.append(entry)
embeddings = []
e_keys = []
for d in data:
emb_texts, emb_keys = [list(t) for t in zip(*d)]
embeddings.append(embed(emb_texts))
e_keys.append(emb_keys)
dump_data = {'data': (embeddings, e_keys)}
with open(embeddings_file, 'wb') as handle:
pickle.dump(dump_data, handle)