diff piecrust/dataproviders/pageiterator.py @ 979:45ad976712ec

tests: Big push to get the tests to pass again. - Lots of fixes everywhere in the code. - Try to handle debug logging in the multiprocessing worker pool when running in pytest. Not perfect, but usable for now. - Replace all `.md` test files with `.html` since now a auto-format extension always sets the format. - Replace `out` with `outfiles` in most places since now blog archives are added to the bake output and I don't want to add expected outputs for blog archives everywhere.
author Ludovic Chabant <ludovic@chabant.com>
date Sun, 29 Oct 2017 22:51:57 -0700
parents abc52a6262a1
children 492b66482f12
line wrap: on
line diff
--- a/piecrust/dataproviders/pageiterator.py	Sun Oct 29 22:46:41 2017 -0700
+++ b/piecrust/dataproviders/pageiterator.py	Sun Oct 29 22:51:57 2017 -0700
@@ -9,11 +9,39 @@
 logger = logging.getLogger(__name__)
 
 
-class _ItInfo:
-    def __init__(self):
+class _CombinedSource:
+    def __init__(self, sources):
+        self.sources = sources
+        self.app = sources[0].app
+        self.name = None
+
+        # This is for recursive traversal of the iterator chain.
+        # See later in `PageIterator`.
         self.it = None
-        self.iterated = False
-        self.source_name = None
+
+    def __iter__(self):
+        sources = self.sources
+
+        if len(sources) == 1:
+            source = sources[0]
+            self.name = source.name
+            yield from source.getAllPages()
+            self.name = None
+            return
+
+        # Return the pages from all the combined sources, but skip
+        # those that are "overridden" -- e.g. a theme page that gets
+        # replaced by a user page of the same name.
+        used_uris = set()
+        for source in sources:
+            self.name = source.name
+            for page in source.getAllPages():
+                page_uri = page.getUri()
+                if page_uri not in used_uris:
+                    used_uris.add(page_uri)
+                    yield page
+
+        self.name = None
 
 
 class PageIteratorDataProvider(DataProvider):
@@ -31,36 +59,37 @@
 
     def __init__(self, source, page):
         super().__init__(source, page)
-        self._its = None
         self._app = source.app
+        self._it = None
+        self._iterated = False
 
     def __len__(self):
         self._load()
-        return sum([len(i.it) for i in self._its])
+        return len(self._it)
 
     def __iter__(self):
         self._load()
-        for i in self._its:
-            yield from i.it
+        yield from self._it
 
     def _load(self):
-        if self._its is not None:
+        if self._it is not None:
             return
 
-        self._its = []
-        for source in self._sources:
-            i = _ItInfo()
-            i.it = PageIterator(source, current_page=self._page)
-            i.it._iter_event += self._onIteration
-            i.source_name = source.name
-            self._its.append(i)
+        combined_source = _CombinedSource(list(reversed(self._sources)))
+        self._it = PageIterator(combined_source, current_page=self._page)
+        self._it._iter_event += self._onIteration
 
     def _onIteration(self, it):
-        ii = next(filter(lambda i: i.it == it, self._its))
-        if not ii.iterated:
+        if not self._iterated:
             rcs = self._app.env.render_ctx_stack
-            rcs.current_ctx.addUsedSource(ii.source_name)
-            ii.iterated = True
+            rcs.current_ctx.addUsedSource(it._source)
+            self._iterated = True
+
+    def _addSource(self, source):
+        if self._it is not None:
+            raise Exception("Can't add sources after the data provider "
+                            "has been loaded.")
+        super()._addSource(source)
 
     def _debugRenderDoc(self):
         return 'Provides a list of %d items' % len(self)
@@ -69,7 +98,8 @@
 class PageIterator:
     def __init__(self, source, *, current_page=None):
         self._source = source
-        self._is_content_source = isinstance(source, ContentSource)
+        self._is_content_source = isinstance(
+            source, (ContentSource, _CombinedSource))
         self._cache = None
         self._pagination_slicer = None
         self._has_sorter = False
@@ -150,14 +180,11 @@
                             (filter_name, self._current_page.path))
         return self._simpleNonSortedWrap(SettingFilterIterator, filter_conf)
 
-    def sort(self, setting_name, reverse=False):
-        if not setting_name:
-            raise Exception("You need to specify a configuration setting "
-                            "to sort by.")
-        self._ensureUnlocked()
-        self._ensureUnloaded()
-        self._pages = SettingSortIterator(self._pages, setting_name, reverse)
-        self._has_sorter = True
+    def sort(self, setting_name=None, reverse=False):
+        if setting_name:
+            self._wrapAsSort(SettingSortIterator, setting_name, reverse)
+        else:
+            self._wrapAsSort(NaturalSortIterator, reverse)
         return self
 
     def reset(self):
@@ -171,12 +198,15 @@
 
     @property
     def _has_more(self):
-        if self._cache is None:
-            return False
+        self._load()
         if self._pagination_slicer:
             return self._pagination_slicer.has_more
         return False
 
+    @property
+    def _is_loaded_and_has_more(self):
+        return self._is_loaded and self._has_more
+
     def _simpleWrap(self, it_class, *args, **kwargs):
         self._ensureUnlocked()
         self._ensureUnloaded()
@@ -226,7 +256,11 @@
 
     def _initIterator(self):
         if self._is_content_source:
-            self._it = PageContentSourceIterator(self._source)
+            if isinstance(self._source, _CombinedSource):
+                self._it = self._source
+            else:
+                self._it = PageContentSourceIterator(self._source)
+
             app = self._source.app
             if app.config.get('baker/is_baking'):
                 # While baking, automatically exclude any page with
@@ -333,6 +367,15 @@
         return iter(self._cache)
 
 
+class NaturalSortIterator:
+    def __init__(self, it, reverse=False):
+        self.it = it
+        self.reverse = reverse
+
+    def __iter__(self):
+        return iter(sorted(self.it, reverse=self.reverse))
+
+
 class SettingSortIterator:
     def __init__(self, it, name, reverse=False):
         self.it = it
@@ -344,7 +387,7 @@
                            reverse=self.reverse))
 
     def _key_getter(self, item):
-        key = item.config.get(item)
+        key = item.config.get(self.name)
         if key is None:
             return 0
         return key