summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rwxr-xr-xtest/all_benchmarks.py75
1 files changed, 35 insertions, 40 deletions
diff --git a/test/all_benchmarks.py b/test/all_benchmarks.py
index 84d6817d..e1b47142 100755
--- a/test/all_benchmarks.py
+++ b/test/all_benchmarks.py
@@ -14,45 +14,40 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
-"""Run all the tests inside the test/ directory as a test suite."""
+"""
+Run all the benchmarks inside this directory.
+Usage: ./all_benchmarks.py [count] [regexp-filters...]
+"""
+
+import os
+import re
+import sys
+import time
+
 if __name__ == '__main__':
-	from re import compile
-	from test import *
-	from time import time
-	from types import FunctionType as function
-	from sys import argv
-	bms = []
-	try:
-		n = int(argv[1])
-	except IndexError:
-		n = 10
-	if len(argv) > 2:
-		args = [compile(re) for re in argv[2:]]
-		def allow(name):
-			for re in args:
-				if re.search(name):
-					return True
+	count   = int(sys.argv[1]) if len(sys.argv) > 1 else 10
+	regexes = [re.compile(fltr) for fltr in sys.argv[2:]]
+	modules = (fname[:-3] for fname in os.listdir(sys.path[0]) \
+			if fname[:3] == 'bm_' and fname[-3:] == '.py')
+
+	benchmarks = []  # find all benchmark (class, methodname) pairs
+	for val in [__import__(module) for module in modules]:
+		for cls in vars(val).values():
+			if type(cls) == type:
+				for methodname in vars(cls):
+					if methodname.startswith('bm_'):
+						benchmarks.append((cls, methodname))
+
+	for cls, methodname in benchmarks:
+		full_method_name = "{0}.{1}".format(cls.__name__, methodname)
+		if all(re.search(full_method_name) for re in regexes):
+			method = getattr(cls(), methodname)
+			t1 = time.time()
+			try:
+				method(count)
+			except:
+				print("{0} failed!".format(full_method_name))
+				raise
 			else:
-				return False
-	else:
-		allow = lambda name: True
-	for key, val in vars().copy().items():
-		if key.startswith('bm_'):
-			bms.extend(v for k,v in vars(val).items() if type(v) == type)
-	for bmclass in bms:
-		for attrname in vars(bmclass):
-			if not attrname.startswith('bm_'):
-				continue
-			bmobj = bmclass()
-			t1 = time()
-			method = getattr(bmobj, attrname)
-			methodname = "{0}.{1}".format(bmobj.__class__.__name__, method.__name__)
-			if allow(methodname):
-				try:
-					method(n)
-				except:
-					print("{0} failed!".format(methodname))
-					raise
-				else:
-					t2 = time()
-					print("{0:60}: {1:10}s".format(methodname, t2 - t1))
+				t2 = time.time()
+				print("{0:60}: {1:10}s".format(full_method_name, t2 - t1))