diff options
Diffstat (limited to 'WebCore/websockets')
-rw-r--r-- | WebCore/websockets/WebSocketHandshake.cpp | 276 | ||||
-rw-r--r-- | WebCore/websockets/WebSocketHandshake.h | 9 | ||||
-rw-r--r-- | WebCore/websockets/WorkerThreadableWebSocketChannel.cpp | 2 |
3 files changed, 165 insertions, 122 deletions
diff --git a/WebCore/websockets/WebSocketHandshake.cpp b/WebCore/websockets/WebSocketHandshake.cpp index 1449c89..ea4f5e5 100644 --- a/WebCore/websockets/WebSocketHandshake.cpp +++ b/WebCore/websockets/WebSocketHandshake.cpp @@ -44,22 +44,25 @@ #include "ScriptExecutionContext.h" #include "SecurityOrigin.h" #include "StringBuilder.h" -#include <wtf/text/CString.h> + +#include <wtf/MD5.h> +#include <wtf/RandomNumber.h> +#include <wtf/StdLibExtras.h> #include <wtf/StringExtras.h> #include <wtf/Vector.h> +#include <wtf/text/CString.h> namespace WebCore { -const char webSocketServerHandshakeHeader[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"; -const char webSocketUpgradeHeader[] = "Upgrade: WebSocket\r\n"; -const char webSocketConnectionHeader[] = "Connection: Upgrade\r\n"; +static const char randomCharacterInSecWebSocketKey[] = "!\"#$%&'()*+,-./:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"; -static String extractResponseCode(const char* header, int len) +static String extractResponseCode(const char* header, int len, size_t& lineLength) { const char* space1 = 0; const char* space2 = 0; const char* p; - for (p = header; p - header < len; p++) { + lineLength = 0; + for (p = header; p - header < len; p++, lineLength++) { if (*p == ' ') { if (!space1) space1 = p; @@ -87,6 +90,18 @@ static String resourceName(const KURL& url) return name; } +static String hostName(const KURL& url, bool secure) +{ + ASSERT(url.protocolIs("wss") == secure); + StringBuilder builder; + builder.append(url.host().lower()); + if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) { + builder.append(":"); + builder.append(String::number(url.port())); + } + return builder.toString(); +} + static String trimConsoleMessage(const char* p, size_t len) { String s = String(p, std::min<size_t>(len, 128)); @@ -95,6 +110,58 @@ static String trimConsoleMessage(const char* p, size_t len) return s; } +static void generateSecWebSocketKey(uint32_t& number, String& key) +{ + uint32_t space = static_cast<uint32_t>(randomNumber() * 12) + 1; + uint32_t max = 4294967295U / space; + number = static_cast<uint32_t>(randomNumber() * max); + uint32_t product = number * space; + + String s = String::number(product); + int n = static_cast<int>(randomNumber() * 12) + 1; + DEFINE_STATIC_LOCAL(String, randomChars, (randomCharacterInSecWebSocketKey)); + for (int i = 0; i < n; i++) { + int pos = static_cast<int>(randomNumber() * (s.length() + 1)); + int chpos = static_cast<int>(randomNumber() * randomChars.length()); + s.insert(randomChars.substring(chpos, 1), pos); + } + DEFINE_STATIC_LOCAL(String, spaceChar, (" ")); + for (uint32_t i = 0; i < space; i++) { + int pos = static_cast<int>(randomNumber() * s.length() - 1) + 1; + s.insert(spaceChar, pos); + } + key = s; +} + +static void generateKey3(unsigned char key3[8]) +{ + for (int i = 0; i < 8; i++) + key3[i] = randomNumber() * 256; +} + +static void setChallengeNumber(unsigned char* buf, uint32_t number) +{ + unsigned char* p = buf + 3; + for (int i = 0; i < 4; i++) { + *p = number & 0xFF; + --p; + number >>= 8; + } +} + +static void generateExpectedChallengeResponse(uint32_t number1, uint32_t number2, unsigned char key3[8], unsigned char expectedChallenge[16]) +{ + unsigned char challenge[16]; + setChallengeNumber(&challenge[0], number1); + setChallengeNumber(&challenge[4], number2); + memcpy(&challenge[8], key3, 8); + MD5 md5; + md5.addBytes(challenge, sizeof(challenge)); + Vector<uint8_t, 16> digest; + md5.checksum(digest); + memcpy(expectedChallenge, digest.data(), 16); +} + WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, ScriptExecutionContext* context) : m_url(url) , m_clientProtocol(protocol) @@ -102,6 +169,12 @@ WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, , m_context(context) , m_mode(Incomplete) { + uint32_t number1; + uint32_t number2; + generateSecWebSocketKey(number1, m_secWebSocketKey1); + generateSecWebSocketKey(number2, m_secWebSocketKey2); + generateKey3(m_key3); + generateExpectedChallengeResponse(number1, number2, m_key3, m_expectedChallengeResponse); } WebSocketHandshake::~WebSocketHandshake() @@ -148,13 +221,7 @@ String WebSocketHandshake::clientLocation() const StringBuilder builder; builder.append(m_secure ? "wss" : "ws"); builder.append("://"); - builder.append(m_url.host().lower()); - if (m_url.port()) { - if ((!m_secure && m_url.port() != 80) || (m_secure && m_url.port() != 443)) { - builder.append(":"); - builder.append(String::number(m_url.port())); - } - } + builder.append(hostName(m_url, m_secure)); builder.append(resourceName(m_url)); return builder.toString(); } @@ -167,43 +234,51 @@ CString WebSocketHandshake::clientHandshakeMessage() const builder.append("GET "); builder.append(resourceName(m_url)); builder.append(" HTTP/1.1\r\n"); - builder.append("Upgrade: WebSocket\r\n"); - builder.append("Connection: Upgrade\r\n"); - builder.append("Host: "); - builder.append(m_url.host().lower()); - if (m_url.port() && ((!m_secure && m_url.port() != 80) || (m_secure && m_url.port() != 443))) { - builder.append(":"); - builder.append(String::number(m_url.port())); - } - builder.append("\r\n"); - builder.append("Origin: "); - builder.append(clientOrigin()); - builder.append("\r\n"); - if (!m_clientProtocol.isEmpty()) { - builder.append("WebSocket-Protocol: "); - builder.append(m_clientProtocol); - builder.append("\r\n"); - } + + Vector<String> fields; + fields.append("Upgrade: WebSocket"); + fields.append("Connection: Upgrade"); + fields.append("Host: " + hostName(m_url, m_secure)); + fields.append("Origin: " + clientOrigin()); + if (!m_clientProtocol.isEmpty()) + fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol); KURL url = httpURLForAuthenticationAndCookies(); if (m_context->isDocument()) { Document* document = static_cast<Document*>(m_context); String cookie = cookieRequestHeaderFieldValue(document, url); - if (!cookie.isEmpty()) { - builder.append("Cookie: "); - builder.append(cookie); - builder.append("\r\n"); - } + if (!cookie.isEmpty()) + fields.append("Cookie: " + cookie); // Set "Cookie2: <cookie>" if cookies 2 exists for url? } + fields.append("Sec-WebSocket-Key1: " + m_secWebSocketKey1); + fields.append("Sec-WebSocket-Key2: " + m_secWebSocketKey2); + + // Fields in the handshake are sent by the client in a random order; the + // order is not meaningful. Thus, it's ok to send the order we constructed + // the fields. + + for (size_t i = 0; i < fields.size(); i++) { + builder.append(fields[i]); + builder.append("\r\n"); + } + builder.append("\r\n"); - return builder.toString().utf8(); + + CString handshakeHeader = builder.toString().utf8(); + char* characterBuffer = 0; + CString msg = CString::newUninitialized(handshakeHeader.length() + sizeof(m_key3), characterBuffer); + memcpy(characterBuffer, handshakeHeader.data(), handshakeHeader.length()); + memcpy(characterBuffer + handshakeHeader.length(), m_key3, sizeof(m_key3)); + return msg; } WebSocketHandshakeRequest WebSocketHandshake::clientHandshakeRequest() const { // Keep the following consistent with clientHandshakeMessage(). + // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and + // m_key3 in WebSocketHandshakeRequest? WebSocketHandshakeRequest request(m_url, clientOrigin(), m_clientProtocol); KURL url = httpURLForAuthenticationAndCookies(); @@ -237,89 +312,55 @@ void WebSocketHandshake::clearScriptExecutionContext() int WebSocketHandshake::readServerHandshake(const char* header, size_t len) { m_mode = Incomplete; - if (len < sizeof(webSocketServerHandshakeHeader) - 1) { - // Just hasn't been received fully yet. + size_t lineLength; + const String& code = extractResponseCode(header, len, lineLength); + if (code.isNull()) { + // Just hasn't been received yet. return -1; } - if (!memcmp(header, webSocketServerHandshakeHeader, sizeof(webSocketServerHandshakeHeader) - 1)) - m_mode = Normal; - else { - const String& code = extractResponseCode(header, len); - if (code.isNull()) { - m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Short server handshake: " + trimConsoleMessage(header, len), 0, clientOrigin()); - return -1; - } - if (code.isEmpty()) { - m_mode = Failed; - m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "No response code found: " + trimConsoleMessage(header, len), 0, clientOrigin()); - return len; - } - LOG(Network, "response code: %s", code.utf8().data()); - if (code == "401") { - m_mode = Failed; - m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Authentication required, but not implemented yet.", 0, clientOrigin()); - return len; - } else { - m_mode = Failed; - m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected response code:" + code, 0, clientOrigin()); - return len; - } + if (code.isEmpty()) { + m_mode = Failed; + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "No response code found: " + trimConsoleMessage(header, lineLength), 0, clientOrigin()); + return len; } - const char* p = header + sizeof(webSocketServerHandshakeHeader) - 1; - const char* end = header + len; - - if (m_mode == Normal) { - size_t headerSize = end - p; - if (headerSize < sizeof(webSocketUpgradeHeader) - 1) { - m_mode = Incomplete; - return 0; - } - if (memcmp(p, webSocketUpgradeHeader, sizeof(webSocketUpgradeHeader) - 1)) { - m_mode = Failed; - m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Bad Upgrade header: " + trimConsoleMessage(p, end - p), 0, clientOrigin()); - return p - header + sizeof(webSocketUpgradeHeader) - 1; - } - p += sizeof(webSocketUpgradeHeader) - 1; - - headerSize = end - p; - if (headerSize < sizeof(webSocketConnectionHeader) - 1) { - m_mode = Incomplete; - return -1; - } - if (memcmp(p, webSocketConnectionHeader, sizeof(webSocketConnectionHeader) - 1)) { - m_mode = Failed; - m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Bad Connection header: " + trimConsoleMessage(p, end - p), 0, clientOrigin()); - return p - header + sizeof(webSocketConnectionHeader) - 1; - } - p += sizeof(webSocketConnectionHeader) - 1; + LOG(Network, "response code: %s", code.utf8().data()); + if (code != "101") { + m_mode = Failed; + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected response code:" + code, 0, clientOrigin()); + return len; } - - if (!strnstr(p, "\r\n\r\n", end - p)) { + m_mode = Normal; + if (!strnstr(header, "\r\n\r\n", len)) { // Just hasn't been received fully yet. m_mode = Incomplete; return -1; } HTTPHeaderMap headers; - p = readHTTPHeaders(p, end, &headers); + const char* headerFields = strnstr(header, "\r\n", len); // skip status line + ASSERT(headerFields); + headerFields += 2; // skip "\r\n". + const char* p = readHTTPHeaders(headerFields, header + len, &headers); if (!p) { LOG(Network, "readHTTPHeaders failed"); m_mode = Failed; return len; } - if (!processHeaders(headers)) { + if (!processHeaders(headers) || !checkResponseHeaders()) { LOG(Network, "header process failed"); m_mode = Failed; return p - header; } - switch (m_mode) { - case Normal: - checkResponseHeaders(); - break; - default: + if (len < static_cast<size_t>(p - header + sizeof(m_expectedChallengeResponse))) { + // Just hasn't been received /expected/ yet. + m_mode = Incomplete; + return -1; + } + if (memcmp(p, m_expectedChallengeResponse, sizeof(m_expectedChallengeResponse))) { m_mode = Failed; - break; + return (p - header) + sizeof(m_expectedChallengeResponse); } - return p - header; + m_mode = Connected; + return (p - header) + sizeof(m_expectedChallengeResponse); } WebSocketHandshake::Mode WebSocketHandshake::mode() const @@ -402,10 +443,10 @@ const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* e m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF at " + trimConsoleMessage(p, end - p), 0, clientOrigin()); return 0; } - m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected CR in name at " + trimConsoleMessage(p, end - p), 0, clientOrigin()); + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected CR in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin()); return 0; case '\n': - m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in name at " + trimConsoleMessage(p, end - p), 0, clientOrigin()); + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin()); return 0; case ':': break; @@ -429,7 +470,7 @@ const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* e case '\r': break; case '\n': - m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in value at " + trimConsoleMessage(p, end - p), 0, clientOrigin()); + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in value at " + trimConsoleMessage(value.data(), value.size()), 0, clientOrigin()); return 0; default: value.append(*p); @@ -465,11 +506,11 @@ bool WebSocketHandshake::processHeaders(const HTTPHeaderMap& headers) for (HTTPHeaderMap::const_iterator it = headers.begin(); it != headers.end(); ++it) { switch (m_mode) { case Normal: - if (it->first == "websocket-origin") + if (it->first == "sec-websocket-origin") m_wsOrigin = it->second; - else if (it->first == "websocket-location") + else if (it->first == "sec-websocket-location") m_wsLocation = it->second; - else if (it->first == "websocket-protocol") + else if (it->first == "sec-websocket-protocol") m_wsProtocol = it->second; else if (it->first == "set-cookie") m_setCookie = it->second; @@ -486,35 +527,32 @@ bool WebSocketHandshake::processHeaders(const HTTPHeaderMap& headers) return true; } -void WebSocketHandshake::checkResponseHeaders() +bool WebSocketHandshake::checkResponseHeaders() { - ASSERT(m_mode == Normal); - m_mode = Failed; if (m_wsOrigin.isNull()) { - m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'websocket-origin' header is missing", 0, clientOrigin()); - return; + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'sec-websocket-origin' header is missing", 0, clientOrigin()); + return false; } if (m_wsLocation.isNull()) { - m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'websocket-location' header is missing", 0, clientOrigin()); - return; + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'sec-websocket-location' header is missing", 0, clientOrigin()); + return false; } if (clientOrigin() != m_wsOrigin) { m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: origin mismatch: " + clientOrigin() + " != " + m_wsOrigin, 0, clientOrigin()); - return; + return false; } if (clientLocation() != m_wsLocation) { m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: location mismatch: " + clientLocation() + " != " + m_wsLocation, 0, clientOrigin()); - return; + return false; } if (!m_clientProtocol.isEmpty() && m_clientProtocol != m_wsProtocol) { m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: protocol mismatch: " + m_clientProtocol + " != " + m_wsProtocol, 0, clientOrigin()); - return; + return false; } - m_mode = Connected; - return; + return true; } -} // namespace WebCore +} // namespace WebCore -#endif // ENABLE(WEB_SOCKETS) +#endif // ENABLE(WEB_SOCKETS) diff --git a/WebCore/websockets/WebSocketHandshake.h b/WebCore/websockets/WebSocketHandshake.h index df199ff..3e0c66a 100644 --- a/WebCore/websockets/WebSocketHandshake.h +++ b/WebCore/websockets/WebSocketHandshake.h @@ -40,8 +40,8 @@ namespace WebCore { - class ScriptExecutionContext; class HTTPHeaderMap; + class ScriptExecutionContext; class WebSocketHandshake : public Noncopyable { public: @@ -92,7 +92,7 @@ namespace WebCore { // Reads all headers except for the two predefined ones. const char* readHTTPHeaders(const char* start, const char* end, HTTPHeaderMap* headers); bool processHeaders(const HTTPHeaderMap& headers); - void checkResponseHeaders(); + bool checkResponseHeaders(); KURL m_url; String m_clientProtocol; @@ -106,6 +106,11 @@ namespace WebCore { String m_wsProtocol; String m_setCookie; String m_setCookie2; + + String m_secWebSocketKey1; + String m_secWebSocketKey2; + unsigned char m_key3[8]; + unsigned char m_expectedChallengeResponse[16]; }; } // namespace WebCore diff --git a/WebCore/websockets/WorkerThreadableWebSocketChannel.cpp b/WebCore/websockets/WorkerThreadableWebSocketChannel.cpp index fd86604..f8d1230 100644 --- a/WebCore/websockets/WorkerThreadableWebSocketChannel.cpp +++ b/WebCore/websockets/WorkerThreadableWebSocketChannel.cpp @@ -34,7 +34,7 @@ #include "WorkerThreadableWebSocketChannel.h" -#include "GenericWorkerTask.h" +#include "CrossThreadTask.h" #include "PlatformString.h" #include "ScriptExecutionContext.h" #include "ThreadableWebSocketChannelClientWrapper.h" |