diff --git a/examples/demo.py b/examples/demo.py index b332b4f..13979e8 100755 --- a/examples/demo.py +++ b/examples/demo.py @@ -21,11 +21,17 @@ MB = 1024 * 1024 +@memprof class FooClass(object): - def __init__(self): + def __init__(self, limit): self.a = [1] * MB self.b = [1] * MB * 2 + for _ in range(limit): + self.a.append(1) + self.b.append(1) + self.b.append(1) + @memprof def bar(limit=10000): bar_a = [1] * MB * 10 @@ -56,13 +62,8 @@ def foo(limit=500000): elif i == (limit*3)/4: c = [1] * MB * 2 -@memprof -def fooObject(limit=500000): - a = FooClass() - - for i in range(limit): - a.a.append(1) - a.b.append(1) +def fooObject(limit=3000000): + foo = FooClass(limit) foo() diff --git a/memprof/memprof.py b/memprof/memprof.py index 66bef3b..3c10e17 100755 --- a/memprof/memprof.py +++ b/memprof/memprof.py @@ -19,6 +19,7 @@ import sys import time import argparse +import inspect import types from .mp_utils import * @@ -28,8 +29,14 @@ def memprof(*args, **kwargs): def inner(func): return MemProf(func, *args, **kwargs) - # To allow @memprof with parameters - if len(args) and callable(args[0]): + if inspect.isclass(args[0]): + cls = args[0] + args = args[1:] + kwargs["funcname"] = cls.__name__ + cls.__init__ = inner(cls.__init__) + return cls + elif callable(args[0]): + # To allow @memprof with parameters func = args[0] args = args[1:] return inner(func) @@ -38,8 +45,11 @@ def inner(func): class MemProf(object): - def __init__(self, func, threshold=default_threshold, plot=False): + def __init__(self, func, threshold=default_threshold, plot=False, + funcname=None): self.func = func + if funcname is None: + funcname = self.func.__name__ self.__locals = {} self.__start = -1 self.__prev = -1 @@ -47,7 +57,7 @@ def __init__(self, func, threshold=default_threshold, plot=False): self.__refresh = 500000 self.__ticks = 0 self.__checkTimes = [] - self.__logfile = "%s.log" % self.func.__name__ + self.__logfile = "%s.log" % funcname self.__plot = self.func.__globals__["memprof_plot"] if "memprof_plot" in self.func.__globals__ else plot self.threshold = self.func.__globals__["memprof_threshold"] if "memprof_threshold" in self.func.__globals__ else threshold diff --git a/testsuite/test1.py b/testsuite/test1.py index 8997a6e..fd0388d 100644 --- a/testsuite/test1.py +++ b/testsuite/test1.py @@ -18,6 +18,9 @@ import os import sys +sys.path.insert(0, "../memprof") +from memprof import memprof + class Test(unittest.TestCase): # Rough test: just run the example def test_demo(self): @@ -27,3 +30,21 @@ def test_demo(self): os.chdir(examples_path) import demo os.chdir(root) + + def test_class(self): + """memprof works with classes""" + MB = 1024 * 1024 + + @memprof + class FooClass(object): + def __init__(self): + self.a = [1] * MB + self.b = [1] * MB * 2 + + def append(self, limit=500000): + for _ in range(limit): + self.a.append(1) + self.b.append(1) + + foo = FooClass() + foo.append()