#!/usr/bin/env python # coding: utf-8 # # Solving SAT with gradient descent # # Using sigmoid functions, we can define smooth approximations of boolean operations. This can be used to solve small instances of SAT using gradient descent. We map: # # \begin{align} # \texttt{True} &\mapsto 1.0 \\ # \texttt{False} &\mapsto -1.0 # \end{align} # # We then use a compressed sigmoid function to approximate the `sgn` function: # \begin{align} # \texttt{sgn}(x) &= \frac{2}{e^{-10x} + 1} - 1 \\ # \texttt{sgn}(x) &\approx -1\textrm{ if }x < 0 \\ # \texttt{sgn}(x) &\approx 1\textrm{ otherwise.} \\ # \end{align} # # This can be used to define boolean operations: # \begin{align} # \texttt{not}(x) &:= -x \\ # \texttt{and}(x, y) &:= \texttt{sgn}(x + y + x \cdot y) \\ # \texttt{or}(x, y) &:= -\texttt{and}(-x, -y) # \end{align} # # For a given instance of SAT, we can then create a smooth function from inputs to its truth value. We can then parse [Tough SAT](https://toughsat.appspot.com/) factoring SAT instances to generate circuits of increasing difficulty, and use gradient descent to solve them. # In[2]: from sympy import * import sympy, math from itertools import * init_printing() x, y, z = symbols('x y z') sgn_expr = 2 / (exp(-10*x) + 1) - 1 plot(sgn_expr, (x, -3, 3)) def sgn(x): exp = sympy.exp if isinstance(x, Expr) else math.exp return 2 / (exp(-10*x) + 1) - 1 print(f'sgn(-1)={sgn(-1)}') print(f'sgn(+1)={sgn(1)}') bool_to_value = { True: 1, False: -1, } get_bool = lambda x: x >= 0 truth_values = list(product([True, False], [True, False])) def tabulate(table): rows = [] for i, row in enumerate(table): tag = 'td' if i != 0 else 'th' rows.append(f"{''.join(f'<{tag}>{value}' for value in row)}") return f"{''.join(rows)}
" def truth_table(f): from IPython.display import HTML, display import inspect nargs = len(inspect.signature(f).parameters) if nargs == 1: table = [[True, False], [f(bool_to_value[b]) for b in [True, False]]] elif nargs == 2: table = [['', 'True', 'False']] for b_row in [True, False]: row = [str(b_row)] for b_col in [True, False]: row.append(f(bool_to_value[b_row], bool_to_value[b_col])) table.append(row) display(HTML(tabulate(table))) truth_table(sgn) # In[3]: not_ = lambda x: -x truth_table(not_) # In[5]: and_ = lambda x, y: sgn(x + y + x * y) truth_table(and_) # In[6]: or_ = lambda x, y: not_(and_(not_(x), not_(y))) truth_table(or_) # In[7]: or_(x, y) # In[8]: from sympy.plotting import plot3d plot3d(or_(x, y), (x, -1, 1), (y, -1, 1)) # and zooming in to view the gradient near the origin: # In[10]: plot3d(or_(x, y), (x, -0.25, 0.25), (y, -0.25, 0.25)) # Now we define some helpers to parse the [DIMACS format](http://logic.pdmi.ras.ru/~basolver/dimacs.html) that Tough SAT outputs. Fortunately, `sympy` already has some helpers that we'll copy and slightly modify. # In[11]: from functools import reduce from sympy.logic.utilities.dimacs import load as sympy_load_dimacs def load_dimacs(s): """Loads a boolean expression from a string. Based off sympy.logic.utilities.dimacs https://github.com/sympy/sympy/blob/57fcd5a941d7c47106bd63fd7b3d79ac032b636b/sympy/logic/utilities/dimacs.py """ import re And = lambda *args: reduce(and_, args) Or = lambda *args: reduce(or_, args) clauses = [] lines = s.split('\n') pComment = re.compile(r'c.*') pStats = re.compile(r'p\s*cnf\s*(\d*)\s*(\d*)') while len(lines) > 0: line = lines.pop(0) # Only deal with lines that aren't comments if not pComment.match(line): m = pStats.match(line) if not m: nums = line.rstrip('\n').split(' ') list = [] for lit in nums: if lit != '': if int(lit) == 0: continue num = abs(int(lit)) sign = True if int(lit) < 0: sign = False if sign: list.append(Symbol("cnf_%s" % num)) else: list.append(not_(Symbol("cnf_%s" % num))) if len(list) > 0: clauses.append(Or(*list)) return And(*clauses) # Now let's try to factor the number $35=5\cdot7$. The variables in the dimacs output represent the binary expansion of the factors. We'll use the `sympy.logic` module to check our work. # # Recall that `7=0b111` and `5=0b101`, so the circuit should have two solutions: `111101` and `101111` since multiplication is commutative. # In[12]: # SAT instance from https://toughsat.appspot.com/ trying to factor 35 using karatsuba multiplication factor_35_dimacs = '''\ p cnf 6 6 c Factors encoded in variables 1-3 and 4-6 c Target number: 35 -2 -5 0 1 0 3 0 4 0 6 0 2 5 0 ''' float_circuit = load_dimacs(factor_35_dimacs) exact_circuit = sympy_load_dimacs(factor_35_dimacs) possibilities = product(*([[True, False]] * 6)) vars = sorted(exact_circuit.atoms(), key=str) for possibility in possibilities: float_subs = float_circuit.subs({ var: bool_to_value[val] for var, val in zip(vars, possibility) }).evalf() exact_subs = exact_circuit.subs({ var: val for var, val in zip(vars, possibility) }) assert exact_subs == get_bool(float_subs) if exact_subs: print(''.join(map(str, map(int, possibility))) + f'={exact_subs}, {get_bool(float_subs)}') # Now we've shown we can load a small circuit, and run it using floating point math to arrive at the same solution as sympy's existing logic. Now let's try to factor 35 by gradient descent. # # For efficiency, use `sympy.lambdify` to compile the very complex expression into a python method. # In[13]: # The circuit is satisfied if its output is close to 1, so create a loss function to do that: import numpy as np loss_expr = (float_circuit - 1) ** 2 _loss_function = lambdify(vars, loss_expr, modules='numpy') loss_function = lambda array: _loss_function(*array) # Check that the loss function is very close to zero at an actual solution to the SAT instance: # In[14]: print(loss_function(np.array([1, 1, 1, 1, -1, 1]))) # In[15]: from scipy.optimize import minimize res = minimize(loss_function, np.zeros(6), method='nelder-mead', options={'xtol': 1e-30, 'disp': True}) evaluated = float_circuit.subs({var: val for var, val in zip(vars, res.x)}).evalf() print(f'solution={res.x},\nevaluated={evaluated}\n') normalized_solution = [np.sign(v) for v in res.x] normalized_evaluated = float_circuit.subs({var: val for var, val in zip(vars, normalized_solution)}).evalf() print(f'normalized_solution={normalized_solution},\nevaluated={evaluated}\n') print(f'loss_function(res.x)={loss_function(res.x)}') # It worked! The optimization procedure found the `111101` ($35=7\cdot 5$) solution: # In[16]: float_circuit.subs({var: val for var, val in zip(vars, res.x)}).evalf() # # Putting it all together: # In[17]: def solve_SAT_with_gradient_descent(dimacs_text): float_circuit = load_dimacs(dimacs_text) loss_expr = (float_circuit - 1) ** 2 vars = sorted(float_circuit.free_symbols, key=str) print(vars) _loss_function = lambdify(vars, loss_expr, modules='numpy') loss_function = lambda array: _loss_function(*array) res = minimize( loss_function, np.zeros(len(vars)), method='nelder-mead', options={'xtol': 1e-30, 'disp': True}) print(res.x) return {var: get_bool(x) for var, x in zip(vars, res.x)} def check(dimacs_text): solution = solve_SAT_with_gradient_descent(dimacs_text) print(solution) print(f'Does the solution work? Does True=={sympy_load_dimacs(dimacs_text).subs(solution)}?') check(factor_35_dimacs) # # Something _slightly_ bigger # # When we try to factor $143=11\cdot 13$, even constructing the function to evaluate is slow, and the optimization fails to converge. # # Ah well, I've heard this SAT thing is hard. # In[18]: factor_11_times_13 = '''\ p cnf 8 12 c Factors encoded in variables 1-4 and 5-8 c Target number: 143 c Factors: 13 x 11 3 2 0 2 6 0 1 0 -3 -2 0 8 0 -2 -7 -6 0 4 0 -3 -2 -6 0 5 0 -2 -6 0 -3 -7 0 3 7 0 ''' # In[73]: get_ipython().run_line_magic('time', 'check(factor_11_times_13)')