changeset 10:cd35d356ccce

Fix some bugs with iterators, add some unit tests.
author Ludovic Chabant <ludovic@chabant.com>
date Sun, 17 Aug 2014 21:18:48 -0700
parents 8f7ba2c95025
children 617191dec18e
files piecrust/data/iterators.py tests/test_data_iterators.py
diffstat 2 files changed, 118 insertions(+), 26 deletions(-) [+]
line wrap: on
line diff
--- a/piecrust/data/iterators.py	Sat Aug 16 23:30:26 2014 -0700
+++ b/piecrust/data/iterators.py	Sun Aug 17 21:18:48 2014 -0700
@@ -23,8 +23,14 @@
         if self._cache is None:
             inner_list = list(self.it)
             self.inner_count = len(inner_list)
-            self.has_more = self.inner_count > (self.offset + self.limit)
-            self._cache = inner_list[self.offset:self.offset + self.limit]
+
+            if self.limit > 0:
+                self.has_more = self.inner_count > (self.offset + self.limit)
+                self._cache = inner_list[self.offset:self.offset + self.limit]
+            else:
+                self.has_more = False
+                self._cache = inner_list[self.offset:]
+
             if self.current_page:
                 idx = inner_list.index(self.current_page)
                 if idx >= 0:
@@ -32,6 +38,7 @@
                         self.next_page = inner_list[idx + 1]
                     if idx > 0:
                         self.prev_page = inner_list[idx - 1]
+
         return iter(self._cache)
 
 
@@ -56,6 +63,15 @@
                 yield i
 
 
+class NaturalSortIterator(object):
+    def __init__(self, it, reverse=False):
+        self.it = it
+        self.reverse = reverse
+
+    def __iter__(self):
+        return iter(sorted(self.it, reverse=self.reverse))
+
+
 class SettingSortIterator(object):
     def __init__(self, it, name, reverse=False, value_accessor=None):
         self.it = it
@@ -64,29 +80,13 @@
         self.value_accessor = value_accessor
 
     def __iter__(self):
-        def comparer(x, y):
-            if self.value_accessor:
-                v1 = self.value_accessor(x, self.name)
-                v2 = self.value_accessor(y, self.name)
-            else:
-                v1 = x.config.get(self.name)
-                v2 = y.config.get(self.name)
+        return iter(sorted(self.it, key=self._key_getter,
+                           reverse=self.reverse))
 
-            if v1 is None and v2 is None:
-                return 0
-            if v1 is None and v2 is not None:
-                return 1 if self.reverse else -1
-            if v1 is not None and v2 is None:
-                return -1 if self.reverse else 1
-
-            if v1 == v2:
-                return 0
-            if self.reverse:
-                return 1 if v1 < v2 else -1
-            else:
-                return -1 if v1 < v2 else 1
-
-        return sorted(self.it, cmp=self._comparer, reverse=self.reverse)
+    def _key_getter(self, item):
+        if self.value_accessor:
+            return self.value_accessor(item, self.name)
+        return item.config.get(self.name)
 
 
 class PaginationFilterIterator(object):
@@ -200,10 +200,14 @@
                             (filter_name, self._current_page.path))
         return self._simpleNonSortedWrap(SettingFilterIterator, filter_conf)
 
-    def sort(self, setting_name, reverse=False):
+    def sort(self, setting_name=None, reverse=False):
         self._ensureUnlocked()
         self._unload()
-        self._pages = SettingSortIterator(self._pages, setting_name, reverse)
+        if setting_name is not None:
+            self._pages = SettingSortIterator(self._pages, setting_name,
+                                              reverse)
+        else:
+            self._pages = NaturalSortIterator(self._pages, reverse)
         self._has_sorter = True
         return self
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/test_data_iterators.py	Sun Aug 17 21:18:48 2014 -0700
@@ -0,0 +1,88 @@
+import mock
+from piecrust.data.iterators import PageIterator
+from piecrust.page import Page, PageConfiguration
+
+
+def test_skip():
+    it = PageIterator(range(12))
+    it.skip(5)
+    assert it.total_count == 12
+    assert len(it) == 7
+    assert list(it) == list(range(5, 12))
+
+
+def test_limit():
+    it = PageIterator(range(12))
+    it.limit(4)
+    assert it.total_count == 12
+    assert len(it) == 4
+    assert list(it) == list(range(4))
+
+
+def test_slice():
+    it = PageIterator(range(12))
+    it.slice(3, 4)
+    assert it.total_count == 12
+    assert len(it) == 4
+    assert list(it) == list(range(3, 7))
+
+
+def test_natural_sort():
+    it = PageIterator([4, 3, 1, 2, 0])
+    it.sort()
+    assert it.total_count == 5
+    assert len(it) == 5
+    assert list(it) == list(range(5))
+
+
+def test_natural_sort_reversed():
+    it = PageIterator([4, 3, 1, 2, 0])
+    it.sort(reverse=True)
+    assert it.total_count == 5
+    assert len(it) == 5
+    assert list(it) == list(reversed(range(5)))
+
+
+class TestItem(object):
+    def __init__(self, value):
+        self.name = str(value)
+        self.config = {'foo': value}  # `config` makes it look like a `Page`.
+
+    def __eq__(self, other):
+        return other.name == self.name
+
+
+def test_setting_sort():
+    it = PageIterator([TestItem(v) for v in [4, 3, 1, 2, 0]])
+    it.sort('foo')
+    assert it.total_count == 5
+    assert len(it) == 5
+    assert list(it) == [TestItem(v) for v in range(5)]
+
+
+def test_setting_sort_reversed():
+    it = PageIterator([TestItem(v) for v in [4, 3, 1, 2, 0]])
+    it.sort('foo', reverse=True)
+    assert it.total_count == 5
+    assert len(it) == 5
+    assert list(it) == [TestItem(v) for v in reversed(range(5))]
+
+
+def test_filter():
+    page = mock.MagicMock(spec=Page)
+    page.config = PageConfiguration()
+    page.config.set('threes', {'is_foo': 3})
+    it = PageIterator([TestItem(v) for v in [3, 2, 3, 1, 4, 3]], page)
+    it.filter('threes')
+    assert it.total_count == 3
+    assert len(it) == 3
+    assert list(it) == [TestItem(3), TestItem(3), TestItem(3)]
+
+
+def test_magic_filter():
+    it = PageIterator([TestItem(v) for v in [3, 2, 3, 1, 4, 3]])
+    it.is_foo(3)
+    assert it.total_count == 3
+    assert len(it) == 3
+    assert list(it) == [TestItem(3), TestItem(3), TestItem(3)]
+