Skip to content

Commit 7328a36

Browse files
authored
Merge pull request #150 from dafu-wu/master
Add support for decorative partial functions
2 parents 8d2d977 + ad066f5 commit 7328a36

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/decorator.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import inspect
3838
import operator
3939
import itertools
40+
import functools
4041
from contextlib import _GeneratorContextManager
4142
from inspect import getfullargspec, iscoroutinefunction, isgeneratorfunction
4243

@@ -71,7 +72,7 @@ def __init__(self, func=None, name=None, signature=None,
7172
self.name = '_lambda_'
7273
self.doc = func.__doc__
7374
self.module = func.__module__
74-
if inspect.isroutine(func):
75+
if inspect.isroutine(func) or isinstance(func, functools.partial):
7576
argspec = getfullargspec(func)
7677
self.annotations = getattr(func, '__annotations__', {})
7778
for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs',
@@ -214,6 +215,8 @@ def decorate(func, caller, extras=(), kwsyntax=False):
214215
does. By default kwsyntax is False and the the arguments are untouched.
215216
"""
216217
sig = inspect.signature(func)
218+
if isinstance(func, functools.partial):
219+
func = functools.update_wrapper(func, func.func)
217220
if iscoroutinefunction(caller):
218221
async def fun(*args, **kw):
219222
if not kwsyntax:
@@ -230,6 +233,7 @@ def fun(*args, **kw):
230233
if not kwsyntax:
231234
args, kw = fix(args, kw, sig)
232235
return caller(func, *(extras + args), **kw)
236+
233237
fun.__name__ = func.__name__
234238
fun.__doc__ = func.__doc__
235239
fun.__wrapped__ = func

src/tests/test.py

+16
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import unittest
44
import decimal
55
import inspect
6+
import functools
67
from asyncio import get_event_loop
78
from collections import defaultdict, ChainMap, abc as c
89
from decorator import dispatch_on, contextmanager, decorator
@@ -509,5 +510,20 @@ def __len__(self):
509510
h(u)
510511

511512

513+
@decorator
514+
def partial_before_after(func, *args, **kwargs):
515+
return "<before>" + func(*args, **kwargs) + "<after>"
516+
517+
518+
class PartialTestCase(unittest.TestCase):
519+
def test_before_after(self):
520+
def origin_func(x, y):
521+
return x + y
522+
_func = functools.partial(origin_func, "x")
523+
partial_func = partial_before_after(_func)
524+
out = partial_func("y")
525+
self.assertEqual(out, '<before>xy<after>')
526+
527+
512528
if __name__ == '__main__':
513529
unittest.main()

0 commit comments

Comments
 (0)