diff options
Diffstat (limited to 'luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java')
-rw-r--r-- | luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java | 249 |
1 files changed, 203 insertions, 46 deletions
diff --git a/luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java b/luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java index 4af7f5a..bf2d0f8 100644 --- a/luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java +++ b/luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java @@ -16,6 +16,8 @@ package libcore.javax.net.ssl; +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; import java.io.EOFException; import java.io.IOException; import java.io.InputStream; @@ -26,15 +28,18 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; -import java.net.SocketAddress; import java.net.SocketException; import java.net.SocketTimeoutException; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; import java.security.Principal; import java.security.PrivateKey; import java.security.cert.Certificate; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -66,11 +71,23 @@ import libcore.io.IoUtils; import libcore.io.Streams; import libcore.java.security.StandardNames; import libcore.java.security.TestKeyStore; +import libcore.tlswire.handshake.CipherSuite; +import libcore.tlswire.handshake.ClientHello; +import libcore.tlswire.handshake.CompressionMethod; +import libcore.tlswire.handshake.HandshakeMessage; +import libcore.tlswire.handshake.HelloExtension; +import libcore.tlswire.handshake.ServerNameHelloExtension; +import libcore.tlswire.record.TlsProtocols; +import libcore.tlswire.record.TlsRecord; +import libcore.tlswire.util.TlsProtocolVersion; +import tests.util.ForEachRunner; +import tests.util.DelegatingSSLSocketFactory; +import tests.util.Pair; public class SSLSocketTest extends TestCase { public void test_SSLSocket_defaultConfiguration() throws Exception { - SSLDefaultConfigurationAsserts.assertSSLSocket( + SSLConfigurationAsserts.assertSSLSocketDefaultConfiguration( (SSLSocket) SSLSocketFactory.getDefault().createSocket()); } @@ -403,6 +420,37 @@ public class SSLSocketTest extends TestCase { c.close(); } + public void test_SSLSocket_NoEnabledCipherSuites_Failure() throws Exception { + TestSSLContext c = TestSSLContext.create(null, null, null, null, null, null, null, null, + SSLContext.getDefault(), SSLContext.getDefault()); + SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host, + c.port); + client.setEnabledCipherSuites(new String[0]); + final SSLSocket server = (SSLSocket) c.serverSocket.accept(); + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future<Void> future = executor.submit(new Callable<Void>() { + @Override + public Void call() throws Exception { + try { + server.startHandshake(); + fail(); + } catch (SSLHandshakeException expected) { + } + return null; + } + }); + executor.shutdown(); + try { + client.startHandshake(); + fail(); + } catch (SSLHandshakeException expected) { + } + future.get(); + server.close(); + client.close(); + c.close(); + } + public void test_SSLSocket_startHandshake_noKeyStore() throws Exception { TestSSLContext c = TestSSLContext.create(null, null, null, null, null, null, null, null, SSLContext.getDefault(), SSLContext.getDefault()); @@ -1458,11 +1506,147 @@ public class SSLSocketTest extends TestCase { test.close(); } - public void test_SSLSocket_ClientHello_size() throws Exception { + public void test_SSLSocket_ClientHello_record_size() throws Exception { // This test checks the size of ClientHello of the default SSLSocket. TLS/SSL handshakes // with older/unpatched F5/BIG-IP appliances are known to stall and time out when // the fragment containing ClientHello is between 256 and 511 (inclusive) bytes long. - // + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, null, null); + SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); + sslSocketFactory = new DelegatingSSLSocketFactory(sslSocketFactory) { + @Override + protected void configureSocket(SSLSocket socket) { + // Enable SNI extension on the socket (this is typically enabled by default) + // to increase the size of ClientHello. + try { + Method setHostname = + socket.getClass().getMethod("setHostname", String.class); + setHostname.invoke(socket, "sslsockettest.androidcts.google.com"); + } catch (NoSuchMethodException ignored) { + } catch (Exception e) { + throw new RuntimeException("Failed to enable SNI", e); + } + + // Enable Session Tickets extension on the socket (this is typically enabled + // by default) to increase the size of ClientHello. + try { + Method setUseSessionTickets = + socket.getClass().getMethod( + "setUseSessionTickets", boolean.class); + setUseSessionTickets.invoke(socket, true); + } catch (NoSuchMethodException ignored) { + } catch (Exception e) { + throw new RuntimeException("Failed to enable Session Tickets", e); + } + } + }; + + TlsRecord firstReceivedTlsRecord = captureTlsHandshakeFirstTlsRecord(sslSocketFactory); + assertEquals("TLS record type", TlsProtocols.HANDSHAKE, firstReceivedTlsRecord.type); + HandshakeMessage handshakeMessage = HandshakeMessage.read( + new DataInputStream(new ByteArrayInputStream(firstReceivedTlsRecord.fragment))); + assertEquals("HandshakeMessage type", + HandshakeMessage.TYPE_CLIENT_HELLO, handshakeMessage.type); + int fragmentLength = firstReceivedTlsRecord.fragment.length; + if ((fragmentLength >= 256) && (fragmentLength <= 511)) { + fail("Fragment containing ClientHello is of dangerous length: " + + fragmentLength + " bytes"); + } + } + + public void test_SSLSocket_ClientHello_cipherSuites() throws Exception { + ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() { + @Override + public void run(SSLSocketFactory sslSocketFactory) throws Exception { + ClientHello clientHello = captureTlsHandshakeClientHello(sslSocketFactory); + String[] cipherSuites = new String[clientHello.cipherSuites.size()]; + for (int i = 0; i < clientHello.cipherSuites.size(); i++) { + CipherSuite cipherSuite = clientHello.cipherSuites.get(i); + cipherSuites[i] = cipherSuite.getAndroidName(); + } + StandardNames.assertDefaultCipherSuites(cipherSuites); + } + }, getSSLSocketFactoriesToTest()); + } + + public void test_SSLSocket_ClientHello_clientProtocolVersion() throws Exception { + ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() { + @Override + public void run(SSLSocketFactory sslSocketFactory) throws Exception { + ClientHello clientHello = captureTlsHandshakeClientHello(sslSocketFactory); + assertEquals(TlsProtocolVersion.TLSv1_2, clientHello.clientVersion); + } + }, getSSLSocketFactoriesToTest()); + } + + public void test_SSLSocket_ClientHello_compressionMethods() throws Exception { + ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() { + @Override + public void run(SSLSocketFactory sslSocketFactory) throws Exception { + ClientHello clientHello = captureTlsHandshakeClientHello(sslSocketFactory); + assertEquals(Arrays.asList(CompressionMethod.NULL), clientHello.compressionMethods); + } + }, getSSLSocketFactoriesToTest()); + } + + public void test_SSLSocket_ClientHello_SNI() throws Exception { + ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() { + @Override + public void run(SSLSocketFactory sslSocketFactory) throws Exception { + ClientHello clientHello = captureTlsHandshakeClientHello(sslSocketFactory); + ServerNameHelloExtension sniExtension = (ServerNameHelloExtension) + clientHello.findExtensionByType(HelloExtension.TYPE_SERVER_NAME); + assertNotNull(sniExtension); + assertEquals(Arrays.asList("localhost.localdomain"), sniExtension.hostnames); + } + }, getSSLSocketFactoriesToTest()); + } + + private List<Pair<String, SSLSocketFactory>> getSSLSocketFactoriesToTest() + throws NoSuchAlgorithmException, KeyManagementException { + List<Pair<String, SSLSocketFactory>> result = + new ArrayList<Pair<String, SSLSocketFactory>>(); + result.add(Pair.of("default", (SSLSocketFactory) SSLSocketFactory.getDefault())); + for (String sslContextProtocol : StandardNames.SSL_CONTEXT_PROTOCOLS) { + SSLContext sslContext = SSLContext.getInstance(sslContextProtocol); + if (StandardNames.SSL_CONTEXT_PROTOCOLS_DEFAULT.equals(sslContextProtocol)) { + continue; + } + sslContext.init(null, null, null); + result.add(Pair.of( + "SSLContext(\"" + sslContext.getProtocol() + "\")", + sslContext.getSocketFactory())); + } + return result; + } + + private ClientHello captureTlsHandshakeClientHello(SSLSocketFactory sslSocketFactory) + throws Exception { + TlsRecord record = captureTlsHandshakeFirstTlsRecord(sslSocketFactory); + assertEquals("TLS record type", TlsProtocols.HANDSHAKE, record.type); + ByteArrayInputStream fragmentIn = new ByteArrayInputStream(record.fragment); + HandshakeMessage handshakeMessage = HandshakeMessage.read(new DataInputStream(fragmentIn)); + assertEquals("HandshakeMessage type", + HandshakeMessage.TYPE_CLIENT_HELLO, handshakeMessage.type); + // Assert that the fragment does not contain any more messages + assertEquals(0, fragmentIn.available()); + + return (ClientHello) handshakeMessage; + } + + private TlsRecord captureTlsHandshakeFirstTlsRecord(SSLSocketFactory sslSocketFactory) + throws Exception { + byte[] firstReceivedChunk = captureTlsHandshakeFirstTransmittedChunkBytes(sslSocketFactory); + ByteArrayInputStream firstReceivedChunkIn = new ByteArrayInputStream(firstReceivedChunk); + TlsRecord record = TlsRecord.read(new DataInputStream(firstReceivedChunkIn)); + // Assert that the chunk does not contain any more data + assertEquals(0, firstReceivedChunkIn.available()); + + return record; + } + + private byte[] captureTlsHandshakeFirstTransmittedChunkBytes( + final SSLSocketFactory sslSocketFactory) throws Exception { // Since there's no straightforward way to obtain a ClientHello from SSLSocket, this test // does the following: // 1. Creates a listening server socket (a plain one rather than a TLS/SSL one). @@ -1506,61 +1690,31 @@ public class SSLSocketTest extends TestCase { executorService.submit(new Callable<Void>() { @Override public Void call() throws Exception { - SSLContext sslContext = SSLContext.getInstance("TLS"); - sslContext.init(null, null, null); - SSLSocket client = (SSLSocket) sslContext.getSocketFactory().createSocket(); + Socket client = new Socket(); sockets[0] = client; try { - // Enable SNI extension on the socket (this is typically enabled by default) - // to increase the size of ClientHello. - try { - Method setHostname = - client.getClass().getMethod("setHostname", String.class); - setHostname.invoke(client, "sslsockettest.androidcts.google.com"); - } catch (NoSuchMethodException ignored) {} - - // Enable Session Tickets extension on the socket (this is typically enabled - // by default) to increase the size of ClientHello. - try { - Method setUseSessionTickets = - client.getClass().getMethod( - "setUseSessionTickets", boolean.class); - setUseSessionTickets.invoke(client, true); - } catch (NoSuchMethodException ignored) {} - client.connect(finalListeningSocket.getLocalSocketAddress()); // Initiate the TLS/SSL handshake which is expected to fail as soon as the // server socket receives a ClientHello. try { - client.startHandshake(); + SSLSocket sslSocket = (SSLSocket) sslSocketFactory.createSocket( + client, + "localhost.localdomain", + finalListeningSocket.getLocalPort(), + true); + sslSocket.startHandshake(); fail(); return null; } catch (IOException expected) {} return null; } finally { IoUtils.closeQuietly(client); - - // Cancel the reading task. If this task succeeded, then the reading task - // is done and this will have no effect. If this task failed prematurely, - // then the reading task might get unblocked (we're interrupting the thread - // it's running on), will fail early, and we'll thus save some time in this - // test. - readFirstReceivedChunkFuture.cancel(true); } } }); // Wait for the ClientHello to arrive - byte[] clientHello = readFirstReceivedChunkFuture.get(10, TimeUnit.SECONDS); - - // Check for ClientHello length that may cause handshake to fail/time out with older - // F5/BIG-IP appliances. - assertEquals("TLS record type: handshake", 22, clientHello[0]); - int fragmentLength = ((clientHello[3] & 0xff) << 8) | (clientHello[4] & 0xff); - if ((fragmentLength >= 256) && (fragmentLength <= 511)) { - fail("Fragment containing ClientHello is of dangerous length: " - + fragmentLength + " bytes"); - } + return readFirstReceivedChunkFuture.get(10, TimeUnit.SECONDS); } finally { executorService.shutdownNow(); IoUtils.closeQuietly(listeningSocket); @@ -1668,6 +1822,11 @@ public class SSLSocketTest extends TestCase { context.close(); } + private static void assertInappropriateFallbackIsCause(Throwable cause) { + assertTrue(cause.getMessage(), cause.getMessage().contains("inappropriate fallback") + || cause.getMessage().contains("INAPPROPRIATE_FALLBACK")); + } + public void test_SSLSocket_sendsTlsFallbackScsv_InappropriateFallback_Failure() throws Exception { TestSSLContext context = TestSSLContext.create(); @@ -1693,8 +1852,7 @@ public class SSLSocketTest extends TestCase { } catch (SSLHandshakeException expected) { Throwable cause = expected.getCause(); assertEquals(SSLProtocolException.class, cause.getClass()); - assertTrue(cause.getMessage(), - cause.getMessage().contains("inappropriate fallback")); + assertInappropriateFallbackIsCause(cause); } return null; } @@ -1709,8 +1867,7 @@ public class SSLSocketTest extends TestCase { } catch (SSLHandshakeException expected) { Throwable cause = expected.getCause(); assertEquals(SSLProtocolException.class, cause.getClass()); - assertTrue(cause.getMessage(), - cause.getMessage().contains("inappropriate fallback")); + assertInappropriateFallbackIsCause(cause); } return null; } |