Skip to content

Commit fb2b6fd

Browse files
committed
more extensive testing for dill.source
1 parent 8b86f50 commit fb2b6fd

File tree

3 files changed

+253
-63
lines changed

3 files changed

+253
-63
lines changed

dill/source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def outdent(code, spaces=None, all=True):
530530
return '\n'.join(_outdent(code.split('\n'), spaces=spaces, all=all))
531531

532532

533-
#XXX: not sure what the point of _wrap is...
533+
# _wrap provides an wrapper to correctly exec and load into locals
534534
__globals__ = globals()
535535
__locals__ = locals()
536536
def _wrap(f):

dill/tests/test_source.py

Lines changed: 62 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
# License: 3-clause BSD. The full license text is available at:
77
# - https://github.com/uqfoundation/dill/blob/master/LICENSE
88

9-
from dill.source import getsource, getname, _wrap, likely_import
10-
from dill.source import getimportable
9+
from dill.source import getsource, getname, _wrap, getimport
10+
from dill.source import importable
1111
from dill._dill import IS_PYPY
1212

1313
import sys
@@ -55,31 +55,31 @@ def test_getsource():
5555

5656
# test itself
5757
def test_itself():
58-
assert likely_import(likely_import)=='from dill.source import likely_import\n'
58+
assert getimport(getimport)=='from dill.source import getimport\n'
5959

6060
# builtin functions and objects
6161
def test_builtin():
62-
assert likely_import(pow) == 'pow\n'
63-
assert likely_import(100) == '100\n'
64-
assert likely_import(True) == 'True\n'
65-
assert likely_import(pow, explicit=True) == 'from builtins import pow\n'
66-
assert likely_import(100, explicit=True) == '100\n'
67-
assert likely_import(True, explicit=True) == 'True\n'
62+
assert getimport(pow) == 'pow\n'
63+
assert getimport(100) == '100\n'
64+
assert getimport(True) == 'True\n'
65+
assert getimport(pow, builtin=True) == 'from builtins import pow\n'
66+
assert getimport(100, builtin=True) == '100\n'
67+
assert getimport(True, builtin=True) == 'True\n'
6868
# this is kinda BS... you can't import a None
69-
assert likely_import(None) == 'None\n'
70-
assert likely_import(None, explicit=True) == 'None\n'
69+
assert getimport(None) == 'None\n'
70+
assert getimport(None, builtin=True) == 'None\n'
7171

7272

7373
# other imported functions
7474
def test_imported():
7575
from math import sin
76-
assert likely_import(sin) == 'from math import sin\n'
76+
assert getimport(sin) == 'from math import sin\n'
7777

7878
# interactively defined functions
7979
def test_dynamic():
80-
assert likely_import(add) == 'from %s import add\n' % __name__
80+
assert getimport(add) == 'from %s import add\n' % __name__
8181
# interactive lambdas
82-
assert likely_import(squared) == 'from %s import squared\n' % __name__
82+
assert getimport(squared) == 'from %s import squared\n' % __name__
8383

8484
# classes and class instances
8585
def test_classes():
@@ -88,61 +88,61 @@ def test_classes():
8888
x = y if (IS_PYPY or sys.hexversion >= PY310b) else "from io import BytesIO\n"
8989
s = StringIO()
9090

91-
assert likely_import(StringIO) == x
92-
assert likely_import(s) == y
91+
assert getimport(StringIO) == x
92+
assert getimport(s) == y
9393
# interactively defined classes and class instances
94-
assert likely_import(Foo) == 'from %s import Foo\n' % __name__
95-
assert likely_import(_foo) == 'from %s import Foo\n' % __name__
94+
assert getimport(Foo) == 'from %s import Foo\n' % __name__
95+
assert getimport(_foo) == 'from %s import Foo\n' % __name__
9696

9797

98-
# test getimportable
98+
# test importable
9999
def test_importable():
100-
assert getimportable(add) == 'from %s import add\n' % __name__
101-
assert getimportable(squared) == 'from %s import squared\n' % __name__
102-
assert getimportable(Foo) == 'from %s import Foo\n' % __name__
103-
assert getimportable(Foo.bar) == 'from %s import bar\n' % __name__
104-
assert getimportable(_foo.bar) == 'from %s import bar\n' % __name__
105-
assert getimportable(None) == 'None\n'
106-
assert getimportable(100) == '100\n'
107-
108-
assert getimportable(add, byname=False) == 'def add(x,y):\n return x+y\n'
109-
assert getimportable(squared, byname=False) == 'squared = lambda x:x**2\n'
110-
assert getimportable(None, byname=False) == 'None\n'
111-
assert getimportable(Bar, byname=False) == 'class Bar:\n pass\n'
112-
assert getimportable(Foo, byname=False) == 'class Foo(object):\n def bar(self, x):\n return x*x+x\n'
113-
assert getimportable(Foo.bar, byname=False) == 'def bar(self, x):\n return x*x+x\n'
114-
assert getimportable(Foo.bar, byname=True) == 'from %s import bar\n' % __name__
115-
assert getimportable(Foo.bar, alias='memo', byname=True) == 'from %s import bar as memo\n' % __name__
116-
assert getimportable(Foo, alias='memo', byname=True) == 'from %s import Foo as memo\n' % __name__
117-
assert getimportable(squared, alias='memo', byname=True) == 'from %s import squared as memo\n' % __name__
118-
assert getimportable(squared, alias='memo', byname=False) == 'memo = squared = lambda x:x**2\n'
119-
assert getimportable(add, alias='memo', byname=False) == 'def add(x,y):\n return x+y\n\nmemo = add\n'
120-
assert getimportable(None, alias='memo', byname=False) == 'memo = None\n'
121-
assert getimportable(100, alias='memo', byname=False) == 'memo = 100\n'
122-
assert getimportable(add, explicit=True) == 'from %s import add\n' % __name__
123-
assert getimportable(squared, explicit=True) == 'from %s import squared\n' % __name__
124-
assert getimportable(Foo, explicit=True) == 'from %s import Foo\n' % __name__
125-
assert getimportable(Foo.bar, explicit=True) == 'from %s import bar\n' % __name__
126-
assert getimportable(_foo.bar, explicit=True) == 'from %s import bar\n' % __name__
127-
assert getimportable(None, explicit=True) == 'None\n'
128-
assert getimportable(100, explicit=True) == '100\n'
100+
assert importable(add, source=False) == 'from %s import add\n' % __name__
101+
assert importable(squared, source=False) == 'from %s import squared\n' % __name__
102+
assert importable(Foo, source=False) == 'from %s import Foo\n' % __name__
103+
assert importable(Foo.bar, source=False) == 'from %s import bar\n' % __name__
104+
assert importable(_foo.bar, source=False) == 'from %s import bar\n' % __name__
105+
assert importable(None, source=False) == 'None\n'
106+
assert importable(100, source=False) == '100\n'
107+
108+
assert importable(add, source=True) == 'def add(x,y):\n return x+y\n'
109+
assert importable(squared, source=True) == 'squared = lambda x:x**2\n'
110+
assert importable(None, source=True) == 'None\n'
111+
assert importable(Bar, source=True) == 'class Bar:\n pass\n'
112+
assert importable(Foo, source=True) == 'class Foo(object):\n def bar(self, x):\n return x*x+x\n'
113+
assert importable(Foo.bar, source=True) == 'def bar(self, x):\n return x*x+x\n'
114+
assert importable(Foo.bar, source=False) == 'from %s import bar\n' % __name__
115+
assert importable(Foo.bar, alias='memo', source=False) == 'from %s import bar as memo\n' % __name__
116+
assert importable(Foo, alias='memo', source=False) == 'from %s import Foo as memo\n' % __name__
117+
assert importable(squared, alias='memo', source=False) == 'from %s import squared as memo\n' % __name__
118+
assert importable(squared, alias='memo', source=True) == 'memo = squared = lambda x:x**2\n'
119+
assert importable(add, alias='memo', source=True) == 'def add(x,y):\n return x+y\n\nmemo = add\n'
120+
assert importable(None, alias='memo', source=True) == 'memo = None\n'
121+
assert importable(100, alias='memo', source=True) == 'memo = 100\n'
122+
assert importable(add, builtin=True, source=False) == 'from %s import add\n' % __name__
123+
assert importable(squared, builtin=True, source=False) == 'from %s import squared\n' % __name__
124+
assert importable(Foo, builtin=True, source=False) == 'from %s import Foo\n' % __name__
125+
assert importable(Foo.bar, builtin=True, source=False) == 'from %s import bar\n' % __name__
126+
assert importable(_foo.bar, builtin=True, source=False) == 'from %s import bar\n' % __name__
127+
assert importable(None, builtin=True, source=False) == 'None\n'
128+
assert importable(100, builtin=True, source=False) == '100\n'
129129

130130

131131
def test_numpy():
132132
try:
133133
import numpy as np
134134
y = np.array
135135
x = y([1,2,3])
136-
assert getimportable(x) == 'from numpy import array\narray([1, 2, 3])\n'
137-
assert getimportable(y) == 'from %s import array\n' % y.__module__
138-
assert getimportable(x, byname=False) == 'from numpy import array\narray([1, 2, 3])\n'
139-
assert getimportable(y, byname=False) == 'from %s import array\n' % y.__module__
136+
assert importable(x, source=False) == 'from numpy import array\narray([1, 2, 3])\n'
137+
assert importable(y, source=False) == 'from %s import array\n' % y.__module__
138+
assert importable(x, source=True) == 'from numpy import array\narray([1, 2, 3])\n'
139+
assert importable(y, source=True) == 'from %s import array\n' % y.__module__
140140
y = np.int64
141141
x = y(0)
142-
assert getimportable(x) == 'from numpy import int64\nint64(0)\n'
143-
assert getimportable(y) == 'from %s import int64\n' % y.__module__
144-
assert getimportable(x, byname=False) == 'from numpy import int64\nint64(0)\n'
145-
assert getimportable(y, byname=False) == 'from %s import int64\n' % y.__module__
142+
assert importable(x, source=False) == 'from numpy import int64\nint64(0)\n'
143+
assert importable(y, source=False) == 'from %s import int64\n' % y.__module__
144+
assert importable(x, source=True) == 'from numpy import int64\nint64(0)\n'
145+
assert importable(y, source=True) == 'from %s import int64\n' % y.__module__
146146
y = np.bool_
147147
x = y(0)
148148
import warnings
@@ -151,15 +151,15 @@ def test_numpy():
151151
warnings.filterwarnings('ignore', category=DeprecationWarning)
152152
if hasattr(np, 'bool'): b = 'bool_' if np.bool is bool else 'bool'
153153
else: b = 'bool_'
154-
assert getimportable(x) == 'from numpy import %s\n%s(False)\n' % (b,b)
155-
assert getimportable(y) == 'from %s import %s\n' % (y.__module__,b)
156-
assert getimportable(x, byname=False) == 'from numpy import %s\n%s(False)\n' % (b,b)
157-
assert getimportable(y, byname=False) == 'from %s import %s\n' % (y.__module__,b)
154+
assert importable(x, source=False) == 'from numpy import %s\n%s(False)\n' % (b,b)
155+
assert importable(y, source=False) == 'from %s import %s\n' % (y.__module__,b)
156+
assert importable(x, source=True) == 'from numpy import %s\n%s(False)\n' % (b,b)
157+
assert importable(y, source=True) == 'from %s import %s\n' % (y.__module__,b)
158158
except ImportError: pass
159159

160-
#NOTE: if before likely_import(pow), will cause pow to throw AssertionError
160+
#NOTE: if before getimport(pow), will cause pow to throw AssertionError
161161
def test_foo():
162-
assert getimportable(_foo, byname=False).startswith("import dill\nclass Foo(object):\n def bar(self, x):\n return x*x+x\ndill.loads(")
162+
assert importable(_foo, source=True).startswith("import dill\nclass Foo(object):\n def bar(self, x):\n return x*x+x\ndill.loads(")
163163

164164
if __name__ == '__main__':
165165
test_getsource()

dill/tests/test_sources.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
#!/usr/bin/env python
2+
#
3+
# Author: Mike McKerns (mmckerns @uqfoundation)
4+
# Copyright (c) 2024 The Uncertainty Quantification Foundation.
5+
# License: 3-clause BSD. The full license text is available at:
6+
# - https://github.com/uqfoundation/dill/blob/master/LICENSE
7+
"""
8+
check that dill.source performs as expected with changes to locals in 3.13.0b1
9+
see: https://github.com/python/cpython/issues/118888
10+
"""
11+
# repeat functions from test_source.py
12+
f = lambda x: x**2
13+
def g(x): return f(x) - x
14+
15+
def h(x):
16+
def g(x): return x
17+
return g(x) - x
18+
19+
class Foo(object):
20+
def bar(self, x):
21+
return x*x+x
22+
_foo = Foo()
23+
24+
def add(x,y):
25+
return x+y
26+
27+
squared = lambda x:x**2
28+
29+
class Bar:
30+
pass
31+
_bar = Bar()
32+
33+
# repeat, but from test_source.py
34+
import test_source as ts
35+
36+
# test objects created in other test modules
37+
import test_mixins as tm
38+
39+
import dill.source as ds
40+
41+
42+
def test_isfrommain():
43+
assert ds.isfrommain(add) == True
44+
assert ds.isfrommain(squared) == True
45+
assert ds.isfrommain(Bar) == True
46+
assert ds.isfrommain(_bar) == True
47+
assert ds.isfrommain(ts.add) == False
48+
assert ds.isfrommain(ts.squared) == False
49+
assert ds.isfrommain(ts.Bar) == False
50+
assert ds.isfrommain(ts._bar) == False
51+
assert ds.isfrommain(tm.quad) == False
52+
assert ds.isfrommain(tm.double_add) == False
53+
assert ds.isfrommain(tm.quadratic) == False
54+
assert ds.isdynamic(add) == False
55+
assert ds.isdynamic(squared) == False
56+
assert ds.isdynamic(ts.add) == False
57+
assert ds.isdynamic(ts.squared) == False
58+
assert ds.isdynamic(tm.double_add) == False
59+
assert ds.isdynamic(tm.quadratic) == False
60+
61+
62+
def test_matchlambda():
63+
assert ds._matchlambda(f, 'f = lambda x: x**2\n')
64+
assert ds._matchlambda(squared, 'squared = lambda x:x**2\n')
65+
assert ds._matchlambda(ts.f, 'f = lambda x: x**2\n')
66+
assert ds._matchlambda(ts.squared, 'squared = lambda x:x**2\n')
67+
68+
69+
def test_findsource():
70+
lines, lineno = ds.findsource(add)
71+
assert lines[lineno] == 'def add(x,y):\n'
72+
lines, lineno = ds.findsource(ts.add)
73+
assert lines[lineno] == 'def add(x,y):\n'
74+
lines, lineno = ds.findsource(squared)
75+
assert lines[lineno] == 'squared = lambda x:x**2\n'
76+
lines, lineno = ds.findsource(ts.squared)
77+
assert lines[lineno] == 'squared = lambda x:x**2\n'
78+
lines, lineno = ds.findsource(Bar)
79+
assert lines[lineno] == 'class Bar:\n'
80+
lines, lineno = ds.findsource(ts.Bar)
81+
assert lines[lineno] == 'class Bar:\n'
82+
lines, lineno = ds.findsource(_bar)
83+
assert lines[lineno] == 'class Bar:\n'
84+
lines, lineno = ds.findsource(ts._bar)
85+
assert lines[lineno] == 'class Bar:\n'
86+
lines, lineno = ds.findsource(tm.quad)
87+
assert lines[lineno] == 'def quad(a=1, b=1, c=0):\n'
88+
lines, lineno = ds.findsource(tm.double_add)
89+
assert lines[lineno] == ' def func(*args, **kwds):\n'
90+
lines, lineno = ds.findsource(tm.quadratic)
91+
assert lines[lineno] == ' def dec(f):\n'
92+
93+
94+
def test_getsourcelines():
95+
assert ''.join(ds.getsourcelines(add)[0]) == 'def add(x,y):\n return x+y\n'
96+
assert ''.join(ds.getsourcelines(ts.add)[0]) == 'def add(x,y):\n return x+y\n'
97+
assert ''.join(ds.getsourcelines(squared)[0]) == 'squared = lambda x:x**2\n'
98+
assert ''.join(ds.getsourcelines(ts.squared)[0]) == 'squared = lambda x:x**2\n'
99+
assert ''.join(ds.getsourcelines(Bar)[0]) == 'class Bar:\n pass\n'
100+
assert ''.join(ds.getsourcelines(ts.Bar)[0]) == 'class Bar:\n pass\n'
101+
assert ''.join(ds.getsourcelines(_bar)[0]) == 'class Bar:\n pass\n' #XXX: ?
102+
assert ''.join(ds.getsourcelines(ts._bar)[0]) == 'class Bar:\n pass\n' #XXX: ?
103+
assert ''.join(ds.getsourcelines(tm.quad)[0]) == 'def quad(a=1, b=1, c=0):\n inverted = [False]\n def invert():\n inverted[0] = not inverted[0]\n def dec(f):\n def func(*args, **kwds):\n x = f(*args, **kwds)\n if inverted[0]: x = -x\n return a*x**2 + b*x + c\n func.__wrapped__ = f\n func.invert = invert\n func.inverted = inverted\n return func\n return dec\n'
104+
assert ''.join(ds.getsourcelines(tm.quadratic)[0]) == ' def dec(f):\n def func(*args,**kwds):\n fx = f(*args,**kwds)\n return a*fx**2 + b*fx + c\n return func\n'
105+
assert ''.join(ds.getsourcelines(tm.quadratic, lstrip=True)[0]) == 'def dec(f):\n def func(*args,**kwds):\n fx = f(*args,**kwds)\n return a*fx**2 + b*fx + c\n return func\n'
106+
assert ''.join(ds.getsourcelines(tm.quadratic, enclosing=True)[0]) == 'def quad_factory(a=1,b=1,c=0):\n def dec(f):\n def func(*args,**kwds):\n fx = f(*args,**kwds)\n return a*fx**2 + b*fx + c\n return func\n return dec\n'
107+
assert ''.join(ds.getsourcelines(tm.double_add)[0]) == ' def func(*args, **kwds):\n x = f(*args, **kwds)\n if inverted[0]: x = -x\n return a*x**2 + b*x + c\n'
108+
assert ''.join(ds.getsourcelines(tm.double_add, enclosing=True)[0]) == 'def quad(a=1, b=1, c=0):\n inverted = [False]\n def invert():\n inverted[0] = not inverted[0]\n def dec(f):\n def func(*args, **kwds):\n x = f(*args, **kwds)\n if inverted[0]: x = -x\n return a*x**2 + b*x + c\n func.__wrapped__ = f\n func.invert = invert\n func.inverted = inverted\n return func\n return dec\n'
109+
110+
111+
def test_indent():
112+
assert ds.outdent(''.join(ds.getsourcelines(tm.quadratic)[0])) == ''.join(ds.getsourcelines(tm.quadratic, lstrip=True)[0])
113+
assert ds.indent(''.join(ds.getsourcelines(tm.quadratic, lstrip=True)[0]), 2) == ''.join(ds.getsourcelines(tm.quadratic)[0])
114+
115+
116+
def test_dumpsource():
117+
local = {}
118+
exec(ds.dumpsource(add, alias='raw'), {}, local)
119+
exec(ds.dumpsource(ts.add, alias='mod'), {}, local)
120+
assert local['raw'](1,2) == local['mod'](1,2)
121+
exec(ds.dumpsource(squared, alias='raw'), {}, local)
122+
exec(ds.dumpsource(ts.squared, alias='mod'), {}, local)
123+
assert local['raw'](3) == local['mod'](3)
124+
assert ds._wrap(add)(1,2) == ds._wrap(ts.add)(1,2)
125+
assert ds._wrap(squared)(3) == ds._wrap(ts.squared)(3)
126+
127+
128+
def test_name():
129+
assert ds._namespace(add) == ds.getname(add, fqn=True).split('.')
130+
assert ds._namespace(ts.add) == ds.getname(ts.add, fqn=True).split('.')
131+
assert ds._namespace(squared) == ds.getname(squared, fqn=True).split('.')
132+
assert ds._namespace(ts.squared) == ds.getname(ts.squared, fqn=True).split('.')
133+
assert ds._namespace(Bar) == ds.getname(Bar, fqn=True).split('.')
134+
assert ds._namespace(ts.Bar) == ds.getname(ts.Bar, fqn=True).split('.')
135+
assert ds._namespace(tm.quad) == ds.getname(tm.quad, fqn=True).split('.')
136+
#XXX: the following also works, however behavior may be wrong for nested functions
137+
#assert ds._namespace(tm.double_add) == ds.getname(tm.double_add, fqn=True).split('.')
138+
#assert ds._namespace(tm.quadratic) == ds.getname(tm.quadratic, fqn=True).split('.')
139+
assert ds.getname(add) == 'add'
140+
assert ds.getname(ts.add) == 'add'
141+
assert ds.getname(squared) == 'squared'
142+
assert ds.getname(ts.squared) == 'squared'
143+
assert ds.getname(Bar) == 'Bar'
144+
assert ds.getname(ts.Bar) == 'Bar'
145+
assert ds.getname(tm.quad) == 'quad'
146+
assert ds.getname(tm.double_add) == 'func' #XXX: ?
147+
assert ds.getname(tm.quadratic) == 'dec' #XXX: ?
148+
149+
150+
def test_getimport():
151+
local = {}
152+
exec(ds.getimport(add, alias='raw'), {}, local)
153+
exec(ds.getimport(ts.add, alias='mod'), {}, local)
154+
assert local['raw'](1,2) == local['mod'](1,2)
155+
exec(ds.getimport(squared, alias='raw'), {}, local)
156+
exec(ds.getimport(ts.squared, alias='mod'), {}, local)
157+
assert local['raw'](3) == local['mod'](3)
158+
exec(ds.getimport(Bar, alias='raw'), {}, local)
159+
exec(ds.getimport(ts.Bar, alias='mod'), {}, local)
160+
assert ds.getname(local['raw']) == ds.getname(local['mod'])
161+
exec(ds.getimport(tm.quad, alias='mod'), {}, local)
162+
assert local['mod']()(sum)([1,2,3]) == tm.quad()(sum)([1,2,3])
163+
#FIXME: wrong results for nested functions (e.g. tm.double_add, tm.quadratic)
164+
165+
166+
def test_importable():
167+
assert ds.importable(add, source=False) == ds.getimport(add)
168+
assert ds.importable(add) == ds.getsource(add)
169+
assert ds.importable(squared, source=False) == ds.getimport(squared)
170+
assert ds.importable(squared) == ds.getsource(squared)
171+
assert ds.importable(Bar, source=False) == ds.getimport(Bar)
172+
assert ds.importable(Bar) == ds.getsource(Bar)
173+
assert ds.importable(ts.add) == ds.getimport(ts.add)
174+
assert ds.importable(ts.add, source=True) == ds.getsource(ts.add)
175+
assert ds.importable(ts.squared) == ds.getimport(ts.squared)
176+
assert ds.importable(ts.squared, source=True) == ds.getsource(ts.squared)
177+
assert ds.importable(ts.Bar) == ds.getimport(ts.Bar)
178+
assert ds.importable(ts.Bar, source=True) == ds.getsource(ts.Bar)
179+
180+
181+
if __name__ == '__main__':
182+
test_isfrommain()
183+
test_matchlambda()
184+
test_findsource()
185+
test_getsourcelines()
186+
test_indent()
187+
test_dumpsource()
188+
test_name()
189+
test_getimport()
190+
test_importable()

0 commit comments

Comments
 (0)