]> Raphaƫl G. Git Repositories - youtubedl/blobdiff - test/helper.py
debian/changelog: Describe a bit what this new version brings us.
[youtubedl] / test / helper.py
index a2b468b509b3effc6ff61c2b00cb9ca1d59f3711..b1f421ac58331bad23328502f42a0e1316df853d 100644 (file)
@@ -1,38 +1,76 @@
+import errno
 import io
 import io
+import hashlib
 import json
 import os.path
 import json
 import os.path
+import re
+import types
+import sys
 
 import youtube_dl.extractor
 
 import youtube_dl.extractor
-from youtube_dl import YoutubeDL, YoutubeDLHandler
-from youtube_dl.utils import (
-    compat_cookiejar,
-    compat_urllib_request,
-)
-
-# General configuration (from __init__, not very elegant...)
-jar = compat_cookiejar.CookieJar()
-cookie_processor = compat_urllib_request.HTTPCookieProcessor(jar)
-proxy_handler = compat_urllib_request.ProxyHandler()
-opener = compat_urllib_request.build_opener(proxy_handler, cookie_processor, YoutubeDLHandler())
-compat_urllib_request.install_opener(opener)
-
-PARAMETERS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "parameters.json")
-with io.open(PARAMETERS_FILE, encoding='utf-8') as pf:
-    parameters = json.load(pf)
+from youtube_dl import YoutubeDL
+from youtube_dl.utils import preferredencoding
+
+
+def get_params(override=None):
+    PARAMETERS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
+                                   "parameters.json")
+    with io.open(PARAMETERS_FILE, encoding='utf-8') as pf:
+        parameters = json.load(pf)
+    if override:
+        parameters.update(override)
+    return parameters
+
+
+def try_rm(filename):
+    """ Remove a file if it exists """
+    try:
+        os.remove(filename)
+    except OSError as ose:
+        if ose.errno != errno.ENOENT:
+            raise
+
+
+def report_warning(message):
+    '''
+    Print the message to stderr, it will be prefixed with 'WARNING:'
+    If stderr is a tty file the 'WARNING:' will be colored
+    '''
+    if sys.stderr.isatty() and os.name != 'nt':
+        _msg_header = u'\033[0;33mWARNING:\033[0m'
+    else:
+        _msg_header = u'WARNING:'
+    output = u'%s %s\n' % (_msg_header, message)
+    if 'b' in getattr(sys.stderr, 'mode', '') or sys.version_info[0] < 3:
+        output = output.encode(preferredencoding())
+    sys.stderr.write(output)
+
 
 class FakeYDL(YoutubeDL):
 
 class FakeYDL(YoutubeDL):
-    def __init__(self):
-        self.result = []
+    def __init__(self, override=None):
         # Different instances of the downloader can't share the same dictionary
         # some test set the "sublang" parameter, which would break the md5 checks.
         # Different instances of the downloader can't share the same dictionary
         # some test set the "sublang" parameter, which would break the md5 checks.
-        self.params = dict(parameters)
-    def to_screen(self, s):
+        params = get_params(override=override)
+        super(FakeYDL, self).__init__(params)
+        self.result = []
+        
+    def to_screen(self, s, skip_eol=None):
         print(s)
         print(s)
+
     def trouble(self, s, tb=None):
         raise Exception(s)
     def trouble(self, s, tb=None):
         raise Exception(s)
+
     def download(self, x):
         self.result.append(x)
 
     def download(self, x):
         self.result.append(x)
 
+    def expect_warning(self, regex):
+        # Silence an expected warning matching a regex
+        old_report_warning = self.report_warning
+        def report_warning(self, message):
+            if re.match(regex, message): return
+            old_report_warning(message)
+        self.report_warning = types.MethodType(report_warning, self)
+
 def get_testcases():
     for ie in youtube_dl.extractor.gen_extractors():
         t = getattr(ie, '_TEST', None)
 def get_testcases():
     for ie in youtube_dl.extractor.gen_extractors():
         t = getattr(ie, '_TEST', None)
@@ -42,3 +80,6 @@ def get_testcases():
         for t in getattr(ie, '_TESTS', []):
             t['name'] = type(ie).__name__[:-len('IE')]
             yield t
         for t in getattr(ie, '_TESTS', []):
             t['name'] = type(ie).__name__[:-len('IE')]
             yield t
+
+
+md5 = lambda s: hashlib.md5(s.encode('utf-8')).hexdigest()