summaryrefslogtreecommitdiffstats
path: root/WebKitTools/pywebsocket/mod_pywebsocket/handshake.py
diff options
context:
space:
mode:
Diffstat (limited to 'WebKitTools/pywebsocket/mod_pywebsocket/handshake.py')
-rw-r--r--WebKitTools/pywebsocket/mod_pywebsocket/handshake.py178
1 files changed, 178 insertions, 0 deletions
diff --git a/WebKitTools/pywebsocket/mod_pywebsocket/handshake.py b/WebKitTools/pywebsocket/mod_pywebsocket/handshake.py
new file mode 100644
index 0000000..a67aadd
--- /dev/null
+++ b/WebKitTools/pywebsocket/mod_pywebsocket/handshake.py
@@ -0,0 +1,178 @@
+# Copyright 2009, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+"""Web Socket handshaking.
+
+Note: request.connection.write/read are used in this module, even though
+mod_python document says that they should be used only in connection handlers.
+Unfortunately, we have no other options. For example, request.write/read are
+not suitable because they don't allow direct raw bytes writing/reading.
+"""
+
+
+import re
+
+
+_DEFAULT_WEB_SOCKET_PORT = 80
+_DEFAULT_WEB_SOCKET_SECURE_PORT = 443
+_WEB_SOCKET_SCHEME = 'ws'
+_WEB_SOCKET_SECURE_SCHEME = 'wss'
+
+_METHOD_LINE = re.compile(r'^GET ([^ ]+) HTTP/1.1\r\n$')
+
+_MANDATORY_HEADERS = [
+ # key, expected value or None
+ ['Upgrade', 'WebSocket'],
+ ['Connection', 'Upgrade'],
+ ['Host', None],
+ ['Origin', None],
+]
+
+
+def _default_port(is_secure):
+ if is_secure:
+ return _DEFAULT_WEB_SOCKET_SECURE_PORT
+ else:
+ return _DEFAULT_WEB_SOCKET_PORT
+
+
+class HandshakeError(Exception):
+ """Exception in Web Socket Handshake."""
+
+ pass
+
+
+def _validate_protocol(protocol):
+ """Validate WebSocket-Protocol string."""
+
+ if not protocol:
+ raise HandshakeError('Invalid WebSocket-Protocol: empty')
+ for c in protocol:
+ if not 0x21 <= ord(c) <= 0x7e:
+ raise HandshakeError('Illegal character in protocol: %r' % c)
+
+
+class Handshaker(object):
+ """This class performs Web Socket handshake."""
+
+ def __init__(self, request, dispatcher):
+ """Construct an instance.
+
+ Args:
+ request: mod_python request.
+ dispatcher: Dispatcher (dispatch.Dispatcher).
+
+ Handshaker will add attributes such as ws_resource in performing
+ handshake.
+ """
+
+ self._request = request
+ self._dispatcher = dispatcher
+
+ def do_handshake(self):
+ """Perform Web Socket Handshake."""
+
+ self._check_header_lines()
+ self._set_resource()
+ self._set_origin()
+ self._set_location()
+ self._set_protocol()
+ self._dispatcher.do_extra_handshake(self._request)
+ self._send_handshake()
+
+ def _set_resource(self):
+ self._request.ws_resource = self._request.uri
+
+ def _set_origin(self):
+ self._request.ws_origin = self._request.headers_in['Origin']
+
+ def _set_location(self):
+ location_parts = []
+ if self._request.is_https():
+ location_parts.append(_WEB_SOCKET_SECURE_SCHEME)
+ else:
+ location_parts.append(_WEB_SOCKET_SCHEME)
+ location_parts.append('://')
+ host, port = self._parse_host_header()
+ connection_port = self._request.connection.local_addr[1]
+ if port != connection_port:
+ raise HandshakeError('Header/connection port mismatch: %d/%d' %
+ (port, connection_port))
+ location_parts.append(host)
+ if (port != _default_port(self._request.is_https())):
+ location_parts.append(':')
+ location_parts.append(str(port))
+ location_parts.append(self._request.uri)
+ self._request.ws_location = ''.join(location_parts)
+
+ def _parse_host_header(self):
+ fields = self._request.headers_in['Host'].split(':', 1)
+ if len(fields) == 1:
+ return fields[0], _default_port(self._request.is_https())
+ try:
+ return fields[0], int(fields[1])
+ except ValueError, e:
+ raise HandshakeError('Invalid port number format: %r' % e)
+
+ def _set_protocol(self):
+ protocol = self._request.headers_in.get('WebSocket-Protocol')
+ if protocol is not None:
+ _validate_protocol(protocol)
+ self._request.ws_protocol = protocol
+
+ def _send_handshake(self):
+ self._request.connection.write(
+ 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n')
+ self._request.connection.write('Upgrade: WebSocket\r\n')
+ self._request.connection.write('Connection: Upgrade\r\n')
+ self._request.connection.write('WebSocket-Origin: ')
+ self._request.connection.write(self._request.ws_origin)
+ self._request.connection.write('\r\n')
+ self._request.connection.write('WebSocket-Location: ')
+ self._request.connection.write(self._request.ws_location)
+ self._request.connection.write('\r\n')
+ if self._request.ws_protocol:
+ self._request.connection.write('WebSocket-Protocol: ')
+ self._request.connection.write(self._request.ws_protocol)
+ self._request.connection.write('\r\n')
+ self._request.connection.write('\r\n')
+
+ def _check_header_lines(self):
+ for key, expected_value in _MANDATORY_HEADERS:
+ actual_value = self._request.headers_in.get(key)
+ if not actual_value:
+ raise HandshakeError('Header %s is not defined' % key)
+ if expected_value:
+ if actual_value != expected_value:
+ raise HandshakeError('Illegal value for header %s: %s' %
+ (key, actual_value))
+
+
+# vi:sts=4 sw=4 et