#default_exp meta #export from fastcore.imports import * from contextlib import contextmanager from copy import copy import inspect from fastcore.test import * from fastcore.foundation import * from nbdev.showdoc import * from fastcore.nb_imports import * #export def test_sig(f, b): "Test the signature of an object" test_eq(str(inspect.signature(f)), b) def func_1(h,i,j): pass def func_2(h,i=3, j=[5,6]): pass class T: def __init__(self, a, b): pass test_sig(func_1, '(h, i, j)') test_sig(func_2, '(h, i=3, j=[5, 6])') test_sig(T, '(a, b)') #export def _rm_self(sig): sigd = dict(sig.parameters) sigd.pop('self') return sig.replace(parameters=sigd.values()) #export class FixSigMeta(type): "A metaclass that fixes the signature on classes that override `__new__`" def __new__(cls, name, bases, dict): res = super().__new__(cls, name, bases, dict) if res.__init__ is not object.__init__: res.__signature__ = _rm_self(inspect.signature(res.__init__)) return res show_doc(FixSigMeta, title_level=3) class T: def __init__(self, a, b, c): pass inspect.signature(T) class Foo: def __new__(self, **args): pass class Bar(Foo): def __init__(self, d, e, f): pass inspect.signature(Bar) class Bar(Foo, metaclass=FixSigMeta): def __init__(self, d, e, f): pass test_sig(Bar, '(d, e, f)') inspect.signature(Bar) class TestMeta(FixSigMeta): # __new__ comes from FixSigMeta def __call__(cls, *args, **kwargs): pass class T(metaclass=TestMeta): def __init__(self, a, b): pass test_sig(T, '(a, b)') class GenericMeta(type): "A boilerplate metaclass that doesn't do anything for testing." def __new__(cls, name, bases, dict): return super().__new__(cls, name, bases, dict) def __call__(cls, *args, **kwargs): pass class T2(metaclass=GenericMeta): def __init__(self, a, b): pass # We can avoid this by inheriting from the metaclass `FixSigMeta` test_sig(T2, '(*args, **kwargs)') #export class PrePostInitMeta(FixSigMeta): "A metaclass that calls optional `__pre_init__` and `__post_init__` methods" def __call__(cls, *args, **kwargs): res = cls.__new__(cls) if type(res)==cls: if hasattr(res,'__pre_init__'): res.__pre_init__(*args,**kwargs) res.__init__(*args,**kwargs) if hasattr(res,'__post_init__'): res.__post_init__(*args,**kwargs) return res show_doc(PrePostInitMeta, title_level=3) class _T(metaclass=PrePostInitMeta): def __pre_init__(self): self.a = 0; def __init__(self,b=0): self.b = self.a + 1; assert self.b==1 def __post_init__(self): self.c = self.b + 2; assert self.c==3 t = _T() test_eq(t.a, 0) # set with __pre_init__ test_eq(t.b, 1) # set with __init__ test_eq(t.c, 3) # set with __post_init__ #exports class AutoInit(metaclass=PrePostInitMeta): "Same as `object`, but no need for subclasses to call `super().__init__`" def __pre_init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class TestParent(): def __init__(self): self.h = 10 class TestChild(AutoInit, TestParent): def __init__(self): self.k = self.h + 2 t = TestChild() test_eq(t.h, 10) # h=10 is initialized in the parent class test_eq(t.k, 12) #export class NewChkMeta(FixSigMeta): "Metaclass to avoid recreating object passed to constructor" def __call__(cls, x=None, *args, **kwargs): if not args and not kwargs and x is not None and isinstance(x,cls): return x res = super().__call__(*((x,) + args), **kwargs) return res show_doc(NewChkMeta, title_level=3) class _T(): "Testing" def __init__(self, o): # if `o` is not an object without an attribute `foo`, set foo = 1 self.foo = getattr(o,'foo',1) t = _T(3) test_eq(t.foo,1) # 1 was not of type _T, so foo = 1 t2 = _T(t) #t1 is of type _T assert t is not t2 # t1 and t2 are different objects class _T(metaclass=NewChkMeta): "Testing with metaclass NewChkMeta" def __init__(self, o=None, b=1): # if `o` is not an object without an attribute `foo`, set foo = 1 self.foo = getattr(o,'foo',1) self.b = b t = _T(3) test_eq(t.foo,1) # 1 was not of type _T, so foo = 1 t2 = _T(t) # t2 will now reference t test_is(t, t2) # t and t2 are the same object t2.foo = 5 # this will also change t.foo to 5 because it is the same object test_eq(t.foo, 5) test_eq(t2.foo, 5) t3 = _T(t, b=1) assert t3 is not t t4 = _T(t) # without any arguments the constructor will return a reference to the same object assert t4 is t test_sig(_T, '(o=None, b=1)') #export class BypassNewMeta(FixSigMeta): "Metaclass: casts `x` to this class if it's of type `cls._bypass_type`" def __call__(cls, x=None, *args, **kwargs): if hasattr(cls, '_new_meta'): x = cls._new_meta(x, *args, **kwargs) elif not isinstance(x,getattr(cls,'_bypass_type',object)) or len(args) or len(kwargs): x = super().__call__(*((x,)+args), **kwargs) if cls!=x.__class__: x.__class__ = cls return x show_doc(BypassNewMeta, title_level=3) class _TestA: pass class _TestB: pass class _T(_TestA, metaclass=BypassNewMeta): _bypass_type=_TestB def __init__(self,x): self.x=x t = _TestA() t2 = _T(t) assert t is not t2 t = _TestB() t2 = _T(t) t2.new_attr = 15 test_is(t, t2) # since t2 just references t these will be the same test_eq(t.new_attr, t2.new_attr) # likewise, chaning an attribute on t will also affect t2 because they both point to the same object. t.new_attr = 9 test_eq(t2.new_attr, 9) #export def empty2none(p): "Replace `Parameter.empty` with `None`" return None if p==inspect.Parameter.empty else p #export def anno_dict(f): "`__annotation__ dictionary with `empty` cast to `None`, returning empty if doesn't exist" return {k:empty2none(v) for k,v in getattr(f, '__annotations__', {}).items()} def _f(a:int, b:L)->str: ... test_eq(anno_dict(_f), {'a': int, 'b': L, 'return': str}) #export def _mk_param(n,d=None): return inspect.Parameter(n, inspect.Parameter.KEYWORD_ONLY, default=d) #export def use_kwargs_dict(keep=False, **kwargs): "Decorator: replace `**kwargs` in signature with `names` params" def _f(f): sig = inspect.signature(f) sigd = dict(sig.parameters) k = sigd.pop('kwargs') s2 = {n:_mk_param(n,d) for n,d in kwargs.items() if n not in sigd} sigd.update(s2) if keep: sigd['kwargs'] = k f.__signature__ = sig.replace(parameters=sigd.values()) return f return _f @use_kwargs_dict(y=1,z=None) def foo(a, b=1, **kwargs): pass test_sig(foo, '(a, b=1, *, y=1, z=None)') @use_kwargs_dict(y=1,z=None, keep=True) def foo(a, b=1, **kwargs): pass test_sig(foo, '(a, b=1, *, y=1, z=None, **kwargs)') #export def use_kwargs(names, keep=False): "Decorator: replace `**kwargs` in signature with `names` params" def _f(f): sig = inspect.signature(f) sigd = dict(sig.parameters) k = sigd.pop('kwargs') s2 = {n:_mk_param(n) for n in names if n not in sigd} sigd.update(s2) if keep: sigd['kwargs'] = k f.__signature__ = sig.replace(parameters=sigd.values()) return f return _f @use_kwargs(['y', 'z']) def foo(a, b=1, **kwargs): pass test_sig(foo, '(a, b=1, *, y=None, z=None)') @use_kwargs(['y', 'z'], keep=True) def foo(a, *args, b=1, **kwargs): pass test_sig(foo, '(a, *args, b=1, y=None, z=None, **kwargs)') # export def delegates(to=None, keep=False, but=None): "Decorator: replace `**kwargs` in signature with params from `to`" if but is None: but = [] def _f(f): if to is None: to_f,from_f = f.__base__.__init__,f.__init__ else: to_f,from_f = to.__init__ if isinstance(to,type) else to,f from_f = getattr(from_f,'__func__',from_f) to_f = getattr(to_f,'__func__',to_f) if hasattr(from_f,'__delwrap__'): return f sig = inspect.signature(from_f) sigd = dict(sig.parameters) k = sigd.pop('kwargs') s2 = {k:v for k,v in inspect.signature(to_f).parameters.items() if v.default != inspect.Parameter.empty and k not in sigd and k not in but} sigd.update(s2) if keep: sigd['kwargs'] = k else: from_f.__delwrap__ = to_f from_f.__signature__ = sig.replace(parameters=sigd.values()) return f return _f def baz(a, b=2, c =3): return a + b + c def foo(c, a, **kwargs): return c + baz(a, **kwargs) assert foo(c=1, a=1) == 7 inspect.signature(foo) @delegates(baz) def foo(c, a, **kwargs): return c + baz(a, **kwargs) test_sig(foo, '(c, a, b=2)') inspect.signature(foo) @delegates(baz, keep=True) def foo(c, a, **kwargs): return c + baz(a, **kwargs) test_sig(foo, '(c, a, b=2, **kwargs)') inspect.signature(foo) def basefoo(e, d, c=2): pass @delegates(basefoo) def foo(a, b=1, **kwargs): pass test_sig(foo, '(a, b=1, c=2)') # e and d are not included b/c they don't have default parameters. inspect.signature(foo) def basefoo(e, c=2, d=3): pass @delegates(basefoo, but= ['d']) def foo(a, b=1, **kwargs): pass test_sig(foo, '(a, b=1, c=2)') inspect.signature(foo) # example 1: class methods class _T(): @classmethod def foo(cls, a=1, b=2): pass @classmethod @delegates(foo) def bar(cls, c=3, **kwargs): pass test_sig(_T.bar, '(c=3, a=1, b=2)') # example 2: instance methods class _T(): def foo(self, a=1, b=2): pass @delegates(foo) def bar(self, c=3, **kwargs): pass t = _T() test_sig(t.bar, '(c=3, a=1, b=2)') class BaseFoo: def __init__(self, e, c=2): pass @delegates()# since no argument was passsed here we delegate to the superclass class Foo(BaseFoo): def __init__(self, a, b=1, **kwargs): super().__init__(**kwargs) test_sig(Foo, '(a, b=1, c=2)') #export def method(f): "Mark `f` as a method" # `1` is a dummy instance since Py3 doesn't allow `None` any more return MethodType(f, 1) def a(x=2): return x + 1 assert type(a).__name__ == 'function' a = method(a) assert type(a).__name__ == 'method' #export def _funcs_kwargs(cls, as_method): old_init = cls.__init__ def _init(self, *args, **kwargs): for k in cls._methods: arg = kwargs.pop(k,None) if arg is not None: if as_method: arg = method(arg) if isinstance(arg,MethodType): arg = MethodType(arg.__func__, self) setattr(self, k, arg) old_init(self, *args, **kwargs) functools.update_wrapper(_init, old_init) cls.__init__ = use_kwargs(cls._methods)(_init) if hasattr(cls, '__signature__'): cls.__signature__ = _rm_self(inspect.signature(cls.__init__)) return cls #export def funcs_kwargs(as_method=False): "Replace methods in `cls._methods` with those from `kwargs`" if callable(as_method): return _funcs_kwargs(as_method, False) return partial(_funcs_kwargs, as_method=as_method) @funcs_kwargs class T: _methods=['b'] # allows you to add method b upon instantiation def __init__(self, f=1, **kwargs): pass # don't forget to include **kwargs in __init__ def a(self): return 1 def b(self): return 2 t = T() test_eq(t.a(), 1) test_eq(t.b(), 2) test_sig(T, '(f=1, *, b=None)') inspect.signature(T) def _new_func(): return 5 t = T(b = _new_func) test_eq(t.b(), 5) t = T(a = lambda:3) test_eq(t.a(), 1) # the attempt to add a is ignored and uses the original method instead. @funcs_kwargs class T: _methods=['c'] def __init__(self, f=1, **kwargs): pass t = T(c = lambda: 4) test_eq(t.c(), 4) def _f(self,a=1): return self.num + a # access the num attribute from the instance @funcs_kwargs(as_method=True) class T: _methods=['b'] num = 5 t = T(b = _f) # adds method b test_eq(t.b(5), 10) # self.num + 5 = 10 def _f(self,a=1): return self.num * a #multiply instead of add class T2(T): def __init__(self,num): super().__init__(b = _f) # add method b from the super class self.num=num t = T2(num=3) test_eq(t.b(a=5), 15) # 3 * 5 = 15 test_sig(T2, '(num)') #hide def _g(a=1): return a+1 class T3(T): b = staticmethod(_g) t = T3() test_eq(t.b(2), 3) #hide #test funcs_kwargs works with PrePostInitMeta class A(metaclass=PrePostInitMeta): pass @funcs_kwargs class B(A): _methods = ['m1'] def __init__(self, **kwargs): pass test_sig(B, '(*, m1=None)') #hide from nbdev.export import notebook2script notebook2script()