diff piecrust/routing.py @ 792:58ebf50235a5

routing: Simplify how routes are defined. * No more declaring the type of route parameters -- the sources and generators already know what type each parameter is supposed to be. * Same for variadic parameters -- we know already. * Update cache version to force a clear reload of the config. * Update tests. TODO: simplify code in the `Route` class to use source or generator transparently.
author Ludovic Chabant <ludovic@chabant.com>
date Wed, 07 Sep 2016 08:58:41 -0700
parents 504d6817352d
children 4850f8c21b6e
line wrap: on
line diff
--- a/piecrust/routing.py	Mon Sep 05 22:30:05 2016 -0700
+++ b/piecrust/routing.py	Wed Sep 07 08:58:41 2016 -0700
@@ -24,19 +24,25 @@
 
 def create_route_metadata(page):
     route_metadata = copy.deepcopy(page.source_metadata)
-    route_metadata.update(page.getRouteMetadata())
     return route_metadata
 
 
-class IRouteMetadataProvider(object):
-    def getRouteMetadata(self):
-        raise NotImplementedError()
-
-
 ROUTE_TYPE_SOURCE = 0
 ROUTE_TYPE_GENERATOR = 1
 
 
+class RouteParameter(object):
+    TYPE_STRING = 0
+    TYPE_PATH = 1
+    TYPE_INT2 = 2
+    TYPE_INT4 = 3
+
+    def __init__(self, param_name, param_type=TYPE_STRING, *, variadic=False):
+        self.param_name = param_name
+        self.param_type = param_type
+        self.variadic = variadic
+
+
 class Route(object):
     """ Information about a route for a PieCrust application.
         Each route defines the "shape" of an URL and how it maps to
@@ -51,6 +57,13 @@
             raise InvalidRouteError(
                     "Both `source` and `generator` are specified.")
 
+        self.uri_pattern = cfg['url'].lstrip('/')
+
+        if self.is_source_route:
+            self.supported_params = self.source.getSupportedRouteParameters()
+        else:
+            self.supported_params = self.generator.getSupportedRouteParameters()
+
         self.pretty_urls = app.config.get('site/pretty_urls')
         self.trailing_slash = app.config.get('site/trailing_slash')
         self.show_debug_info = app.config.get('site/show_debug_info')
@@ -58,8 +71,7 @@
                 '__cache/pagination_suffix_format')
         self.uri_root = app.config.get('site/root')
 
-        uri = cfg['url']
-        self.uri_pattern = uri.lstrip('/')
+        self.uri_params = []
         self.uri_format = route_re.sub(self._uriFormatRepl, self.uri_pattern)
 
         # Get the straight-forward regex for matching this URI pattern.
@@ -85,31 +97,17 @@
         else:
             self.uri_re_no_path = None
 
-        # Determine the parameters for the route function.
         self.func_name = self._validateFuncName(cfg.get('func'))
-        self.func_parameters = []
         self.func_has_variadic_parameter = False
-        self.param_types = {}
-        variadic_param_idx = -1
-        for m in route_re.finditer(self.uri_pattern):
-            name = m.group('name')
-            self.func_parameters.append(name)
-
-            qual = m.group('qual')
-            if not qual:
-                qual = self._getBackwardCompatibleParamType(name)
-            if qual:
-                self.param_types[name] = qual
-
-            if m.group('var'):
-                self.func_has_variadic_parameter = True
-                variadic_param_idx = len(self.func_parameters) - 1
-
-        if (variadic_param_idx >= 0 and
-                variadic_param_idx != len(self.func_parameters) - 1):
-            raise Exception(
-                "Only the last route URL parameter can be variadic. "
-                "Got: %s" % self.uri_pattern)
+        for p in self.uri_params[:-1]:
+            param = self.getParameter(p)
+            if param.variadic:
+                raise Exception(
+                    "Only the last route URL parameter can be variadic. "
+                    "Got: %s" % self.uri_pattern)
+        if len(self.uri_params) > 0:
+            last_param = self.getParameter(self.uri_params[-1])
+            self.func_has_variadic_parameter = last_param.variadic
 
     @property
     def route_type(self):
@@ -136,7 +134,7 @@
             if src.name == self.source_name:
                 return src
         raise Exception("Can't find source '%s' for route '%s'." % (
-                self.source_name, self.uri))
+                self.source_name, self.uri_pattern))
 
     @cached_property
     def generator(self):
@@ -146,10 +144,23 @@
             if gen.name == self.generator_name:
                 return gen
         raise Exception("Can't find generator '%s' for route '%s'." % (
-                self.generator_name, self.uri))
+                self.generator_name, self.uri_pattern))
+
+    def hasParameter(self, name):
+        return any(lambda p: p.param_name == name, self.supported_params)
+
+    def getParameter(self, name):
+        for p in self.supported_params:
+            if p.param_name == name:
+                return p
+        raise Exception("No such supported route parameter '%s' in: %s" %
+                        (name, self.uri_pattern))
+
+    def getParameterType(self, name):
+        return self.getParameter(name).param_type
 
     def matchesMetadata(self, route_metadata):
-        return set(self.func_parameters).issubset(route_metadata.keys())
+        return set(self.uri_params).issubset(route_metadata.keys())
 
     def matchUri(self, uri, strict=False):
         if not uri.startswith(self.uri_root):
@@ -178,8 +189,10 @@
             # say, a route's pattern is `/foo/%slug%`, and we're matching an
             # URL like `/foo`.
             matched_keys = set(route_metadata.keys())
-            missing_keys = set(self.func_parameters) - matched_keys
+            missing_keys = set(self.uri_params) - matched_keys
             for k in missing_keys:
+                if self.getParameterType(k) != RouteParameter.TYPE_PATH:
+                    return None
                 route_metadata[k] = ''
 
         for k in route_metadata:
@@ -239,7 +252,7 @@
         return uri
 
     def execTemplateFunc(self, *args):
-        fixed_param_count = len(self.func_parameters)
+        fixed_param_count = len(self.uri_params)
         if self.func_has_variadic_parameter:
             fixed_param_count -= 1
 
@@ -258,7 +271,7 @@
             coerced_args = args
 
         metadata = {}
-        for arg_name, arg_val in zip(self.func_parameters, coerced_args):
+        for arg_name, arg_val in zip(self.uri_params, coerced_args):
             metadata[arg_name] = self._coerceRouteParameter(
                     arg_name, arg_val)
 
@@ -268,87 +281,62 @@
         return self.getUri(metadata)
 
     def _uriFormatRepl(self, m):
-        qual = m.group('qual')
-        name = m.group('name')
+        if m.group('qual') or m.group('var'):
+            # Print a warning only if we're not in a worker process.
+            print_warning = not self.app.config.has('baker/worker_id')
+            if print_warning:
+                logger.warning("Route '%s' specified parameter types -- "
+                               "they're not needed anymore." %
+                               self.uri_pattern)
 
-        # Backwards compatibility... this will print a warning later.
-        if qual is None:
-            if name == 'year':
-                qual = 'int4'
-            elif name in ['month', 'day']:
-                qual = 'int2'
-
-        if qual == 'int4':
-            return '%%(%s)04d' % name
-        elif qual == 'int2':
-            return '%%(%s)02d' % name
-        elif qual and qual != 'path':
-            raise Exception("Unknown route parameter type: %s" % qual)
-        return '%%(%s)s' % name
+        name = m.group('name')
+        self.uri_params.append(name)
+        try:
+            param_type = self.getParameterType(name)
+            if param_type == RouteParameter.TYPE_INT4:
+                return '%%(%s)04d' % name
+            elif param_type == RouteParameter.TYPE_INT2:
+                return '%%(%s)02d' % name
+            return '%%(%s)s' % name
+        except:
+            known = [p.name for p in self.supported_params]
+            raise Exception("Unknown route parameter '%s' for route '%s'. "
+                            "Must be one of: %s'" %
+                            (name, self.uri_pattern, known))
 
     def _uriPatternRepl(self, m):
         name = m.group('name')
-        qual = m.group('qual')
-
-        # Backwards compatibility... this will print a warning later.
-        if qual is None:
-            if name == 'year':
-                qual = 'int4'
-            elif name in ['month', 'day']:
-                qual = 'int2'
-
-        if qual == 'path' or m.group('var'):
+        param_type = self.getParameterType(name)
+        if param_type == RouteParameter.TYPE_PATH:
             return r'(?P<%s>[^\?]*)' % name
-        elif qual == 'int4':
+        elif param_type == RouteParameter.TYPE_INT4:
             return r'(?P<%s>\d{4})' % name
-        elif qual == 'int2':
+        elif param_type == RouteParameter.TYPE_INT2:
             return r'(?P<%s>\d{2})' % name
-        elif qual and qual != 'path':
-            raise Exception("Unknown route parameter type: %s" % qual)
         return r'(?P<%s>[^/\?]+)' % name
 
     def _uriNoPathRepl(self, m):
         name = m.group('name')
-        qualifier = m.group('qual')
-        if qualifier == 'path':
+        param_type = self.getParameterType(name)
+        if param_type == RouteParameter.TYPE_PATH:
             return ''
         return r'(?P<%s>[^/\?]+)' % name
 
     def _coerceRouteParameter(self, name, val):
-        param_type = self.param_types.get(name)
-        if param_type is None:
+        try:
+            param_type = self.getParameterType(name)
+        except:
+            # Unknown parameter... just leave it.
             return val
-        if param_type in ['int', 'int2', 'int4']:
+
+        if param_type in [RouteParameter.TYPE_INT2, RouteParameter.TYPE_INT4]:
             try:
                 return int(val)
             except ValueError:
                 raise Exception(
-                    "Expected route parameter '%s' to be of type "
-                    "'%s', but was: %s" %
-                    (name, param_type, val))
-        if param_type == 'path':
-            return val
-        raise Exception("Unknown route parameter type: %s" % param_type)
-
-    def _getBackwardCompatibleParamType(self, name):
-        # Print a warning only if we're not in a worker process.
-        print_warning = not self.app.config.has('baker/worker_id')
-
-        if name in ['year']:
-            if print_warning:
-                logger.warning(
-                    "Route parameter '%%%s%%' has no type qualifier. "
-                    "You probably meant '%%int4:%s%%' so we'll use that." %
-                    (name, name))
-            return 'int4'
-        if name in ['month', 'day']:
-            if print_warning:
-                logger.warning(
-                    "Route parameter '%%%s%%' has no type qualifier. "
-                    "You probably meant '%%int2:%s%%' so we'll use that." %
-                    (name, name))
-            return 'int2'
-        return None
+                    "Expected route parameter '%s' to be an integer, "
+                    "but was: %s" % (name, param_type, val))
+        return val
 
     def _validateFuncName(self, name):
         if not name:
@@ -369,12 +357,12 @@
 
     def addFunc(self, route):
         if self._arg_names is None:
-            self._arg_names = list(route.func_parameters)
+            self._arg_names = list(route.uri_params)
 
-        if route.func_parameters != self._arg_names:
+        if route.uri_params != self._arg_names:
             raise Exception("Cannot merge route function with arguments '%s' "
                             "with route function with arguments '%s'." %
-                            (route.func_parameters, self._arg_names))
+                            (route.uri_params, self._arg_names))
         self._routes.append(route)
 
     def __call__(self, *args, **kwargs):