From d0825bca7fe65beaee391d30da42e937db621564 Mon Sep 17 00:00:00 2001 From: Steve Block Date: Tue, 2 Feb 2010 14:57:50 +0000 Subject: Merge webkit.org at r54127 : Initial merge by git Change-Id: Ib661abb595522f50ea406f72d3a0ce17f7193c82 --- .../pywebsocket/mod_pywebsocket/dispatch.py | 17 ++ .../pywebsocket/mod_pywebsocket/handshake.py | 50 ++++- .../pywebsocket/mod_pywebsocket/memorizingfile.py | 81 +++++++++ .../pywebsocket/mod_pywebsocket/standalone.py | 61 ++++++- WebKitTools/pywebsocket/setup.py | 2 +- WebKitTools/pywebsocket/test/test_dispatch.py | 11 ++ WebKitTools/pywebsocket/test/test_handshake.py | 201 ++++++++++++++++++++- .../pywebsocket/test/test_memorizingfile.py | 72 ++++++++ 8 files changed, 485 insertions(+), 10 deletions(-) create mode 100644 WebKitTools/pywebsocket/mod_pywebsocket/memorizingfile.py create mode 100644 WebKitTools/pywebsocket/test/test_memorizingfile.py (limited to 'WebKitTools/pywebsocket') diff --git a/WebKitTools/pywebsocket/mod_pywebsocket/dispatch.py b/WebKitTools/pywebsocket/mod_pywebsocket/dispatch.py index bf9a856..c52e9eb 100644 --- a/WebKitTools/pywebsocket/mod_pywebsocket/dispatch.py +++ b/WebKitTools/pywebsocket/mod_pywebsocket/dispatch.py @@ -142,6 +142,23 @@ class Dispatcher(object): 'root_dir:%s.' % (scan_dir, root_dir)) self._source_files_in_dir(root_dir, scan_dir) + def add_resource_path_alias(self, + alias_resource_path, existing_resource_path): + """Add resource path alias. + + Once added, request to alias_resource_path would be handled by + handler registered for existing_resource_path. + + Args: + alias_resource_path: alias resource path + existing_resource_path: existing resource path + """ + try: + handler = self._handlers[existing_resource_path] + self._handlers[alias_resource_path] = handler + except KeyError: + raise DispatchError('No handler for: %r' % existing_resource_path) + def source_warnings(self): """Return warnings in sourcing handlers.""" diff --git a/WebKitTools/pywebsocket/mod_pywebsocket/handshake.py b/WebKitTools/pywebsocket/mod_pywebsocket/handshake.py index a67aadd..b86278e 100644 --- a/WebKitTools/pywebsocket/mod_pywebsocket/handshake.py +++ b/WebKitTools/pywebsocket/mod_pywebsocket/handshake.py @@ -39,14 +39,14 @@ not suitable because they don't allow direct raw bytes writing/reading. import re +import util + _DEFAULT_WEB_SOCKET_PORT = 80 _DEFAULT_WEB_SOCKET_SECURE_PORT = 443 _WEB_SOCKET_SCHEME = 'ws' _WEB_SOCKET_SECURE_SCHEME = 'wss' -_METHOD_LINE = re.compile(r'^GET ([^ ]+) HTTP/1.1\r\n$') - _MANDATORY_HEADERS = [ # key, expected value or None ['Upgrade', 'WebSocket'], @@ -55,6 +55,22 @@ _MANDATORY_HEADERS = [ ['Origin', None], ] +_FIRST_FIVE_LINES = map(re.compile, [ + r'^GET /[\S]* HTTP/1.1\r\n$', + r'^Upgrade: WebSocket\r\n$', + r'^Connection: Upgrade\r\n$', + r'^Host: [\S]+\r\n$', + r'^Origin: [\S]+\r\n$', +]) + +_SIXTH_AND_LATER = re.compile( + r'^' + r'(WebSocket-Protocol: [\x20-\x7e]+\r\n)?' + r'(Cookie: [^\r]*\r\n)*' + r'(Cookie2: [^\r]*\r\n)?' + r'(Cookie: [^\r]*\r\n)*' + r'\r\n') + def _default_port(is_secure): if is_secure: @@ -75,19 +91,22 @@ def _validate_protocol(protocol): if not protocol: raise HandshakeError('Invalid WebSocket-Protocol: empty') for c in protocol: - if not 0x21 <= ord(c) <= 0x7e: + if not 0x20 <= ord(c) <= 0x7e: raise HandshakeError('Illegal character in protocol: %r' % c) class Handshaker(object): """This class performs Web Socket handshake.""" - def __init__(self, request, dispatcher): + def __init__(self, request, dispatcher, strict=False): """Construct an instance. Args: request: mod_python request. dispatcher: Dispatcher (dispatch.Dispatcher). + strict: Strictly check handshake request. Default: False. + If True, request.connection must provide get_memorized_lines + method. Handshaker will add attributes such as ws_resource in performing handshake. @@ -95,6 +114,7 @@ class Handshaker(object): self._request = request self._dispatcher = dispatcher + self._strict = strict def do_handshake(self): """Perform Web Socket Handshake.""" @@ -173,6 +193,28 @@ class Handshaker(object): if actual_value != expected_value: raise HandshakeError('Illegal value for header %s: %s' % (key, actual_value)) + if self._strict: + try: + lines = self._request.connection.get_memorized_lines() + except AttributeError, e: + util.prepend_message_to_exception( + 'Strict handshake is specified but the connection ' + 'doesn\'t provide get_memorized_lines()', e) + raise + self._check_first_lines(lines) + + def _check_first_lines(self, lines): + if len(lines) < len(_FIRST_FIVE_LINES): + raise HandshakeError('Too few header lines: %d' % len(lines)) + for line, regexp in zip(lines, _FIRST_FIVE_LINES): + if not regexp.search(line): + raise HandshakeError('Unexpected header: %r doesn\'t match %r' + % (line, regexp.pattern)) + sixth_and_later = ''.join(lines[5:]) + if not _SIXTH_AND_LATER.search(sixth_and_later): + raise HandshakeError('Unexpected header: %r doesn\'t match %r' + % (sixth_and_later, + _SIXTH_AND_LATER.pattern)) # vi:sts=4 sw=4 et diff --git a/WebKitTools/pywebsocket/mod_pywebsocket/memorizingfile.py b/WebKitTools/pywebsocket/mod_pywebsocket/memorizingfile.py new file mode 100644 index 0000000..2f8a54e --- /dev/null +++ b/WebKitTools/pywebsocket/mod_pywebsocket/memorizingfile.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# +# Copyright 2009, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""Memorizing file. + +A memorizing file wraps a file and memorizes lines read by readline. +""" + + +import sys + + +class MemorizingFile(object): + """MemorizingFile wraps a file and memorizes lines read by readline. + + Note that data read by other methods are not memorized. This behavior + is good enough for memorizing lines SimpleHTTPServer reads before + the control reaches WebSocketRequestHandler. + """ + def __init__(self, file_, max_memorized_lines=sys.maxint): + """Construct an instance. + + Args: + file_: the file object to wrap. + max_memorized_lines: the maximum number of lines to memorize. + Only the first max_memorized_lines are memorized. + Default: sys.maxint. + """ + self._file = file_ + self._memorized_lines = [] + self._max_memorized_lines = max_memorized_lines + + def __getattribute__(self, name): + if name in ('_file', '_memorized_lines', '_max_memorized_lines', + 'readline', 'get_memorized_lines'): + return object.__getattribute__(self, name) + return self._file.__getattribute__(name) + + def readline(self): + """Override file.readline and memorize the line read.""" + + line = self._file.readline() + if line and len(self._memorized_lines) < self._max_memorized_lines: + self._memorized_lines.append(line) + return line + + def get_memorized_lines(self): + """Get lines memorized so far.""" + return self._memorized_lines + + +# vi:sts=4 sw=4 et diff --git a/WebKitTools/pywebsocket/mod_pywebsocket/standalone.py b/WebKitTools/pywebsocket/mod_pywebsocket/standalone.py index 6217585..0e6a349 100644 --- a/WebKitTools/pywebsocket/mod_pywebsocket/standalone.py +++ b/WebKitTools/pywebsocket/mod_pywebsocket/standalone.py @@ -38,6 +38,7 @@ Usage: python standalone.py [-p ] [-w ] [-s ] [-d ] + [-m ] ... for other options, see _main below ... is the port number to use for ws:// connection. @@ -63,6 +64,7 @@ import logging import logging.handlers import optparse import os +import re import socket import sys @@ -75,6 +77,7 @@ except ImportError: import dispatch import handshake +import memorizingfile import util @@ -88,6 +91,10 @@ _LOG_LEVELS = { _DEFAULT_LOG_MAX_BYTES = 1024 * 256 _DEFAULT_LOG_BACKUP_COUNT = 5 +_DEFAULT_REQUEST_QUEUE_SIZE = 128 + +# 1024 is practically large enough to contain WebSocket handshake lines. +_MAX_MEMORIZED_LINES = 1024 def _print_warnings_if_any(dispatcher): warnings = dispatcher.source_warnings() @@ -129,6 +136,10 @@ class _StandaloneConnection(object): """Mimic mp_conn.read().""" return self._request_handler.rfile.read(length) + def get_memorized_lines(self): + """Get memorized lines.""" + return self._request_handler.rfile.get_memorized_lines() + class _StandaloneRequest(object): """Mimic mod_python request.""" @@ -198,7 +209,9 @@ class WebSocketRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): """Override SocketServer.StreamRequestHandler.setup.""" self.connection = self.request - self.rfile = socket._fileobject(self.request, 'rb', self.rbufsize) + self.rfile = memorizingfile.MemorizingFile( + socket._fileobject(self.request, 'rb', self.rbufsize), + max_memorized_lines=_MAX_MEMORIZED_LINES) self.wfile = socket._fileobject(self.request, 'wb', self.wbufsize) def __init__(self, *args, **keywords): @@ -206,8 +219,9 @@ class WebSocketRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): self, WebSocketRequestHandler.options.use_tls) self._dispatcher = WebSocketRequestHandler.options.dispatcher self._print_warnings_if_any() - self._handshaker = handshake.Handshaker(self._request, - self._dispatcher) + self._handshaker = handshake.Handshaker( + self._request, self._dispatcher, + WebSocketRequestHandler.options.strict) SimpleHTTPServer.SimpleHTTPRequestHandler.__init__( self, *args, **keywords) @@ -268,6 +282,31 @@ def _configure_logging(options): handler.setFormatter(formatter) logger.addHandler(handler) +def _alias_handlers(dispatcher, websock_handlers_map_file): + """Set aliases specified in websock_handler_map_file in dispatcher. + + Args: + dispatcher: dispatch.Dispatcher instance + websock_handler_map_file: alias map file + """ + fp = open(websock_handlers_map_file) + try: + for line in fp: + if line[0] == '#' or line.isspace(): + continue + m = re.match('(\S+)\s+(\S+)', line) + if not m: + logging.warning('Wrong format in map file:' + line) + continue + try: + dispatcher.add_resource_path_alias( + m.group(1), m.group(2)) + except dispatch.DispatchError, e: + logging.error(str(e)) + finally: + fp.close() + + def _main(): parser = optparse.OptionParser() @@ -277,6 +316,12 @@ def _main(): parser.add_option('-w', '--websock_handlers', dest='websock_handlers', default='.', help='Web Socket handlers root directory.') + parser.add_option('-m', '--websock_handlers_map_file', + dest='websock_handlers_map_file', + default=None, + help=('Web Socket handlers map file. ' + 'Each line consists of alias_resource_path and ' + 'existing_resource_path, separated by spaces.')) parser.add_option('-s', '--scan_dir', dest='scan_dir', default=None, help=('Web Socket handlers scan directory. ' @@ -302,12 +347,19 @@ def _main(): parser.add_option('--log_count', dest='log_count', type='int', default=_DEFAULT_LOG_BACKUP_COUNT, help='Log backup count') + parser.add_option('--strict', dest='strict', action='store_true', + default=False, help='Strictly check handshake request') + parser.add_option('-q', '--queue', dest='request_queue_size', type='int', + default=_DEFAULT_REQUEST_QUEUE_SIZE, + help='request queue size') options = parser.parse_args()[0] os.chdir(options.document_root) _configure_logging(options) + SocketServer.TCPServer.request_queue_size = options.request_queue_size + if options.use_tls: if not _HAS_OPEN_SSL: logging.critical('To use TLS, install pyOpenSSL.') @@ -325,6 +377,9 @@ def _main(): # instantiation. Dispatcher can be shared because it is thread-safe. options.dispatcher = dispatch.Dispatcher(options.websock_handlers, options.scan_dir) + if options.websock_handlers_map_file: + _alias_handlers(options.dispatcher, + options.websock_handlers_map_file) _print_warnings_if_any(options.dispatcher) WebSocketRequestHandler.options = options diff --git a/WebKitTools/pywebsocket/setup.py b/WebKitTools/pywebsocket/setup.py index df05fef..a49c943 100644 --- a/WebKitTools/pywebsocket/setup.py +++ b/WebKitTools/pywebsocket/setup.py @@ -56,7 +56,7 @@ setup(author='Yuzo Fujishima', name=_PACKAGE_NAME, packages=[_PACKAGE_NAME], url='http://code.google.com/p/pywebsocket/', - version='0.4.3', + version='0.4.7.1', ) diff --git a/WebKitTools/pywebsocket/test/test_dispatch.py b/WebKitTools/pywebsocket/test/test_dispatch.py index b19d706..5403228 100644 --- a/WebKitTools/pywebsocket/test/test_dispatch.py +++ b/WebKitTools/pywebsocket/test/test_dispatch.py @@ -225,6 +225,17 @@ class DispatcherTest(unittest.TestCase): self.assertRaises(dispatch.DispatchError, dispatch.Dispatcher, 'a/b/c', 'a/b') + def test_resource_path_alias(self): + disp = dispatch.Dispatcher(_TEST_HANDLERS_DIR, None) + disp.add_resource_path_alias('/', '/origin_check') + self.assertEqual(4, len(disp._handlers)) + self.failUnless(disp._handlers.has_key('/origin_check')) + self.failUnless(disp._handlers.has_key('/sub/exception_in_transfer')) + self.failUnless(disp._handlers.has_key('/sub/plain')) + self.failUnless(disp._handlers.has_key('/')) + self.assertRaises(dispatch.DispatchError, + disp.add_resource_path_alias, '/alias', '/not-exist') + if __name__ == '__main__': unittest.main() diff --git a/WebKitTools/pywebsocket/test/test_handshake.py b/WebKitTools/pywebsocket/test/test_handshake.py index dd1f65c..8bf07be 100644 --- a/WebKitTools/pywebsocket/test/test_handshake.py +++ b/WebKitTools/pywebsocket/test/test_handshake.py @@ -214,11 +214,168 @@ _BAD_REQUESTS = ( 'Connection':'Upgrade', 'Host':'example.com', 'Origin':'http://example.com', - 'WebSocket-Protocol':'illegal protocol', + 'WebSocket-Protocol':'illegal\x09protocol', } ), ) +_STRICTLY_GOOD_REQUESTS = ( + ( + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + '\r\n', + ), + ( # WebSocket-Protocol + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + 'WebSocket-Protocol: sample\r\n', + '\r\n', + ), + ( # WebSocket-Protocol and Cookie + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + 'WebSocket-Protocol: sample\r\n', + 'Cookie: xyz\r\n' + '\r\n', + ), + ( # Cookie + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + 'Cookie: abc/xyz\r\n' + 'Cookie2: $Version=1\r\n' + 'Cookie: abc\r\n' + '\r\n', + ), + ( + 'GET / HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + '\r\n', + ), +) + +_NOT_STRICTLY_GOOD_REQUESTS = ( + ( # Extra space after GET + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + '\r\n', + ), + ( # Resource name doesn't stat with '/' + 'GET demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + '\r\n', + ), + ( # No space after : + 'GET /demo HTTP/1.1\r\n', + 'Upgrade:WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + '\r\n', + ), + ( # Lower case Upgrade header + 'GET /demo HTTP/1.1\r\n', + 'upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + '\r\n', + ), + ( # Connection comes before Upgrade + 'GET /demo HTTP/1.1\r\n', + 'Connection: Upgrade\r\n', + 'Upgrade: WebSocket\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + '\r\n', + ), + ( # Origin comes before Host + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Origin: http://example.com\r\n', + 'Host: example.com\r\n', + '\r\n', + ), + ( # Host continued to the next line + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example\r\n', + ' .com\r\n', + 'Origin: http://example.com\r\n', + '\r\n', + ), + ( # Cookie comes before WebSocket-Protocol + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + 'Cookie: xyz\r\n' + 'WebSocket-Protocol: sample\r\n', + '\r\n', + ), + ( # Unknown header + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + 'Content-Type: text/html\r\n' + '\r\n', + ), + ( # Cookie with continuation lines + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + 'Cookie: xyz\r\n', + ' abc\r\n', + ' defg\r\n', + '\r\n', + ), + ( # Wrong-case cookie + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + 'cookie: abc/xyz\r\n' + '\r\n', + ), + ( # Cookie, no space after colon + 'GET /demo HTTP/1.1\r\n', + 'Upgrade: WebSocket\r\n', + 'Connection: Upgrade\r\n', + 'Host: example.com\r\n', + 'Origin: http://example.com\r\n', + 'Cookie:abc/xyz\r\n' + '\r\n', + ), +) + def _create_request(request_def): conn = mock.MockConn('') @@ -229,13 +386,37 @@ def _create_request(request_def): connection=conn) +def _create_get_memorized_lines(lines): + def get_memorized_lines(): + return lines + return get_memorized_lines + + +def _create_requests_with_lines(request_lines_set): + requests = [] + for lines in request_lines_set: + request = _create_request(_GOOD_REQUEST) + request.connection.get_memorized_lines = _create_get_memorized_lines( + lines) + requests.append(request) + return requests + + class HandshakerTest(unittest.TestCase): def test_validate_protocol(self): handshake._validate_protocol('sample') # should succeed. handshake._validate_protocol('Sample') # should succeed. + handshake._validate_protocol('sample\x20protocol') # should succeed. + handshake._validate_protocol('sample\x7eprotocol') # should succeed. + self.assertRaises(handshake.HandshakeError, + handshake._validate_protocol, + '') self.assertRaises(handshake.HandshakeError, handshake._validate_protocol, - 'sample protocol') + 'sample\x19protocol') + self.assertRaises(handshake.HandshakeError, + handshake._validate_protocol, + 'sample\x7fprotocol') self.assertRaises(handshake.HandshakeError, handshake._validate_protocol, # "Japan" in Japanese @@ -308,6 +489,22 @@ class HandshakerTest(unittest.TestCase): mock.MockDispatcher()) self.assertRaises(handshake.HandshakeError, handshaker.do_handshake) + def test_strictly_good_requests(self): + for request in _create_requests_with_lines(_STRICTLY_GOOD_REQUESTS): + strict_handshaker = handshake.Handshaker(request, + mock.MockDispatcher(), + True) + strict_handshaker.do_handshake() + + def test_not_strictly_good_requests(self): + for request in _create_requests_with_lines(_NOT_STRICTLY_GOOD_REQUESTS): + strict_handshaker = handshake.Handshaker(request, + mock.MockDispatcher(), + True) + self.assertRaises(handshake.HandshakeError, + strict_handshaker.do_handshake) + + if __name__ == '__main__': unittest.main() diff --git a/WebKitTools/pywebsocket/test/test_memorizingfile.py b/WebKitTools/pywebsocket/test/test_memorizingfile.py new file mode 100644 index 0000000..2de77ba --- /dev/null +++ b/WebKitTools/pywebsocket/test/test_memorizingfile.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# +# Copyright 2009, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""Tests for memorizingfile module.""" + + +import StringIO +import unittest + +import config # This must be imported before mod_pywebsocket. +from mod_pywebsocket import memorizingfile + + +class UtilTest(unittest.TestCase): + def check(self, memorizing_file, num_read, expected_list): + for unused in range(num_read): + memorizing_file.readline() + actual_list = memorizing_file.get_memorized_lines() + self.assertEqual(len(expected_list), len(actual_list)) + for expected, actual in zip(expected_list, actual_list): + self.assertEqual(expected, actual) + + def test_get_memorized_lines(self): + memorizing_file = memorizingfile.MemorizingFile(StringIO.StringIO( + 'Hello\nWorld\nWelcome')) + self.check(memorizing_file, 3, ['Hello\n', 'World\n', 'Welcome']) + + def test_get_memorized_lines_limit_memorized_lines(self): + memorizing_file = memorizingfile.MemorizingFile(StringIO.StringIO( + 'Hello\nWorld\nWelcome'), 2) + self.check(memorizing_file, 3, ['Hello\n', 'World\n']) + + def test_get_memorized_lines_empty_file(self): + memorizing_file = memorizingfile.MemorizingFile(StringIO.StringIO( + '')) + self.check(memorizing_file, 10, []) + + +if __name__ == '__main__': + unittest.main() + + +# vi:sts=4 sw=4 et -- cgit v1.1