[#8314] make @memoize not be a global cache for methods. And support kwargs properly, used in icon_url()
diff --git a/Allura/allura/lib/decorators.py b/Allura/allura/lib/decorators.py
index 589b327..8bb6413 100644
--- a/Allura/allura/lib/decorators.py
+++ b/Allura/allura/lib/decorators.py
@@ -22,11 +22,11 @@
from Cookie import Cookie
from collections import defaultdict
from urllib import unquote
-
from datetime import datetime
-
from datetime import timedelta
+
from decorator import decorator
+import wrapt
from paste.deploy.converters import asint
from tg.decorators import before_validate
from tg import request, redirect, session, config
@@ -236,21 +236,36 @@
return default
-@decorator
-def memoize(func, *args):
+@wrapt.decorator
+def memoize(func, instance, args, kwargs):
"""
Cache the method's result, for the given args
"""
- dic = getattr_(func, "memoize_dic", dict)
- # memoize_dic is created at the first call
- if args in dic:
- return dic[args]
+ if instance is None:
+ # decorating a simple function
+ dic = getattr_(func, "_memoize_dic", dict)
else:
- result = func(*args)
- dic[args] = result
+ # decorating a method
+ dic = getattr_(instance, "_memoize_dic__{}".format(func.__name__), dict)
+
+ cache_key = (args, frozenset(kwargs.items()))
+ if cache_key in dic:
+ return dic[cache_key]
+ else:
+ result = func(*args, **kwargs)
+ dic[cache_key] = result
return result
+def memoize_cleanup(obj):
+ """
+ Remove any _memoize_dic_* keys (if obj is a dict/obj hybrid) that were created by @memoize on methods
+ """
+ for k in obj.keys():
+ if k.startswith('_memoize_dic'):
+ del obj[k]
+
+
def memorable_forget():
"""
Decorator to mark a controller action as needing to "forget" remembered input values on the next
diff --git a/Allura/allura/tests/test_decorators.py b/Allura/allura/tests/test_decorators.py
index 0d1d338..7261711 100644
--- a/Allura/allura/tests/test_decorators.py
+++ b/Allura/allura/tests/test_decorators.py
@@ -14,12 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+import inspect
from unittest import TestCase
-
from mock import patch
+import random
+import gc
-from allura.lib.decorators import task
+from nose.tools import assert_equal, assert_not_equal
+
+from allura.lib.decorators import task, memoize
class TestTask(TestCase):
@@ -54,3 +57,85 @@
c.project.notifications_disabled = False
MonQTask.post.side_effect = mock_post
func.post('test', foo=2, delay=1)
+
+
+class TestMemoize(object):
+
+ def test_function(self):
+ @memoize
+ def remember_randomy(do_random, foo=None):
+ if do_random:
+ return random.random()
+ else:
+ return "constant"
+
+ rand1 = remember_randomy(True)
+ rand2 = remember_randomy(True)
+ const1 = remember_randomy(False)
+ rand_kwargs1 = remember_randomy(True, foo='asdf')
+ rand_kwargs2 = remember_randomy(True, foo='xyzzy')
+ assert_equal(rand1, rand2)
+ assert_equal(const1, "constant")
+ assert_not_equal(rand1, rand_kwargs1)
+ assert_not_equal(rand_kwargs1, rand_kwargs2)
+
+ def test_methods(self):
+
+ class Randomy(object):
+ @memoize
+ def randomy(self, do_random):
+ if do_random:
+ return random.random()
+ else:
+ return "constant"
+
+ @memoize
+ def other(self, do_random):
+ if do_random:
+ return random.random()
+ else:
+ return "constant"
+
+ r = Randomy()
+ rand1 = r.randomy(True)
+ rand2 = r.randomy(True)
+ const1 = r.randomy(False)
+ other1 = r.other(True)
+ other2 = r.other(True)
+
+ assert_equal(rand1, rand2)
+ assert_equal(const1, "constant")
+ assert_not_equal(rand1, other1)
+ assert_equal(other1, other2)
+
+ r2 = Randomy()
+ r2rand1 = r2.randomy(True)
+ r2rand2 = r2.randomy(True)
+ r2const1 = r2.randomy(False)
+ r2other1 = r2.other(True)
+ r2other2 = r2.other(True)
+
+ assert_not_equal(r2rand1, rand1)
+ assert_equal(r2rand1, r2rand2)
+ assert_not_equal(r2other1, other1)
+ assert_equal(r2other1, r2other2)
+
+ def test_methods_garbage_collection(self):
+
+ class Randomy(object):
+ @memoize
+ def randomy(self, do_random):
+ if do_random:
+ return random.random()
+ else:
+ return "constant"
+
+ r = Randomy()
+ rand1 = r.randomy(True)
+
+ for gc_ref in gc.get_referrers(r):
+ if inspect.isframe(gc_ref):
+ continue
+ else:
+ raise AssertionError('Unexpected reference to `r` instance: {!r}\n'
+ '@memoize probably made a reference to it and has created a circular reference loop'.format(gc_ref))
diff --git a/requirements.in b/requirements.in
index e82dc30..e37e2a7 100644
--- a/requirements.in
+++ b/requirements.in
@@ -42,6 +42,7 @@
TurboGears2==2.3.12
WebHelpers==1.3
WebOb==1.7.4
+wrapt==1.11.2
# testing
datadiff==1.1.5
diff --git a/requirements.txt b/requirements.txt
index b953b53..6bcbf3c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -80,6 +80,7 @@
webhelpers==1.3
webob==1.7.4
webtest==2.0.33
+wrapt==1.11.2
# The following packages are considered to be unsafe in a requirements file:
# setuptools==41.0.1 # via logilab-common