#!/usr/bin/env python
# coding: utf-8
# # Solving SAT with gradient descent
#
# Using sigmoid functions, we can define smooth approximations of boolean operations. This can be used to solve small instances of SAT using gradient descent. We map:
#
# \begin{align}
# \texttt{True} &\mapsto 1.0 \\
# \texttt{False} &\mapsto -1.0
# \end{align}
#
# We then use a compressed sigmoid function to approximate the `sgn` function:
# \begin{align}
# \texttt{sgn}(x) &= \frac{2}{e^{-10x} + 1} - 1 \\
# \texttt{sgn}(x) &\approx -1\textrm{ if }x < 0 \\
# \texttt{sgn}(x) &\approx 1\textrm{ otherwise.} \\
# \end{align}
#
# This can be used to define boolean operations:
# \begin{align}
# \texttt{not}(x) &:= -x \\
# \texttt{and}(x, y) &:= \texttt{sgn}(x + y + x \cdot y) \\
# \texttt{or}(x, y) &:= -\texttt{and}(-x, -y)
# \end{align}
#
# For a given instance of SAT, we can then create a smooth function from inputs to its truth value. We can then parse [Tough SAT](https://toughsat.appspot.com/) factoring SAT instances to generate circuits of increasing difficulty, and use gradient descent to solve them.
# In[2]:
from sympy import *
import sympy, math
from itertools import *
init_printing()
x, y, z = symbols('x y z')
sgn_expr = 2 / (exp(-10*x) + 1) - 1
plot(sgn_expr, (x, -3, 3))
def sgn(x):
exp = sympy.exp if isinstance(x, Expr) else math.exp
return 2 / (exp(-10*x) + 1) - 1
print(f'sgn(-1)={sgn(-1)}')
print(f'sgn(+1)={sgn(1)}')
bool_to_value = {
True: 1,
False: -1,
}
get_bool = lambda x: x >= 0
truth_values = list(product([True, False], [True, False]))
def tabulate(table):
rows = []
for i, row in enumerate(table):
tag = 'td' if i != 0 else 'th'
rows.append(f"
{''.join(f'<{tag}>{value}{tag}>' for value in row)}
")
return f""
def truth_table(f):
from IPython.display import HTML, display
import inspect
nargs = len(inspect.signature(f).parameters)
if nargs == 1:
table = [[True, False], [f(bool_to_value[b]) for b in [True, False]]]
elif nargs == 2:
table = [['', 'True', 'False']]
for b_row in [True, False]:
row = [str(b_row)]
for b_col in [True, False]:
row.append(f(bool_to_value[b_row], bool_to_value[b_col]))
table.append(row)
display(HTML(tabulate(table)))
truth_table(sgn)
# In[3]:
not_ = lambda x: -x
truth_table(not_)
# In[5]:
and_ = lambda x, y: sgn(x + y + x * y)
truth_table(and_)
# In[6]:
or_ = lambda x, y: not_(and_(not_(x), not_(y)))
truth_table(or_)
# In[7]:
or_(x, y)
# In[8]:
from sympy.plotting import plot3d
plot3d(or_(x, y), (x, -1, 1), (y, -1, 1))
# and zooming in to view the gradient near the origin:
# In[10]:
plot3d(or_(x, y), (x, -0.25, 0.25), (y, -0.25, 0.25))
# Now we define some helpers to parse the [DIMACS format](http://logic.pdmi.ras.ru/~basolver/dimacs.html) that Tough SAT outputs. Fortunately, `sympy` already has some helpers that we'll copy and slightly modify.
# In[11]:
from functools import reduce
from sympy.logic.utilities.dimacs import load as sympy_load_dimacs
def load_dimacs(s):
"""Loads a boolean expression from a string.
Based off sympy.logic.utilities.dimacs
https://github.com/sympy/sympy/blob/57fcd5a941d7c47106bd63fd7b3d79ac032b636b/sympy/logic/utilities/dimacs.py
"""
import re
And = lambda *args: reduce(and_, args)
Or = lambda *args: reduce(or_, args)
clauses = []
lines = s.split('\n')
pComment = re.compile(r'c.*')
pStats = re.compile(r'p\s*cnf\s*(\d*)\s*(\d*)')
while len(lines) > 0:
line = lines.pop(0)
# Only deal with lines that aren't comments
if not pComment.match(line):
m = pStats.match(line)
if not m:
nums = line.rstrip('\n').split(' ')
list = []
for lit in nums:
if lit != '':
if int(lit) == 0:
continue
num = abs(int(lit))
sign = True
if int(lit) < 0:
sign = False
if sign:
list.append(Symbol("cnf_%s" % num))
else:
list.append(not_(Symbol("cnf_%s" % num)))
if len(list) > 0:
clauses.append(Or(*list))
return And(*clauses)
# Now let's try to factor the number $35=5\cdot7$. The variables in the dimacs output represent the binary expansion of the factors. We'll use the `sympy.logic` module to check our work.
#
# Recall that `7=0b111` and `5=0b101`, so the circuit should have two solutions: `111101` and `101111` since multiplication is commutative.
# In[12]:
# SAT instance from https://toughsat.appspot.com/ trying to factor 35 using karatsuba multiplication
factor_35_dimacs = '''\
p cnf 6 6
c Factors encoded in variables 1-3 and 4-6
c Target number: 35
-2 -5 0
1 0
3 0
4 0
6 0
2 5 0
'''
float_circuit = load_dimacs(factor_35_dimacs)
exact_circuit = sympy_load_dimacs(factor_35_dimacs)
possibilities = product(*([[True, False]] * 6))
vars = sorted(exact_circuit.atoms(), key=str)
for possibility in possibilities:
float_subs = float_circuit.subs({
var: bool_to_value[val]
for var, val in zip(vars, possibility)
}).evalf()
exact_subs = exact_circuit.subs({
var: val
for var, val in zip(vars, possibility)
})
assert exact_subs == get_bool(float_subs)
if exact_subs:
print(''.join(map(str, map(int, possibility))) + f'={exact_subs}, {get_bool(float_subs)}')
# Now we've shown we can load a small circuit, and run it using floating point math to arrive at the same solution as sympy's existing logic. Now let's try to factor 35 by gradient descent.
#
# For efficiency, use `sympy.lambdify` to compile the very complex expression into a python method.
# In[13]:
# The circuit is satisfied if its output is close to 1, so create a loss function to do that:
import numpy as np
loss_expr = (float_circuit - 1) ** 2
_loss_function = lambdify(vars, loss_expr, modules='numpy')
loss_function = lambda array: _loss_function(*array)
# Check that the loss function is very close to zero at an actual solution to the SAT instance:
# In[14]:
print(loss_function(np.array([1, 1, 1, 1, -1, 1])))
# In[15]:
from scipy.optimize import minimize
res = minimize(loss_function,
np.zeros(6),
method='nelder-mead',
options={'xtol': 1e-30, 'disp': True})
evaluated = float_circuit.subs({var: val for var, val in zip(vars, res.x)}).evalf()
print(f'solution={res.x},\nevaluated={evaluated}\n')
normalized_solution = [np.sign(v) for v in res.x]
normalized_evaluated = float_circuit.subs({var: val for var, val in zip(vars, normalized_solution)}).evalf()
print(f'normalized_solution={normalized_solution},\nevaluated={evaluated}\n')
print(f'loss_function(res.x)={loss_function(res.x)}')
# It worked! The optimization procedure found the `111101` ($35=7\cdot 5$) solution:
# In[16]:
float_circuit.subs({var: val for var, val in zip(vars, res.x)}).evalf()
# # Putting it all together:
# In[17]:
def solve_SAT_with_gradient_descent(dimacs_text):
float_circuit = load_dimacs(dimacs_text)
loss_expr = (float_circuit - 1) ** 2
vars = sorted(float_circuit.free_symbols, key=str)
print(vars)
_loss_function = lambdify(vars, loss_expr, modules='numpy')
loss_function = lambda array: _loss_function(*array)
res = minimize(
loss_function,
np.zeros(len(vars)),
method='nelder-mead',
options={'xtol': 1e-30, 'disp': True})
print(res.x)
return {var: get_bool(x) for var, x in zip(vars, res.x)}
def check(dimacs_text):
solution = solve_SAT_with_gradient_descent(dimacs_text)
print(solution)
print(f'Does the solution work? Does True=={sympy_load_dimacs(dimacs_text).subs(solution)}?')
check(factor_35_dimacs)
# # Something _slightly_ bigger
#
# When we try to factor $143=11\cdot 13$, even constructing the function to evaluate is slow, and the optimization fails to converge.
#
# Ah well, I've heard this SAT thing is hard.
# In[18]:
factor_11_times_13 = '''\
p cnf 8 12
c Factors encoded in variables 1-4 and 5-8
c Target number: 143
c Factors: 13 x 11
3 2 0
2 6 0
1 0
-3 -2 0
8 0
-2 -7 -6 0
4 0
-3 -2 -6 0
5 0
-2 -6 0
-3 -7 0
3 7 0
'''
# In[73]:
get_ipython().run_line_magic('time', 'check(factor_11_times_13)')