Mercurial > piecrust2
diff piecrust/dataproviders/pageiterator.py @ 979:45ad976712ec
tests: Big push to get the tests to pass again.
- Lots of fixes everywhere in the code.
- Try to handle debug logging in the multiprocessing worker pool when running in pytest. Not perfect, but usable for now.
- Replace all `.md` test files with `.html` since now a auto-format extension always sets the format.
- Replace `out` with `outfiles` in most places since now blog archives are added to the bake output and I don't want to add expected outputs for blog archives everywhere.
author | Ludovic Chabant <ludovic@chabant.com> |
---|---|
date | Sun, 29 Oct 2017 22:51:57 -0700 |
parents | abc52a6262a1 |
children | 492b66482f12 |
line wrap: on
line diff
--- a/piecrust/dataproviders/pageiterator.py Sun Oct 29 22:46:41 2017 -0700 +++ b/piecrust/dataproviders/pageiterator.py Sun Oct 29 22:51:57 2017 -0700 @@ -9,11 +9,39 @@ logger = logging.getLogger(__name__) -class _ItInfo: - def __init__(self): +class _CombinedSource: + def __init__(self, sources): + self.sources = sources + self.app = sources[0].app + self.name = None + + # This is for recursive traversal of the iterator chain. + # See later in `PageIterator`. self.it = None - self.iterated = False - self.source_name = None + + def __iter__(self): + sources = self.sources + + if len(sources) == 1: + source = sources[0] + self.name = source.name + yield from source.getAllPages() + self.name = None + return + + # Return the pages from all the combined sources, but skip + # those that are "overridden" -- e.g. a theme page that gets + # replaced by a user page of the same name. + used_uris = set() + for source in sources: + self.name = source.name + for page in source.getAllPages(): + page_uri = page.getUri() + if page_uri not in used_uris: + used_uris.add(page_uri) + yield page + + self.name = None class PageIteratorDataProvider(DataProvider): @@ -31,36 +59,37 @@ def __init__(self, source, page): super().__init__(source, page) - self._its = None self._app = source.app + self._it = None + self._iterated = False def __len__(self): self._load() - return sum([len(i.it) for i in self._its]) + return len(self._it) def __iter__(self): self._load() - for i in self._its: - yield from i.it + yield from self._it def _load(self): - if self._its is not None: + if self._it is not None: return - self._its = [] - for source in self._sources: - i = _ItInfo() - i.it = PageIterator(source, current_page=self._page) - i.it._iter_event += self._onIteration - i.source_name = source.name - self._its.append(i) + combined_source = _CombinedSource(list(reversed(self._sources))) + self._it = PageIterator(combined_source, current_page=self._page) + self._it._iter_event += self._onIteration def _onIteration(self, it): - ii = next(filter(lambda i: i.it == it, self._its)) - if not ii.iterated: + if not self._iterated: rcs = self._app.env.render_ctx_stack - rcs.current_ctx.addUsedSource(ii.source_name) - ii.iterated = True + rcs.current_ctx.addUsedSource(it._source) + self._iterated = True + + def _addSource(self, source): + if self._it is not None: + raise Exception("Can't add sources after the data provider " + "has been loaded.") + super()._addSource(source) def _debugRenderDoc(self): return 'Provides a list of %d items' % len(self) @@ -69,7 +98,8 @@ class PageIterator: def __init__(self, source, *, current_page=None): self._source = source - self._is_content_source = isinstance(source, ContentSource) + self._is_content_source = isinstance( + source, (ContentSource, _CombinedSource)) self._cache = None self._pagination_slicer = None self._has_sorter = False @@ -150,14 +180,11 @@ (filter_name, self._current_page.path)) return self._simpleNonSortedWrap(SettingFilterIterator, filter_conf) - def sort(self, setting_name, reverse=False): - if not setting_name: - raise Exception("You need to specify a configuration setting " - "to sort by.") - self._ensureUnlocked() - self._ensureUnloaded() - self._pages = SettingSortIterator(self._pages, setting_name, reverse) - self._has_sorter = True + def sort(self, setting_name=None, reverse=False): + if setting_name: + self._wrapAsSort(SettingSortIterator, setting_name, reverse) + else: + self._wrapAsSort(NaturalSortIterator, reverse) return self def reset(self): @@ -171,12 +198,15 @@ @property def _has_more(self): - if self._cache is None: - return False + self._load() if self._pagination_slicer: return self._pagination_slicer.has_more return False + @property + def _is_loaded_and_has_more(self): + return self._is_loaded and self._has_more + def _simpleWrap(self, it_class, *args, **kwargs): self._ensureUnlocked() self._ensureUnloaded() @@ -226,7 +256,11 @@ def _initIterator(self): if self._is_content_source: - self._it = PageContentSourceIterator(self._source) + if isinstance(self._source, _CombinedSource): + self._it = self._source + else: + self._it = PageContentSourceIterator(self._source) + app = self._source.app if app.config.get('baker/is_baking'): # While baking, automatically exclude any page with @@ -333,6 +367,15 @@ return iter(self._cache) +class NaturalSortIterator: + def __init__(self, it, reverse=False): + self.it = it + self.reverse = reverse + + def __iter__(self): + return iter(sorted(self.it, reverse=self.reverse)) + + class SettingSortIterator: def __init__(self, it, name, reverse=False): self.it = it @@ -344,7 +387,7 @@ reverse=self.reverse)) def _key_getter(self, item): - key = item.config.get(item) + key = item.config.get(self.name) if key is None: return 0 return key