# Solving SAT with gradient descent¶

Using sigmoid functions, we can define smooth approximations of boolean operations. This can be used to solve small instances of SAT using gradient descent. We map:

\begin{align} \texttt{True} &\mapsto 1.0 \\ \texttt{False} &\mapsto -1.0 \end{align}

We then use a compressed sigmoid function to approximate the sgn function: \begin{align} \texttt{sgn}(x) &= \frac{2}{e^{-10x} + 1} - 1 \\ \texttt{sgn}(x) &\approx -1\textrm{ if }x < 0 \\ \texttt{sgn}(x) &\approx 1\textrm{ otherwise.} \\ \end{align}

This can be used to define boolean operations: \begin{align} \texttt{not}(x) &:= -x \\ \texttt{and}(x, y) &:= \texttt{sgn}(x + y + x \cdot y) \\ \texttt{or}(x, y) &:= -\texttt{and}(-x, -y) \end{align}

For a given instance of SAT, we can then create a smooth function from inputs to its truth value. We can then parse Tough SAT factoring SAT instances to generate circuits of increasing difficulty, and use gradient descent to solve them.

In [2]:
from sympy import *
import sympy, math
from itertools import *
init_printing()
x, y, z = symbols('x y z')

sgn_expr = 2 / (exp(-10*x) + 1) - 1
plot(sgn_expr, (x, -3, 3))
def sgn(x):
exp = sympy.exp if isinstance(x, Expr) else math.exp
return 2 / (exp(-10*x) + 1) - 1
print(f'sgn(-1)={sgn(-1)}')
print(f'sgn(+1)={sgn(1)}')

bool_to_value = {
True: 1,
False: -1,
}
get_bool = lambda x: x >= 0
truth_values = list(product([True, False], [True, False]))

def tabulate(table):
rows = []
for i, row in enumerate(table):
tag = 'td' if i != 0 else 'th'
rows.append(f"<tr>{''.join(f'<{tag}>{value}</{tag}>' for value in row)}</tr>")
return f"<table>{''.join(rows)}</table>"

def truth_table(f):
from IPython.display import HTML, display
import inspect
nargs = len(inspect.signature(f).parameters)
if nargs == 1:
table = [[True, False], [f(bool_to_value[b]) for b in [True, False]]]
elif nargs == 2:
table = [['', 'True', 'False']]
for b_row in [True, False]:
row = [str(b_row)]
for b_col in [True, False]:
row.append(f(bool_to_value[b_row], bool_to_value[b_col]))
table.append(row)
display(HTML(tabulate(table)))

truth_table(sgn)

sgn(-1)=-0.9999092042625951
sgn(+1)=0.9999092042625952

TrueFalse
0.9999092042625952-0.9999092042625951
In [3]:
not_ = lambda x: -x
truth_table(not_)

TrueFalse
-11
In [5]:
and_ = lambda x, y: sgn(x + y + x * y)
truth_table(and_)

TrueFalse
True0.999999999999813-0.9999092042625951
False-0.9999092042625951-0.9999092042625951
In [6]:
or_ = lambda x, y: not_(and_(not_(x), not_(y)))
truth_table(or_)

TrueFalse
True0.99990920426259510.9999092042625951
False0.9999092042625951-0.999999999999813
In [7]:
or_(x, y)

Out[7]:
$$1 - \frac{2}{e^{- 10 x y + 10 x + 10 y} + 1}$$
In [8]:
from sympy.plotting import plot3d
plot3d(or_(x, y), (x, -1, 1), (y, -1, 1))

Out[8]:
<sympy.plotting.plot.Plot at 0x10b38c588>

and zooming in to view the gradient near the origin:

In [10]:
plot3d(or_(x, y), (x, -0.25, 0.25), (y, -0.25, 0.25))

Out[10]:
<sympy.plotting.plot.Plot at 0x10bca5eb8>

Now we define some helpers to parse the DIMACS format that Tough SAT outputs. Fortunately, sympy already has some helpers that we'll copy and slightly modify.

In [11]:
from functools import reduce

"""Loads a boolean expression from a string.

Based off sympy.logic.utilities.dimacs
https://github.com/sympy/sympy/blob/57fcd5a941d7c47106bd63fd7b3d79ac032b636b/sympy/logic/utilities/dimacs.py
"""
import re

And = lambda *args: reduce(and_, args)
Or = lambda *args: reduce(or_, args)
clauses = []

lines = s.split('\n')

pComment = re.compile(r'c.*')
pStats = re.compile(r'p\s*cnf\s*(\d*)\s*(\d*)')

while len(lines) > 0:
line = lines.pop(0)

# Only deal with lines that aren't comments
if not pComment.match(line):
m = pStats.match(line)

if not m:
nums = line.rstrip('\n').split(' ')
list = []
for lit in nums:
if lit != '':
if int(lit) == 0:
continue
num = abs(int(lit))
sign = True
if int(lit) < 0:
sign = False

if sign:
list.append(Symbol("cnf_%s" % num))
else:
list.append(not_(Symbol("cnf_%s" % num)))

if len(list) > 0:
clauses.append(Or(*list))

return And(*clauses)


Now let's try to factor the number $35=5\cdot7$. The variables in the dimacs output represent the binary expansion of the factors. We'll use the sympy.logic module to check our work.

Recall that 7=0b111 and 5=0b101, so the circuit should have two solutions: 111101 and 101111 since multiplication is commutative.

In [12]:
# SAT instance from https://toughsat.appspot.com/ trying to factor 35 using karatsuba multiplication
factor_35_dimacs = '''\
p cnf 6 6
c Factors encoded in variables 1-3 and 4-6
c Target number: 35
-2 -5 0
1 0
3 0
4 0
6 0
2 5 0
'''

possibilities = product(*([[True, False]] * 6))
vars = sorted(exact_circuit.atoms(), key=str)
for possibility in possibilities:
float_subs = float_circuit.subs({
var: bool_to_value[val]
for var, val in zip(vars, possibility)
}).evalf()
exact_subs = exact_circuit.subs({
var: val
for var, val in zip(vars, possibility)
})
assert exact_subs == get_bool(float_subs)
if exact_subs:
print(''.join(map(str, map(int, possibility))) + f'={exact_subs}, {get_bool(float_subs)}')

111101=True, True
101111=True, True


Now we've shown we can load a small circuit, and run it using floating point math to arrive at the same solution as sympy's existing logic. Now let's try to factor 35 by gradient descent.

For efficiency, use sympy.lambdify to compile the very complex expression into a python method.

In [13]:
# The circuit is satisfied if its output is close to 1, so create a loss function to do that:
import numpy as np
loss_expr = (float_circuit - 1) ** 2
_loss_function = lambdify(vars, loss_expr, modules='numpy')
loss_function = lambda array: _loss_function(*array)


Check that the loss function is very close to zero at an actual solution to the SAT instance:

In [14]:
print(loss_function(np.array([1, 1, 1, 1, -1, 1])))

3.5120876361344667e-26

In [15]:
from scipy.optimize import minimize

res = minimize(loss_function,
np.zeros(6),
options={'xtol': 1e-30, 'disp': True})
evaluated = float_circuit.subs({var: val for var, val in zip(vars, res.x)}).evalf()
print(f'solution={res.x},\nevaluated={evaluated}\n')
normalized_solution = [np.sign(v) for v in res.x]
normalized_evaluated = float_circuit.subs({var: val for var, val in zip(vars, normalized_solution)}).evalf()
print(f'normalized_solution={normalized_solution},\nevaluated={evaluated}\n')
print(f'loss_function(res.x)={loss_function(res.x)}')

Optimization terminated successfully.
Current function value: 0.000000
Iterations: 266
Function evaluations: 751
solution=[ 0.96007168  1.37241084  0.49484629  0.67831874 -0.70099425  0.13773547],
evaluated=0.999999999999813

normalized_solution=[1.0, 1.0, 1.0, 1.0, -1.0, 1.0],
evaluated=0.999999999999813

loss_function(res.x)=3.495462392556934e-26


It worked! The optimization procedure found the 111101 ($35=7\cdot 5$) solution:

In [16]:
float_circuit.subs({var: val for var, val in zip(vars, res.x)}).evalf()

Out[16]:
$$0.999999999999813$$

# Putting it all together:¶

In [17]:
def solve_SAT_with_gradient_descent(dimacs_text):
loss_expr = (float_circuit - 1) ** 2
vars = sorted(float_circuit.free_symbols, key=str)
print(vars)
_loss_function = lambdify(vars, loss_expr, modules='numpy')
loss_function = lambda array: _loss_function(*array)
res = minimize(
loss_function,
np.zeros(len(vars)),
options={'xtol': 1e-30, 'disp': True})
print(res.x)
return {var: get_bool(x) for var, x in zip(vars, res.x)}

def check(dimacs_text):
print(solution)
print(f'Does the solution work? Does True=={sympy_load_dimacs(dimacs_text).subs(solution)}?')

check(factor_35_dimacs)

[cnf_1, cnf_2, cnf_3, cnf_4, cnf_5, cnf_6]
Optimization terminated successfully.
Current function value: 0.000000
Iterations: 266
Function evaluations: 751
[ 0.96007168  1.37241084  0.49484629  0.67831874 -0.70099425  0.13773547]
{cnf_1: True, cnf_2: True, cnf_3: True, cnf_4: True, cnf_5: False, cnf_6: True}
Does the solution work? Does True==True?


# Something slightly bigger¶

When we try to factor $143=11\cdot 13$, even constructing the function to evaluate is slow, and the optimization fails to converge.

Ah well, I've heard this SAT thing is hard.

In [18]:
factor_11_times_13 = '''\
p cnf 8 12
c Factors encoded in variables 1-4 and 5-8
c Target number: 143
c Factors: 13 x 11
3 2 0
2 6 0
1 0
-3 -2 0
8 0
-2 -7 -6 0
4 0
-3 -2 -6 0
5 0
-2 -6 0
-3 -7 0
3 7 0
'''

In [73]:
%time check(factor_11_times_13)

[cnf_1, cnf_2, cnf_3, cnf_4, cnf_5, cnf_6, cnf_7, cnf_8]
Warning: Maximum number of function evaluations has been exceeded.
[-0.06463293 -0.04806722  0.01353296  0.10877588 -0.1979872   0.0876021
0.05516384  0.01156727]
{cnf_1: False, cnf_2: False, cnf_3: True, cnf_4: True, cnf_5: False, cnf_6: True, cnf_7: True, cnf_8: True}
Does the solution work? Does True==False?
CPU times: user 8min 19s, sys: 2.22 s, total: 8min 22s
Wall time: 8min 23s