diff piecrust/dataproviders/pageiterator.py @ 867:757fba54bfd3

refactor: Improve pagination and iterators to work with other sources. This makes the assets work as a pagination source again.
author Ludovic Chabant <ludovic@chabant.com>
date Mon, 12 Jun 2017 22:20:58 -0700
parents 9bb22bbe093c
children d6d35b2efd04
line wrap: on
line diff
--- a/piecrust/dataproviders/pageiterator.py	Mon Jun 12 22:10:50 2017 -0700
+++ b/piecrust/dataproviders/pageiterator.py	Mon Jun 12 22:20:58 2017 -0700
@@ -3,7 +3,7 @@
 from piecrust.data.paginationdata import PaginationData
 from piecrust.events import Event
 from piecrust.dataproviders.base import DataProvider
-from piecrust.sources.base import AbortedSourceUseError
+from piecrust.sources.base import ContentSource, AbortedSourceUseError
 
 
 logger = logging.getLogger(__name__)
@@ -13,6 +13,7 @@
     def __init__(self):
         self.it = None
         self.iterated = False
+        self.source_name = None
 
 
 class PageIteratorDataProvider(DataProvider):
@@ -51,13 +52,14 @@
             i = _ItInfo()
             i.it = PageIterator(source, current_page=self._page)
             i.it._iter_event += self._onIteration
+            i.source_name = source.name
             self._its.append(i)
 
     def _onIteration(self, it):
         ii = next(filter(lambda i: i.it == it, self._its))
         if not ii.iterated:
             rcs = self._app.env.render_ctx_stack
-            rcs.current_ctx.addUsedSource(self._source.name)
+            rcs.current_ctx.addUsedSource(ii.source_name)
             ii.iterated = True
 
     def _debugRenderDoc(self):
@@ -67,6 +69,7 @@
 class PageIterator:
     def __init__(self, source, *, current_page=None):
         self._source = source
+        self._is_content_source = isinstance(source, ContentSource)
         self._cache = None
         self._pagination_slicer = None
         self._has_sorter = False
@@ -75,7 +78,7 @@
         self._locked = False
         self._iter_event = Event()
         self._current_page = current_page
-        self._it = PageContentSourceIterator(self._source)
+        self._initIterator()
 
     @property
     def total_count(self):
@@ -215,11 +218,18 @@
     def _ensureSorter(self):
         if self._has_sorter:
             return
-        self._it = DateSortIterator(self._it, reverse=True)
+        if self._is_content_source:
+            self._it = DateSortIterator(self._it, reverse=True)
         self._has_sorter = True
 
+    def _initIterator(self):
+        if self._is_content_source:
+            self._it = PageContentSourceIterator(self._source)
+        else:
+            self._it = GenericSourceIterator(self._source)
+
     def _unload(self):
-        self._it = PageContentSourceIterator(self._source)
+        self._initIterator()
         self._cache = None
         self._paginationSlicer = None
         self._has_sorter = False
@@ -230,27 +240,29 @@
         if self._cache is not None:
             return
 
-        if self._source.app.env.abort_source_use:
-            if self._current_page is not None:
-                logger.debug("Aborting iteration of '%s' from: %s." %
-                             (self._source.name,
-                              self._current_page.content_spec))
-            else:
-                logger.debug("Aborting iteration of '%s'." %
-                             self._source.name)
-            raise AbortedSourceUseError()
+        if self._is_content_source:
+            if self._source.app.env.abort_source_use:
+                if self._current_page is not None:
+                    logger.debug("Aborting iteration of '%s' from: %s." %
+                                 (self._source.name,
+                                  self._current_page.content_spec))
+                else:
+                    logger.debug("Aborting iteration of '%s'." %
+                                 self._source.name)
+                raise AbortedSourceUseError()
 
         self._ensureSorter()
 
-        tail_it = PaginationDataBuilderIterator(self._it, self._source.route)
-        self._cache = list(tail_it)
+        if self._is_content_source:
+            self._it = PaginationDataBuilderIterator(self._it)
+
+        self._cache = list(self._it)
 
         if (self._current_page is not None and
                 self._pagination_slicer is not None):
             pn = [self._pagination_slicer.prev_page,
                   self._pagination_slicer.next_page]
-            pn_it = PaginationDataBuilderIterator(iter(pn),
-                                                  self._source.route)
+            pn_it = PaginationDataBuilderIterator(iter(pn))
             self._prev_page, self._next_page = (list(pn_it))
 
         self._iter_event.fire(self)
@@ -367,9 +379,8 @@
 
 
 class PaginationDataBuilderIterator:
-    def __init__(self, it, route):
+    def __init__(self, it):
         self.it = it
-        self.route = route
 
     def __iter__(self):
         for page in self.it:
@@ -378,3 +389,11 @@
             else:
                 yield None
 
+
+class GenericSourceIterator:
+    def __init__(self, source):
+        self.source = source
+        self.it = None
+
+    def __iter__(self):
+        yield from self.source