]> Raphaƫl G. Git Repositories - youtubedl/blobdiff - youtube_dl/utils.py
debian/control: Update list of supported sites.
[youtubedl] / youtube_dl / utils.py
index 7832ed87f022d5c4891cecb43822ebebd30a6d0b..8f5463f1c9a1e1a2660867abdc0f1f62e9147032 100644 (file)
@@ -32,6 +32,7 @@ import xml.etree.ElementTree
 import zlib
 
 from .compat import (
+    compat_basestring,
     compat_chr,
     compat_getenv,
     compat_html_entities,
@@ -140,7 +141,7 @@ else:
     def find_xpath_attr(node, xpath, key, val):
         # Here comes the crazy part: In 2.6, if the xpath is a unicode,
         # .//node does not match if a node is a direct child of . !
-        if isinstance(xpath, unicode):
+        if isinstance(xpath, compat_str):
             xpath = xpath.encode('ascii')
 
         for f in node.findall(xpath):
@@ -411,25 +412,9 @@ def make_HTTPS_handler(params, **kwargs):
             pass
 
     if sys.version_info < (3, 2):
-        import httplib
-
-        class HTTPSConnectionV3(httplib.HTTPSConnection):
-            def __init__(self, *args, **kwargs):
-                httplib.HTTPSConnection.__init__(self, *args, **kwargs)
-
-            def connect(self):
-                sock = socket.create_connection((self.host, self.port), self.timeout)
-                if getattr(self, '_tunnel_host', False):
-                    self.sock = sock
-                    self._tunnel()
-                try:
-                    self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, ssl_version=ssl.PROTOCOL_TLSv1)
-                except ssl.SSLError:
-                    self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, ssl_version=ssl.PROTOCOL_SSLv23)
-
-        return YoutubeDLHTTPSHandler(params, https_conn_class=HTTPSConnectionV3, **kwargs)
+        return YoutubeDLHTTPSHandler(params, **kwargs)
     else:  # Python < 3.4
-        context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+        context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
         context.verify_mode = (ssl.CERT_NONE
                                if opts_no_check_certificate
                                else ssl.CERT_REQUIRED)
@@ -560,7 +545,9 @@ def _create_http_connection(ydl_handler, http_class, is_https, *args, **kwargs):
                 sock = compat_socket_create_connection(
                     (self.host, self.port), self.timeout, sa)
                 if is_https:
-                    self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file)
+                    self.sock = ssl.wrap_socket(
+                        sock, self.key_file, self.cert_file,
+                        ssl_version=ssl.PROTOCOL_TLSv1)
                 else:
                     self.sock = sock
             hc.connect = functools.partial(_hc_connect, hc)
@@ -612,17 +599,14 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
 
     def http_request(self, req):
         for h, v in std_headers.items():
-            if h not in req.headers:
+            # Capitalize is needed because of Python bug 2275: http://bugs.python.org/issue2275
+            # The dict keys are capitalized because of this bug by urllib
+            if h.capitalize() not in req.headers:
                 req.add_header(h, v)
         if 'Youtubedl-no-compression' in req.headers:
             if 'Accept-encoding' in req.headers:
                 del req.headers['Accept-encoding']
             del req.headers['Youtubedl-no-compression']
-        if 'Youtubedl-user-agent' in req.headers:
-            if 'User-agent' in req.headers:
-                del req.headers['User-agent']
-            req.headers['User-agent'] = req.headers['Youtubedl-user-agent']
-            del req.headers['Youtubedl-user-agent']
 
         if sys.version_info < (2, 7) and '#' in req.get_full_url():
             # Python 2.6 is brain-dead when it comes to fragments
@@ -671,9 +655,14 @@ class YoutubeDLHTTPSHandler(compat_urllib_request.HTTPSHandler):
         self._params = params
 
     def https_open(self, req):
+        kwargs = {}
+        if hasattr(self, '_context'):  # python > 2.6
+            kwargs['context'] = self._context
+        if hasattr(self, '_check_hostname'):  # python 3.x
+            kwargs['check_hostname'] = self._check_hostname
         return self.do_open(functools.partial(
             _create_http_connection, self, self._https_conn_class, True),
-            req)
+            req, **kwargs)
 
 
 def parse_iso8601(date_str, delimiter='T'):
@@ -712,7 +701,7 @@ def unified_strdate(date_str, day_first=True):
     # %z (UTC offset) is only supported in python>=3.2
     date_str = re.sub(r' ?(\+|-)[0-9]{2}:?[0-9]{2}$', '', date_str)
     # Remove AM/PM + timezone
-    date_str = re.sub(r'(?i)\s*(?:AM|PM)\s+[A-Z]+', '', date_str)
+    date_str = re.sub(r'(?i)\s*(?:AM|PM)(?:\s+[A-Z]+)?', '', date_str)
 
     format_expressions = [
         '%d %B %Y',
@@ -875,6 +864,9 @@ def _windows_write_string(s, out):
     except AttributeError:
         # If the output stream doesn't have a fileno, it's virtual
         return False
+    except io.UnsupportedOperation:
+        # Some strange Windows pseudo files?
+        return False
     if fileno not in WIN_OUTPUT_IDS:
         return False
 
@@ -1271,7 +1263,7 @@ def float_or_none(v, scale=1, invscale=1, default=None):
 
 
 def parse_duration(s):
-    if not isinstance(s, basestring if sys.version_info < (3, 0) else compat_str):
+    if not isinstance(s, compat_basestring):
         return None
 
     s = s.strip()
@@ -1283,7 +1275,10 @@ def parse_duration(s):
             (?P<only_hours>[0-9.]+)\s*(?:hours?)|
 
             (?:
-                (?:(?P<hours>[0-9]+)\s*(?:[:h]|hours?)\s*)?
+                (?:
+                    (?:(?P<days>[0-9]+)\s*(?:[:d]|days?)\s*)?
+                    (?P<hours>[0-9]+)\s*(?:[:h]|hours?)\s*
+                )?
                 (?P<mins>[0-9]+)\s*(?:[:m]|mins?|minutes?)\s*
             )?
             (?P<secs>[0-9]+)(?P<ms>\.[0-9]+)?\s*(?:s|secs?|seconds?)?
@@ -1301,6 +1296,8 @@ def parse_duration(s):
         res += int(m.group('mins')) * 60
     if m.group('hours'):
         res += int(m.group('hours')) * 60 * 60
+    if m.group('days'):
+        res += int(m.group('days')) * 24 * 60 * 60
     if m.group('ms'):
         res += float(m.group('ms'))
     return res
@@ -1435,7 +1432,7 @@ def uppercase_escape(s):
 
 def escape_rfc3986(s):
     """Escape non-ASCII characters as suggested by RFC 3986"""
-    if sys.version_info < (3, 0) and isinstance(s, unicode):
+    if sys.version_info < (3, 0) and isinstance(s, compat_str):
         s = s.encode('utf-8')
     return compat_urllib_parse.quote(s, b"%/;:@&=+$,!~*'()?#[]")
 
@@ -1551,7 +1548,7 @@ def js_to_json(code):
     res = re.sub(r'''(?x)
         "(?:[^"\\]*(?:\\\\|\\")?)*"|
         '(?:[^'\\]*(?:\\\\|\\')?)*'|
-        [a-zA-Z_][a-zA-Z_0-9]*
+        [a-zA-Z_][.a-zA-Z_0-9]*
         ''', fix_kv, code)
     res = re.sub(r',(\s*\])', lambda m: m.group(1), res)
     return res
@@ -1612,6 +1609,14 @@ def urlhandle_detect_ext(url_handle):
     except AttributeError:  # Python < 3
         getheader = url_handle.info().getheader
 
+    cd = getheader('Content-Disposition')
+    if cd:
+        m = re.match(r'attachment;\s*filename="(?P<filename>[^"]+)"', cd)
+        if m:
+            e = determine_ext(m.group('filename'), default_ext=None)
+            if e:
+                return e
+
     return getheader('Content-Type').split("/")[1]
 
 
@@ -1623,3 +1628,53 @@ def age_restricted(content_limit, age_limit):
     if content_limit is None:
         return False  # Content available for everyone
     return age_limit < content_limit
+
+
+def is_html(first_bytes):
+    """ Detect whether a file contains HTML by examining its first bytes. """
+
+    BOMS = [
+        (b'\xef\xbb\xbf', 'utf-8'),
+        (b'\x00\x00\xfe\xff', 'utf-32-be'),
+        (b'\xff\xfe\x00\x00', 'utf-32-le'),
+        (b'\xff\xfe', 'utf-16-le'),
+        (b'\xfe\xff', 'utf-16-be'),
+    ]
+    for bom, enc in BOMS:
+        if first_bytes.startswith(bom):
+            s = first_bytes[len(bom):].decode(enc, 'replace')
+            break
+    else:
+        s = first_bytes.decode('utf-8', 'replace')
+
+    return re.match(r'^\s*<', s)
+
+
+def determine_protocol(info_dict):
+    protocol = info_dict.get('protocol')
+    if protocol is not None:
+        return protocol
+
+    url = info_dict['url']
+    if url.startswith('rtmp'):
+        return 'rtmp'
+    elif url.startswith('mms'):
+        return 'mms'
+    elif url.startswith('rtsp'):
+        return 'rtsp'
+
+    ext = determine_ext(url)
+    if ext == 'm3u8':
+        return 'm3u8'
+    elif ext == 'f4m':
+        return 'f4m'
+
+    return compat_urllib_parse_urlparse(url).scheme
+
+
+def render_table(header_row, data):
+    """ Render a list of rows, each as a list of values """
+    table = [header_row] + data
+    max_lens = [max(len(compat_str(v)) for v in col) for col in zip(*table)]
+    format_str = ' '.join('%-' + compat_str(ml + 1) + 's' for ml in max_lens[:-1]) + '%s'
+    return '\n'.join(format_str % tuple(row) for row in table)