Skip to content

Commit fa8fbe1

Browse files
Edouard GraveFacebook Github Bot
Edouard Grave
authored and
Facebook Github Bot
committed
Moved sigmoid and log functions inside Model class
Summary: Moved the sigmoid and log functions inside the Model class. No more need for the thread-unsafe `fasttext::utils::initTables()` function. Reviewed By: ajoulin Differential Revision: D4052203 fbshipit-source-id: ca593fdc83c1ed96abc5db5f09b8db2154062efa
1 parent 2211639 commit fa8fbe1

File tree

5 files changed

+69
-83
lines changed

5 files changed

+69
-83
lines changed

src/main.cc

-2
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ void train(int argc, char** argv) {
129129
}
130130

131131
int main(int argc, char** argv) {
132-
utils::initTables();
133132
if (argc < 2) {
134133
printUsage();
135134
exit(EXIT_FAILURE);
@@ -147,6 +146,5 @@ int main(int argc, char** argv) {
147146
printUsage();
148147
exit(EXIT_FAILURE);
149148
}
150-
utils::freeTables();
151149
return 0;
152150
}

src/model.cc

+51-9
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,24 @@ Model::Model(std::shared_ptr<Matrix> wi,
3232
negpos = 0;
3333
loss_ = 0.0;
3434
nexamples_ = 1;
35+
initSigmoid();
36+
initLog();
37+
}
38+
39+
Model::~Model() {
40+
delete[] t_sigmoid;
41+
delete[] t_log;
3542
}
3643

3744
real Model::binaryLogistic(int32_t target, bool label, real lr) {
38-
real score = utils::sigmoid(wo_->dotRow(hidden_, target));
45+
real score = sigmoid(wo_->dotRow(hidden_, target));
3946
real alpha = lr * (real(label) - score);
4047
grad_.addRow(*wo_, target, alpha);
4148
wo_->addRow(hidden_, target, alpha);
4249
if (label) {
43-
return -utils::log(score);
50+
return -log(score);
4451
} else {
45-
return -utils::log(1.0 - score);
52+
return -log(1.0 - score);
4653
}
4754
}
4855

@@ -98,7 +105,7 @@ real Model::softmax(int32_t target, real lr) {
98105
grad_.addRow(*wo_, i, alpha);
99106
wo_->addRow(hidden_, i, alpha);
100107
}
101-
return -utils::log(output_[target]);
108+
return -log(output_[target]);
102109
}
103110

104111
void Model::computeHidden(const std::vector<int32_t>& input, Vector& hidden) const {
@@ -138,10 +145,10 @@ void Model::findKBest(int32_t k, std::vector<std::pair<real, int32_t>>& heap,
138145
Vector& hidden, Vector& output) const {
139146
computeOutputSoftmax(hidden, output);
140147
for (int32_t i = 0; i < osz_; i++) {
141-
if (heap.size() == k && utils::log(output[i]) < heap.front().first) {
148+
if (heap.size() == k && log(output[i]) < heap.front().first) {
142149
continue;
143150
}
144-
heap.push_back(std::make_pair(utils::log(output[i]), i));
151+
heap.push_back(std::make_pair(log(output[i]), i));
145152
std::push_heap(heap.begin(), heap.end(), comparePairs);
146153
if (heap.size() > k) {
147154
std::pop_heap(heap.begin(), heap.end(), comparePairs);
@@ -167,9 +174,9 @@ void Model::dfs(int32_t k, int32_t node, real score,
167174
return;
168175
}
169176

170-
real f = utils::sigmoid(wo_->dotRow(hidden, node - osz_));
171-
dfs(k, tree[node].left, score + utils::log(1.0 - f), heap, hidden);
172-
dfs(k, tree[node].right, score + utils::log(f), heap, hidden);
177+
real f = sigmoid(wo_->dotRow(hidden, node - osz_));
178+
dfs(k, tree[node].left, score + log(1.0 - f), heap, hidden);
179+
dfs(k, tree[node].right, score + log(f), heap, hidden);
173180
}
174181

175182
void Model::update(const std::vector<int32_t>& input, int32_t target, real lr) {
@@ -275,4 +282,39 @@ real Model::getLoss() const {
275282
return loss_ / nexamples_;
276283
}
277284

285+
void Model::initSigmoid() {
286+
t_sigmoid = new real[SIGMOID_TABLE_SIZE + 1];
287+
for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) {
288+
real x = real(i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID;
289+
t_sigmoid[i] = 1.0 / (1.0 + std::exp(-x));
290+
}
291+
}
292+
293+
void Model::initLog() {
294+
t_log = new real[LOG_TABLE_SIZE + 1];
295+
for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) {
296+
real x = (real(i) + 1e-5) / LOG_TABLE_SIZE;
297+
t_log[i] = std::log(x);
298+
}
299+
}
300+
301+
real Model::log(real x) const {
302+
if (x > 1.0) {
303+
return 0.0;
304+
}
305+
int i = int(x * LOG_TABLE_SIZE);
306+
return t_log[i];
307+
}
308+
309+
real Model::sigmoid(real x) const {
310+
if (x < -MAX_SIGMOID) {
311+
return 0.0;
312+
} else if (x > MAX_SIGMOID) {
313+
return 1.0;
314+
} else {
315+
int i = int((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2);
316+
return t_sigmoid[i];
317+
}
318+
}
319+
278320
}

src/model.h

+18-7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
#include "vector.h"
2121
#include "real.h"
2222

23+
#define SIGMOID_TABLE_SIZE 512
24+
#define MAX_SIGMOID 8
25+
#define LOG_TABLE_SIZE 512
26+
2327
namespace fasttext {
2428

2529
struct Node {
@@ -43,24 +47,29 @@ class Model {
4347
int32_t osz_;
4448
real loss_;
4549
int64_t nexamples_;
46-
47-
static bool comparePairs(const std::pair<real, int32_t>&,
48-
const std::pair<real, int32_t>&);
49-
50+
real* t_sigmoid;
51+
real* t_log;
52+
// used for negative sampling:
5053
std::vector<int32_t> negatives;
5154
size_t negpos;
52-
53-
int32_t getNegative(int32_t target);
54-
55+
// used for hierarchical softmax:
5556
std::vector< std::vector<int32_t> > paths;
5657
std::vector< std::vector<bool> > codes;
5758
std::vector<Node> tree;
5859

60+
static bool comparePairs(const std::pair<real, int32_t>&,
61+
const std::pair<real, int32_t>&);
62+
63+
int32_t getNegative(int32_t target);
64+
void initSigmoid();
65+
void initLog();
66+
5967
static const int32_t NEGATIVE_TABLE_SIZE = 10000000;
6068

6169
public:
6270
Model(std::shared_ptr<Matrix>, std::shared_ptr<Matrix>,
6371
std::shared_ptr<Args>, int32_t);
72+
~Model();
6473

6574
real binaryLogistic(int32_t, bool, real);
6675
real negativeSampling(int32_t, real);
@@ -86,6 +95,8 @@ class Model {
8695
void initTableNegatives(const std::vector<int64_t>&);
8796
void buildTree(const std::vector<int64_t>&);
8897
real getLoss() const;
98+
real sigmoid(real) const;
99+
real log(real) const;
89100

90101
std::minstd_rand rng;
91102
};

src/utils.cc

-51
Original file line numberDiff line numberDiff line change
@@ -15,57 +15,6 @@
1515
namespace fasttext {
1616

1717
namespace utils {
18-
real* t_sigmoid = nullptr;
19-
real* t_log = nullptr;
20-
21-
real log(real x) {
22-
if (x > 1.0) {
23-
return 0.0;
24-
}
25-
int i = int(x * LOG_TABLE_SIZE);
26-
return t_log[i];
27-
}
28-
29-
real sigmoid(real x) {
30-
if (x < -MAX_SIGMOID) {
31-
return 0.0;
32-
} else if (x > MAX_SIGMOID) {
33-
return 1.0;
34-
} else {
35-
int i = int((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2);
36-
return t_sigmoid[i];
37-
}
38-
}
39-
40-
void initTables() {
41-
initSigmoid();
42-
initLog();
43-
}
44-
45-
void initSigmoid() {
46-
if (t_sigmoid != nullptr) return;
47-
t_sigmoid = new real[SIGMOID_TABLE_SIZE + 1];
48-
for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) {
49-
real x = real(i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID;
50-
t_sigmoid[i] = 1.0 / (1.0 + std::exp(-x));
51-
}
52-
}
53-
54-
void initLog() {
55-
if (t_log != nullptr) return;
56-
t_log = new real[LOG_TABLE_SIZE + 1];
57-
for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) {
58-
real x = (real(i) + 1e-5) / LOG_TABLE_SIZE;
59-
t_log[i] = std::log(x);
60-
}
61-
}
62-
63-
void freeTables() {
64-
delete[] t_sigmoid;
65-
delete[] t_log;
66-
t_sigmoid = nullptr;
67-
t_log = nullptr;
68-
}
6918

7019
int64_t size(std::ifstream& ifs) {
7120
ifs.seekg(std::streamoff(0), std::ios::end);

src/utils.h

-14
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,10 @@
1212

1313
#include <fstream>
1414

15-
#include "real.h"
16-
17-
#define SIGMOID_TABLE_SIZE 512
18-
#define MAX_SIGMOID 8
19-
#define LOG_TABLE_SIZE 512
20-
2115
namespace fasttext {
2216

2317
namespace utils {
2418

25-
real log(real);
26-
real sigmoid(real);
27-
28-
void initTables();
29-
void initSigmoid();
30-
void initLog();
31-
void freeTables();
32-
3319
int64_t size(std::ifstream&);
3420
void seek(std::ifstream&, int64_t);
3521
}

0 commit comments

Comments
 (0)