_handshake.py 6.35 KB
Newer Older
l2m2's avatar
l2m2 committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
"""
websocket - WebSocket client library for Python

Copyright (C) 2010 Hiroki Ohtani(liris)

    This library is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 2.1 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public
    License along with this library; if not, write to the Free Software
    Foundation, Inc., 51 Franklin Street, Fifth Floor,
    Boston, MA  02110-1335  USA

"""
import hashlib
import hmac
import os

import six

from ._cookiejar import SimpleCookieJar
from ._exceptions import *
from ._http import *
from ._logging import *
from ._socket import *

l2m2's avatar
l2m2 committed
34
if hasattr(six, 'PY3') and six.PY3:
l2m2's avatar
l2m2 committed
35 36 37 38
    from base64 import encodebytes as base64encode
else:
    from base64 import encodestring as base64encode

l2m2's avatar
l2m2 committed
39 40 41 42 43 44 45 46 47
if hasattr(six, 'PY3') and six.PY3:
    if hasattr(six, 'PY34') and six.PY34:
        from http import client as HTTPStatus
    else:
        from http import HTTPStatus
else:
    import httplib as HTTPStatus

__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
l2m2's avatar
l2m2 committed
48 49 50 51 52 53 54 55 56 57

if hasattr(hmac, "compare_digest"):
    compare_digest = hmac.compare_digest
else:
    def compare_digest(s1, s2):
        return s1 == s2

# websocket supported version.
VERSION = 13

l2m2's avatar
l2m2 committed
58 59 60
SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,)
SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)

l2m2's avatar
l2m2 committed
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
CookieJar = SimpleCookieJar()


class handshake_response(object):

    def __init__(self, status, headers, subprotocol):
        self.status = status
        self.headers = headers
        self.subprotocol = subprotocol
        CookieJar.add(headers.get("set-cookie"))


def handshake(sock, hostname, port, resource, **options):
    headers, key = _get_handshake_headers(resource, hostname, port, options)

    header_str = "\r\n".join(headers)
    send(sock, header_str)
    dump("request header", header_str)

    status, resp = _get_resp_headers(sock)
l2m2's avatar
l2m2 committed
81 82
    if status in SUPPORTED_REDIRECT_STATUSES:
        return handshake_response(status, resp, None)
l2m2's avatar
l2m2 committed
83 84 85 86 87 88
    success, subproto = _validate(resp, key, options.get("subprotocols"))
    if not success:
        raise WebSocketException("Invalid WebSocket Header")

    return handshake_response(status, resp, subproto)

l2m2's avatar
l2m2 committed
89

l2m2's avatar
l2m2 committed
90 91 92 93 94 95 96 97 98 99
def _pack_hostname(hostname):
    # IPv6 address
    if ':' in hostname:
        return '[' + hostname + ']'

    return hostname

def _get_handshake_headers(resource, host, port, options):
    headers = [
        "GET %s HTTP/1.1" % resource,
l2m2's avatar
l2m2 committed
100
        "Upgrade: websocket"
l2m2's avatar
l2m2 committed
101 102 103 104 105 106 107 108 109 110
    ]
    if port == 80 or port == 443:
        hostport = _pack_hostname(host)
    else:
        hostport = "%s:%d" % (_pack_hostname(host), port)
    if "host" in options and options["host"] is not None:
        headers.append("Host: %s" % options["host"])
    else:
        headers.append("Host: %s" % hostport)

l2m2's avatar
l2m2 committed
111 112 113 114 115
    if "suppress_origin" not in options or not options["suppress_origin"]:
        if "origin" in options and options["origin"] is not None:
            headers.append("Origin: %s" % options["origin"])
        else:
            headers.append("Origin: http://%s" % hostport)
l2m2's avatar
l2m2 committed
116 117

    key = _create_sec_websocket_key()
l2m2's avatar
l2m2 committed
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
    
    # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
    if not 'header' in options or 'Sec-WebSocket-Key' not in options['header']:
        key = _create_sec_websocket_key()
        headers.append("Sec-WebSocket-Key: %s" % key)
    else:
        key = options['header']['Sec-WebSocket-Key']

    if not 'header' in options or 'Sec-WebSocket-Version' not in options['header']:
        headers.append("Sec-WebSocket-Version: %s" % VERSION)

    if not 'connection' in options or options['connection'] is None:
        headers.append('Connection: upgrade')
    else:
        headers.append(options['connection'])
l2m2's avatar
l2m2 committed
133 134 135 136 137 138 139 140

    subprotocols = options.get("subprotocols")
    if subprotocols:
        headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))

    if "header" in options:
        header = options["header"]
        if isinstance(header, dict):
l2m2's avatar
l2m2 committed
141 142 143 144 145
            header = [
                ": ".join([k, v])
                for k, v in header.items()
                if v is not None
            ]
l2m2's avatar
l2m2 committed
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
        headers.extend(header)

    server_cookie = CookieJar.get(host)
    client_cookie = options.get("cookie", None)

    cookie = "; ".join(filter(None, [server_cookie, client_cookie]))

    if cookie:
        headers.append("Cookie: %s" % cookie)

    headers.append("")
    headers.append("")

    return headers, key


l2m2's avatar
l2m2 committed
162
def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES):
l2m2's avatar
l2m2 committed
163
    status, resp_headers, status_message = read_headers(sock)
l2m2's avatar
l2m2 committed
164 165
    if status not in success_statuses:
        raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers)
l2m2's avatar
l2m2 committed
166 167
    return status, resp_headers

l2m2's avatar
l2m2 committed
168

l2m2's avatar
l2m2 committed
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
_HEADERS_TO_CHECK = {
    "upgrade": "websocket",
    "connection": "upgrade",
}


def _validate(headers, key, subprotocols):
    subproto = None
    for k, v in _HEADERS_TO_CHECK.items():
        r = headers.get(k, None)
        if not r:
            return False, None
        r = r.lower()
        if v != r:
            return False, None

    if subprotocols:
        subproto = headers.get("sec-websocket-protocol", None).lower()
        if not subproto or subproto not in [s.lower() for s in subprotocols]:
            error("Invalid subprotocol: " + str(subprotocols))
            return False, None

    result = headers.get("sec-websocket-accept", None)
    if not result:
        return False, None
    result = result.lower()

    if isinstance(result, six.text_type):
        result = result.encode('utf-8')

    value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
    hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
    success = compare_digest(hashed, result)

    if success:
        return True, subproto
    else:
        return False, None


def _create_sec_websocket_key():
    randomness = os.urandom(16)
    return base64encode(randomness).decode('utf-8').strip()