from datashape import *
from inspect import *
from functools import *
from toolz.curried import *
from toolz.curried.operator import *
def to_datashape(annotation):
if isinstance(annotation, str):
return dshape(annotation)
if isinstance(annotation, dict):
return Record(
(_1, to_datashape(_2)) for _1, _2 in annotations.items())
return from_numpy(tuple(), annotation)
def annotation_to_shapes(callable):
shapes = pipe(
callable, getfullargspec, attrgetter('annotations'),
valmap(to_datashape)
)
return shapes, shapes.pop('return')
def check_shapes(values, shapes):
type_vars = {}
for key, shape in shapes.items():
value = values.pop(key)
discovered = discover(value)
new_parameters = []
for dim, parameter in enumerate(shape.parameters):
if isinstance(parameter, TypeVar):
if parameter not in type_vars:
if dim == len(shape.parameters)-1:
new = discover(value).measure
else:
if hasattr(value, 'shape'):
new = Fixed(value.shape[dim])
else:
new = Fixed(len(value))
type_vars[parameter] = new
new_parameters.append(type_vars[parameter])
else:
new_parameters.append(parameter)
shape = DataShape(*new_parameters)
if not validate(shape, value):
raise InteractiveTypeError(f"{key} expects {shape}, but recieved {discover(value)}")
if values:
Err(f"No types for {list(values)}.")
class InteractiveTypeError(BaseException):...
def typecheck(callable):
@wraps(callable)
def caller(*args, **kwargs):
argspec = getfullargspec(callable)
shapes, returns = annotation_to_shapes(callable)
values = merge(
dict(zip(argspec.args, argspec.defaults or [])),
dict(zip(argspec.args, args)),
kwargs)
check_shapes(values, shapes)
output = callable(*args, **kwargs)
return output
return caller
@typecheck
def dot(x: "N*float64", y: "N*float64") -> "float64":
return sum(_1*_2 for _1, _2 in zip(x, y))
dot([10.], [20.])
@typecheck
def matmul(x: "N*M*float64", y: "M*T*float64") -> "N*T*float64":
return x@y
# NBVAL_IGNORE_OUTPUT
matmul(np.random.randn(10,4), np.random.randn(4, 2))
array([[ 1.93643169, -8.33513156], [ 1.65499336, -5.2537133 ], [-0.47451908, 3.59120183], [-0.36964418, 1.08542178], [ 0.27122241, -1.42844261], [ 0.36788651, -0.66620386], [ 0.57044369, 0.5412394 ], [ 0.23610315, 0.796602 ], [ 1.31549658, -2.31397078], [-0.19160408, 3.57337919]])
try:
matmul(np.random.randn(10,4), np.random.randn(3, 2))
except InteractiveTypeError as e: print(e)
y expects 4 * 2 * float64, but recieved 3 * 2 * float64