diff tests/mockutil.py @ 35:e4c345dcf33c

More unit tests, fix a bug with the skip patterns.
author Ludovic Chabant <ludovic@chabant.com>
date Wed, 20 Aug 2014 21:46:27 -0700
parents 43091c9837bf
children 485682a6de50
line wrap: on
line diff
--- a/tests/mockutil.py	Wed Aug 20 14:55:23 2014 -0700
+++ b/tests/mockutil.py	Wed Aug 20 21:46:27 2014 -0700
@@ -2,6 +2,7 @@
 import time
 import random
 import codecs
+import shutil
 import os.path
 import functools
 import mock
@@ -51,22 +52,16 @@
 
     def withDir(self, path):
         path = path.replace('\\', '/')
-        cur = self._fs[self._root]
-        for b in path.split('/'):
-            if b not in cur:
-                cur[b] = {}
-            cur = cur[b]
+        path = path.lstrip('/')
+        path = '/%s/%s' % (self._root, path)
+        self._createDir(path)
         return self
 
     def withFile(self, path, contents):
         path = path.replace('\\', '/')
-        cur = self._fs[self._root]
-        bits = path.split('/')
-        for b in bits[:-1]:
-            if b not in cur:
-                cur[b] = {}
-            cur = cur[b]
-        cur[bits[-1]] = (contents, {'mtime': time.time()})
+        path = path.lstrip('/')
+        path = '/%s/%s' % (self._root, path)
+        self._createFile(path, contents)
         return self
 
     def withAsset(self, path, contents):
@@ -106,16 +101,66 @@
         return self.withAsset('_content/%s/%s' % (dirname, name),
                 contents)
 
+    def getStructure(self, path=None):
+        root = self._fs[self._root]
+        if path:
+            root = self._getEntry(self.path(path))
+
+        res = {}
+        for k, v in root.items():
+            self._getStructureRecursive(v, res, k)
+        return res
+
+    def _getStructureRecursive(self, src, target, name):
+        if isinstance(src, tuple):
+            target[name] = src[0]
+            return
+
+        e = {}
+        for k, v in src.items():
+            self._getStructureRecursive(v, e, k)
+        target[name] = e
+
+    def _getEntry(self, path):
+        cur = self._fs
+        path = path.replace('\\', '/').lstrip('/')
+        bits = path.split('/')
+        for p in bits:
+            try:
+                cur = cur[p]
+            except KeyError:
+                return None
+        return cur
+
+    def _createDir(self, path):
+        cur = self._fs
+        bits = path.strip('/').split('/')
+        for b in bits:
+            if b not in cur:
+                cur[b] = {}
+            cur = cur[b]
+        return self
+
+    def _createFile(self, path, contents):
+        cur = self._fs
+        bits = path.strip('/').split('/')
+        for b in bits[:-1]:
+            if b not in cur:
+                cur[b] = {}
+            cur = cur[b]
+        cur[bits[-1]] = (contents, {'mtime': time.time()})
+        return self
+
 
 class mock_fs_scope(object):
     def __init__(self, fs):
         self._fs = fs
-        self._root = None
         self._patchers = []
         self._originals = {}
-        if isinstance(fs, mock_fs):
-            self._fs = fs._fs
-            self._root = fs._root
+
+    @property
+    def root(self):
+        return self._fs._root
 
     def __enter__(self):
         self._startMock()
@@ -128,10 +173,12 @@
         self._createMock('__main__.open', open, self._open, create=True)
         self._createMock('codecs.open', codecs.open, self._codecsOpen)
         self._createMock('os.listdir', os.listdir, self._listdir)
+        self._createMock('os.makedirs', os.makedirs, self._makedirs)
         self._createMock('os.path.isdir', os.path.isdir, self._isdir)
         self._createMock('os.path.isfile', os.path.isfile, self._isfile)
         self._createMock('os.path.islink', os.path.islink, self._islink)
         self._createMock('os.path.getmtime', os.path.getmtime, self._getmtime)
+        self._createMock('shutil.copyfile', shutil.copyfile, self._copyfile)
         for p in self._patchers:
             p.start()
 
@@ -166,7 +213,7 @@
         return io.StringIO(e[0])
 
     def _listdir(self, path):
-        if not path.startswith('/' + self._root):
+        if not path.startswith('/' + self.root):
             return self._originals['os.listdir'](path)
         e = self._getFsEntry(path)
         if e is None:
@@ -175,39 +222,47 @@
             raise OSError("'%s' is not a directory." % path)
         return list(e.keys())
 
+    def _makedirs(self, path, mode):
+        if not path.startswith('/' + self.root):
+            raise Exception("Shouldn't create directory: %s" % path)
+        self._fs._createDir(path)
+
     def _isdir(self, path):
-        if not path.startswith('/' + self._root):
+        if not path.startswith('/' + self.root):
             return self._originals['os.path.isdir'](path)
         e = self._getFsEntry(path)
         return e is not None and isinstance(e, dict)
 
     def _isfile(self, path):
-        if not path.startswith('/' + self._root):
+        if not path.startswith('/' + self.root):
             return self._originals['os.path.isfile'](path)
         e = self._getFsEntry(path)
         return e is not None and isinstance(e, tuple)
 
     def _islink(self, path):
-        if not path.startswith('/' + self._root):
+        if not path.startswith('/' + self.root):
             return self._originals['os.path.islink'](path)
         return False
 
     def _getmtime(self, path):
-        if not path.startswith('/' + self._root):
+        if not path.startswith('/' + self.root):
             return self._originals['os.path.getmtime'](path)
         e = self._getFsEntry(path)
         if e is None:
             raise OSError("No such file: %s" % path)
         return e[1]['mtime']
 
+    def _copyfile(self, src, dst):
+        if not src.startswith('/' + self.root):
+            with open(src, 'r') as fp:
+                src_text = fp.read()
+        else:
+            e = self._getFsEntry(src)
+            src_text = e[0]
+        if not dst.startswith('/' + self.root):
+            raise Exception("Shouldn't copy to: %s" % dst)
+        self._fs._createFile(dst, src_text)
+
     def _getFsEntry(self, path):
-        cur = self._fs
-        path = path.replace('\\', '/').lstrip('/')
-        bits = path.split('/')
-        for p in bits:
-            try:
-                cur = cur[p]
-            except KeyError:
-                return None
-        return cur
+        return self._fs._getEntry(path)