from datashape import *
from inspect import *
from functools import *
from toolz.curried import *
from toolz.curried.operator import *
def to_datashape(annotation):
if isinstance(annotation, str):
return dshape(annotation)
if isinstance(annotation, dict):
return Record(
(_1, to_datashape(_2)) for _1, _2 in annotations.items())
return from_numpy(tuple(), annotation)
def annotation_to_shapes(callable):
shapes = pipe(
callable, getfullargspec, attrgetter('annotations'),
valmap(to_datashape)
)
return shapes, shapes.pop('return')
def check_shapes(values, shapes):
type_vars = {}
for key, shape in shapes.items():
value = values.pop(key)
discovered = discover(value)
new_parameters = []
for dim, parameter in enumerate(shape.parameters):
if isinstance(parameter, TypeVar):
if parameter not in type_vars:
if dim == len(shape.parameters)-1:
new = discover(value).measure
else:
if hasattr(value, 'shape'):
new = Fixed(value.shape[dim])
else:
new = Fixed(len(value))
type_vars[parameter] = new
new_parameters.append(type_vars[parameter])
else:
new_parameters.append(parameter)
shape = DataShape(*new_parameters)
if not validate(shape, value):
raise InteractiveTypeError(f"{key} expects {shape}, but recieved {discover(value)}")
if values:
Err(f"No types for {list(values)}.")
class InteractiveTypeError(BaseException):...
def typecheck(callable):
@wraps(callable)
def caller(*args, **kwargs):
argspec = getfullargspec(callable)
shapes, returns = annotation_to_shapes(callable)
values = merge(
dict(zip(argspec.args, argspec.defaults or [])),
dict(zip(argspec.args, args)),
kwargs)
check_shapes(values, shapes)
output = callable(*args, **kwargs)
return output
return caller
@typecheck
def dot(x: "N*float64", y: "N*float64") -> "float64":
return sum(_1*_2 for _1, _2 in zip(x, y))
dot([10.], [20.])
@typecheck
def matmul(x: "N*M*float64", y: "M*T*float64") -> "N*T*float64":
return x@y
matmul(np.random.randn(10,4), np.random.randn(4, 2))
array([[ 3.72519493e+00, -1.09563082e-01], [ 4.69513259e-01, 3.37871837e+00], [ -4.44569978e+00, -2.15308497e+00], [ 2.63104150e+00, -2.90177959e+00], [ -7.61140141e-01, 2.16359480e+00], [ -2.40202344e-02, -1.48523737e+00], [ -3.08959526e+00, 6.93008449e-01], [ 1.29601051e+00, -2.66912294e-03], [ 1.87770045e-01, -9.63482844e-02], [ 2.59286769e+00, -8.12771658e-01]])
from pytest import raises
def _validates_types():
assert matmul(np.random.randn(10,4), np.random.randn(4, 2)) is not None
assert dot([10.], [20.])
def _finds_type_errors():
with raises(InteractiveTypeError):
matmul(np.random.randn(10,4), np.random.randn(3, 2))