Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better 2sat impl #4568

Merged
merged 2 commits into from
Jul 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 44 additions & 52 deletions content/6_Advanced/SCC.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -425,62 +425,55 @@ struct Clause {
bool neg2;
};

class SATSolver {
private:
vector<vector<int>> adj;
bool valid = true;
vector<int> val;

/** @return the negation of the variable v */
int get_neg(int v) { return v % 2 == 1 ? v - 1 : v + 1; }
/**
* @return a vector of booleans that satisfy the given clauses,
* or an empty vector if no such set of booleans exist
*/
vector<int> solve_sat(const vector<Clause> &clauses, int var_num) {
vector<vector<int>> adj(2 * var_num);
// 2 * var is the variable, while 2 * var + 1 is its negation
for (const Clause &c : clauses) {
// falseness of the first implies the truth of the second
adj[2 * c.var1 + !c.neg1].push_back(2 * c.var2 + c.neg2);
// and vice versa
adj[2 * c.var2 + !c.neg2].push_back(2 * c.var1 + c.neg1);
}

public:
SATSolver(const vector<Clause> &clauses, int var_num)
: adj(2 * var_num), val(2 * var_num, -1) {
// 2 * var is the variable, while 2 * var + 1 is its negation
for (const Clause &c : clauses) {
// falseness of the first implies the truth of the second
adj[2 * c.var1 + !c.neg1].push_back(2 * c.var2 + c.neg2);
// and vice versa
adj[2 * c.var2 + !c.neg2].push_back(2 * c.var1 + c.neg1);
}
TarjanSolver scc(adj);
// a list of all the components in the graph
vector<vector<int>> comps(scc.comp_num());
for (int i = 0; i < 2 * var_num; i += 2) {
// do a node and its negation share the same component?
if (scc.get_comp(i) == scc.get_comp(i + 1)) { return {}; }
comps[scc.get_comp(i)].push_back(i);
comps[scc.get_comp(i + 1)].push_back(i + 1);
}

TarjanSolver scc(adj);
// a list of all the components in the graph
vector<vector<int>> comps(scc.comp_num());
for (int i = 0; i < 2 * var_num; i += 2) {
// do a node and its negation share the same component?
if (scc.get_comp(i) == scc.get_comp(i + 1)) {
valid = false;
return;
vector<int> val(2 * var_num, -1);
/*
* because of how our tarjan solver works, starting from
* starting from comp 0 and going up process the graph
* in reverse topological order- neat, huh?
*/
for (const vector<int> &comp : comps) {
int set_to = 1; // set all to true by default
// check if any values have had their negations set yet
for (int v : comp) {
int neg = v % 2 == 1 ? v - 1 : v + 1;
if (val[neg] != -1) {
set_to = !val[neg];
break;
}
comps[scc.get_comp(i)].push_back(i);
comps[scc.get_comp(i + 1)].push_back(i + 1);
}

/*
* because of how our tarjan solver works, starting from
* starting from comp 0 and going up process the graph
* in reverse topological order- neat, huh?
*/
for (const vector<int> &comp : comps) {
int set_to = 1; // set all to true by default
// check if any values have had their negations set yet
for (int v : comp) {
if (val[get_neg(v)] != -1) {
set_to = !val[get_neg(v)];
break;
}
}

for (int v : comp) { val[v] = set_to; }
}
for (int v : comp) { val[v] = set_to; }
}

bool is_valid() const { return valid; }
vector<int> actual_val(var_num);
for (int i = 0; i < var_num; i++) { actual_val[i] = val[2 * i]; }

bool get_var(int var) const { return val[2 * var]; }
};
return actual_val;
}

int main() {
int req_num;
Expand All @@ -496,14 +489,13 @@ int main() {
c.neg2 = neg2 == '-';
}

SATSolver sat(clauses, topping_num);
if (!sat.is_valid()) {
vector<int> sat_res = solve_sat(clauses, topping_num);
if (sat_res.empty()) {
cout << "IMPOSSIBLE" << endl;
} else {
for (int t = 0; t < topping_num; t++) {
cout << (sat.get_var(t) ? '+' : '-') << ' ';
cout << (sat_res[t] ? '+' : '-') << " \n"[t == topping_num - 1];
}
cout << endl;
}
}
```
Expand Down
Loading