#!/usr/bin/env python # coding: utf-8 # In[ ]: import numpy as np from fractions import Fraction as frac def Matrix(*a): if len(a)==1 and isinstance(a[0], np.ndarray): a = a[0] return np.array([[frac(x) for x in r] for r in a]) def Vector(*a): if len(a)==1 and isinstance(a[0], np.ndarray): a = a[0] return np.array([frac(x) for x in a]).reshape(-1,1) # In[ ]: # 巫術 from IPython.display import Latex, SVG, display from IPython.core.interactiveshell import InteractiveShell def frac_to_latex(self): if self._denominator == 1: return str(self._numerator) return "\\frac{{{}}}{{{}}}".format(self._numerator, self._denominator) frac.__str__= frac_to_latex frac._repr_latex_ = lambda x:"$"+frac_to_latex(x)+"$" def ndarray_to_latex(arr): if len(arr.shape)==1: arr=arr.reshape(1,-1) if len(arr.shape) != 2: return None str_arr = np.vectorize(str)(arr) return r'\begin{{pmatrix}}{}\end{{pmatrix}}'.format(r'\\ '.join(map('&'.join, str_arr))) sh = InteractiveShell.instance() sh.display_formatter.formatters['text/latex'].type_printers[np.ndarray]=ndarray_to_latex def matrix_dot(A,B): if isinstance(A, np.ndarray): assert len(A.shape)==2==len(B.shape) return np.array([(A * x).sum(axis=1) for x in B.T]).T assert callable(A) if isinstance(B, np.ndarray): return A(B) assert callable(B) return lambda x:A(B(x)) import ast def int_to_frac(x): if isinstance(x, int): return frac(x) return x class NumberWrapper(ast.NodeTransformer): def visit_BinOp(self, node): node = self.generic_visit(node) left = node.left right = node.right if isinstance(node.op, ast.MatMult): return ast.Call(func=ast.Name(id='matrix_dot', ctx=ast.Load()), args=[left, right], keywords=[]) elif isinstance(node.op, ast.Div): right = ast.Call(func=ast.Name(id='int_to_frac', ctx=ast.Load()), args=[right], keywords=[]) return ast.BinOp(left, node.op, right) def visit_Num(self, node): if isinstance(node.n, float): print("convert",repr(node.n), ) return ast.Call(func=ast.Name(id='frac', ctx=ast.Load()), args=[ast.Str(str(node.n))], keywords=[]) return node sh.ast_transformers.append(NumberWrapper()) smiley=SVG('') def check_answer(result): if not result.all(): print("Try again") return result return smiley