From 4f60362bb61f8199f3e97371eff96461fbacba68 Mon Sep 17 00:00:00 2001
From: Kenny Root <kroot@google.com>
Date: Fri, 3 May 2013 10:51:56 -0700
Subject: NativeCrypto: move key conversion to Java

Key type conversion in native code is from the legacy period before the
OpenSSLKey class existed. Use that to hold PKEY reference instead of
converting it in native code.

Change-Id: I84e9a6e1f2e0f95d2f44c18fa9f65cd15e039d63
---
 .../src/main/java/org/conscrypt/NativeCrypto.java  |  26 +---
 crypto/src/main/java/org/conscrypt/OpenSSLKey.java |  13 ++
 .../main/java/org/conscrypt/OpenSSLSocketImpl.java |  46 ++++---
 .../src/main/native/org_conscrypt_NativeCrypto.cpp | 146 +++------------------
 .../test/java/org/conscrypt/NativeCryptoTest.java  |  56 ++++----
 5 files changed, 83 insertions(+), 204 deletions(-)

(limited to 'crypto')

diff --git a/crypto/src/main/java/org/conscrypt/NativeCrypto.java b/crypto/src/main/java/org/conscrypt/NativeCrypto.java
index 1416d1b..208f89e 100644
--- a/crypto/src/main/java/org/conscrypt/NativeCrypto.java
+++ b/crypto/src/main/java/org/conscrypt/NativeCrypto.java
@@ -770,27 +770,7 @@ public final class NativeCrypto {
 
     public static native byte[] SSL_get_tls_channel_id(long ssl) throws SSLException;
 
-    public static native void SSL_use_OpenSSL_PrivateKey_for_tls_channel_id(long ssl, long pkey)
-            throws SSLException;
-
-    public static native void SSL_use_PKCS8_PrivateKey_for_tls_channel_id(
-            long ssl, byte[] pkcs8EncodedPrivateKey) throws SSLException;
-
-    public static void SSL_set1_tls_channel_id(long ssl, PrivateKey privateKey)
-            throws SSLException {
-        if (privateKey == null) {
-            throw new NullPointerException("privateKey == null");
-        } else if (privateKey instanceof OpenSSLECPrivateKey) {
-            OpenSSLKey openSslPrivateKey = ((OpenSSLECPrivateKey) privateKey).getOpenSSLKey();
-            SSL_use_OpenSSL_PrivateKey_for_tls_channel_id(ssl, openSslPrivateKey.getPkeyContext());
-        } else if ("PKCS#8".equals(privateKey.getFormat())) {
-            byte[] pkcs8EncodedKey = privateKey.getEncoded();
-            SSL_use_PKCS8_PrivateKey_for_tls_channel_id(ssl, pkcs8EncodedKey);
-        } else {
-            throw new SSLException("Unsupported Channel ID private key type:" +
-                    " class: " + privateKey.getClass() + ", format: " + privateKey.getFormat());
-        }
-    }
+    public static native void SSL_set1_tls_channel_id(long ssl, long pkey);
 
     public static byte[][] encodeCertificates(Certificate[] certificates)
             throws CertificateEncodingException {
@@ -803,9 +783,7 @@ public final class NativeCrypto {
 
     public static native void SSL_use_certificate(long ssl, byte[][] asn1DerEncodedCertificateChain);
 
-    public static native void SSL_use_OpenSSL_PrivateKey(long ssl, long pkey);
-
-    public static native void SSL_use_PrivateKey(long ssl, byte[] pkcs8EncodedPrivateKey);
+    public static native void SSL_use_PrivateKey(long ssl, long pkey);
 
     public static native void SSL_check_private_key(long ssl) throws SSLException;
 
diff --git a/crypto/src/main/java/org/conscrypt/OpenSSLKey.java b/crypto/src/main/java/org/conscrypt/OpenSSLKey.java
index 9dac153..fa9a258 100644
--- a/crypto/src/main/java/org/conscrypt/OpenSSLKey.java
+++ b/crypto/src/main/java/org/conscrypt/OpenSSLKey.java
@@ -16,6 +16,7 @@
 
 package org.conscrypt;
 
+import java.security.InvalidKeyException;
 import java.security.NoSuchAlgorithmException;
 import java.security.PrivateKey;
 import java.security.PublicKey;
@@ -64,6 +65,18 @@ public class OpenSSLKey {
         return alias;
     }
 
+    public static OpenSSLKey fromPrivateKey(PrivateKey key) throws InvalidKeyException {
+        if (key instanceof OpenSSLKeyHolder) {
+            return ((OpenSSLKeyHolder) key).getOpenSSLKey();
+        }
+
+        if ("PKCS#8".equals(key.getFormat())) {
+            return new OpenSSLKey(NativeCrypto.d2i_PKCS8_PRIV_KEY_INFO(key.getEncoded()));
+        } else {
+            throw new InvalidKeyException("Unknown key format " + key.getFormat());
+        }
+    }
+
     public PublicKey getPublicKey() throws NoSuchAlgorithmException {
         switch (NativeCrypto.EVP_PKEY_type(ctx)) {
             case NativeCrypto.EVP_PKEY_RSA:
diff --git a/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java b/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java
index 3301387..6d51ccb 100644
--- a/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java
+++ b/crypto/src/main/java/org/conscrypt/OpenSSLSocketImpl.java
@@ -25,6 +25,7 @@ import java.io.OutputStream;
 import java.net.InetAddress;
 import java.net.Socket;
 import java.net.SocketException;
+import java.security.InvalidKeyException;
 import java.security.PrivateKey;
 import java.security.SecureRandom;
 import java.security.cert.CertificateEncodingException;
@@ -78,7 +79,7 @@ public class OpenSSLSocketImpl
     /** Whether the TLS Channel ID extension is enabled. This field is server-side only. */
     private boolean channelIdEnabled;
     /** Private key for the TLS Channel ID extension. This field is client-side only. */
-    private PrivateKey channelIdPrivateKey;
+    private OpenSSLKey channelIdPrivateKey;
     private OpenSSLSessionImpl sslSession;
     private final Socket socket;
     private boolean autoClose;
@@ -380,14 +381,16 @@ public class OpenSSLSocketImpl
             }
 
             // TLS Channel ID
-            if (client) {
-                // Client-side TLS Channel ID
-                if (channelIdPrivateKey != null) {
-                    NativeCrypto.SSL_set1_tls_channel_id(sslNativePointer, channelIdPrivateKey);
-                }
-            } else {
-                // Server-side TLS Channel ID
-                if (channelIdEnabled) {
+            if (channelIdEnabled) {
+                if (client) {
+                    // Client-side TLS Channel ID
+                    if (channelIdPrivateKey == null) {
+                        throw new SSLHandshakeException("Invalid TLS channel ID key specified");
+                    }
+                    NativeCrypto.SSL_set1_tls_channel_id(sslNativePointer,
+                            channelIdPrivateKey.getPkeyContext());
+                } else {
+                    // Server-side TLS Channel ID
                     NativeCrypto.SSL_enable_tls_channel_id(sslNativePointer);
                 }
             }
@@ -497,14 +500,11 @@ public class OpenSSLSocketImpl
             return;
         }
 
-        if (privateKey instanceof OpenSSLKeyHolder) {
-            OpenSSLKey key = ((OpenSSLKeyHolder) privateKey).getOpenSSLKey();
-            NativeCrypto.SSL_use_OpenSSL_PrivateKey(sslNativePointer, key.getPkeyContext());
-        } else if ("PKCS#8".equals(privateKey.getFormat())) {
-            byte[] privateKeyBytes = privateKey.getEncoded();
-            NativeCrypto.SSL_use_PrivateKey(sslNativePointer, privateKeyBytes);
-        } else {
-            throw new SSLException("Unsupported PrivateKey format: " + privateKey.getFormat());
+        try {
+            final OpenSSLKey key = OpenSSLKey.fromPrivateKey(privateKey);
+            NativeCrypto.SSL_use_PrivateKey(sslNativePointer, key.getPkeyContext());
+        } catch (InvalidKeyException e) {
+            throw new SSLException(e);
         }
 
         byte[][] certificateBytes = NativeCrypto.encodeCertificates(certificates);
@@ -879,7 +879,17 @@ public class OpenSSLSocketImpl
                     "Could not change Channel ID private key after the initial handshake has"
                     + " begun.");
         }
-        this.channelIdPrivateKey = privateKey;
+        if (privateKey == null) {
+            this.channelIdEnabled = false;
+            this.channelIdPrivateKey = null;
+        } else {
+            this.channelIdEnabled = true;
+            try {
+                this.channelIdPrivateKey = OpenSSLKey.fromPrivateKey(privateKey);
+            } catch (InvalidKeyException e) {
+                // Will have error in startHandshake
+            }
+        }
     }
 
     @Override public boolean getUseClientMode() {
diff --git a/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp b/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp
index 077c391..ac1f321 100644
--- a/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp
+++ b/crypto/src/main/native/org_conscrypt_NativeCrypto.cpp
@@ -6216,18 +6216,19 @@ static jbyteArray NativeCrypto_SSL_get_tls_channel_id(JNIEnv* env, jclass, jlong
     return javaBytes;
 }
 
-static void NativeCrypto_SSL_use_OpenSSL_PrivateKey_for_tls_channel_id(
-        JNIEnv* env, jclass, jlong ssl_address, jlong pkeyRef)
+static void NativeCrypto_SSL_set1_tls_channel_id(JNIEnv* env, jclass,
+        jlong ssl_address, jlong pkeyRef)
 {
     SSL* ssl = to_SSL(env, ssl_address, true);
     EVP_PKEY* pkey = reinterpret_cast<EVP_PKEY*>(pkeyRef);
-    JNI_TRACE("ssl=%p NativeCrypto_SSL_use_OpenSSL_PrivateKey_for_tls_channel_id privatekey=%p",
-            ssl, pkey);
+    JNI_TRACE("ssl=%p SSL_set1_tls_channel_id privatekey=%p", ssl, pkey);
     if (ssl == NULL) {
         return;
     }
 
     if (pkey == NULL) {
+        jniThrowNullPointerException(env, "pkey == null");
+        JNI_TRACE("ssl=%p SSL_set1_tls_channel_id => pkey == null", ssl);
         return;
     }
 
@@ -6242,86 +6243,28 @@ static void NativeCrypto_SSL_use_OpenSSL_PrivateKey_for_tls_channel_id(
         throwSSLExceptionWithSslErrors(
                 env, ssl, SSL_ERROR_NONE, "Error setting private key for Channel ID");
         SSL_clear(ssl);
-        JNI_TRACE("ssl=%p NativeCrypto_SSL_use_OpenSSL_PrivateKey_for_tls_channel_id => error", ssl);
+        JNI_TRACE("ssl=%p SSL_set1_tls_channel_id => error", ssl);
         return;
     }
-    // SSL_use_PrivateKey expects to take ownership of the EVP_PKEY,
-    // but we have an external reference from the caller such as an
-    // OpenSSLKey, so we manually increment the reference count here.
+    // SSL_set1_tls_channel_id expects to take ownership of the EVP_PKEY, but
+    // we have an external reference from the caller such as an OpenSSLKey,
+    // so we manually increment the reference count here.
     CRYPTO_add(&pkey->references,+1,CRYPTO_LOCK_EVP_PKEY);
 
-    JNI_TRACE("ssl=%p NativeCrypto_SSL_use_OpenSSL_PrivateKey_for_tls_channel_id => ok", ssl);
+    JNI_TRACE("ssl=%p SSL_set1_tls_channel_id => ok", ssl);
 }
 
-static void NativeCrypto_SSL_use_PKCS8_PrivateKey_for_tls_channel_id(
-        JNIEnv* env, jclass, jlong ssl_address, jbyteArray privatekey)
-{
-    SSL* ssl = to_SSL(env, ssl_address, true);
-    JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey_for_tls_channel_id privatekey=%p", ssl,
-            privatekey);
-    if (ssl == NULL) {
-        return;
-    }
-
-    ScopedByteArrayRO buf(env, privatekey);
-    if (buf.get() == NULL) {
-        JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey_for_tls_channel_id => threw exception",
-                ssl);
-        return;
-    }
-    const unsigned char* tmp = reinterpret_cast<const unsigned char*>(buf.get());
-    Unique_PKCS8_PRIV_KEY_INFO pkcs8(d2i_PKCS8_PRIV_KEY_INFO(NULL, &tmp, buf.size()));
-    if (pkcs8.get() == NULL) {
-        ALOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
-        throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE,
-                                       "Error parsing private key from DER to PKCS8");
-        SSL_clear(ssl);
-        JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey => error from DER to PKCS8", ssl);
-        return;
-    }
-
-    Unique_EVP_PKEY privatekeyevp(EVP_PKCS82PKEY(pkcs8.get()));
-    if (privatekeyevp.get() == NULL) {
-        ALOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
-        throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE,
-                                       "Error creating private key from PKCS8");
-        SSL_clear(ssl);
-        JNI_TRACE(
-                "ssl=%p NativeCrypto_SSL_use_PrivateKey_for_tls_channel_id => error from PKCS8 to key",
-                ssl);
-        return;
-    }
-
-    JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey_for_tls_channel_id EVP_PKEY_type=%d",
-              ssl, EVP_PKEY_type(privatekeyevp.get()->type));
-
-    // SSL_set1_tls_channel_id requires ssl->server to be set to 0.
-    // Unfortunately, the default value is 1 and it's only changed to 0 just
-    // before the handshake starts (see NativeCrypto_SSL_do_handshake).
-    ssl->server = 0;
-    long ret = SSL_set1_tls_channel_id(ssl, privatekeyevp.get());
-    if (ret == 1L) {
-        OWNERSHIP_TRANSFERRED(privatekeyevp);
-    } else {
-        ALOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
-        throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE, "Error setting private key");
-        SSL_clear(ssl);
-        JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey_for_tls_channel_id => error", ssl);
-        return;
-    }
-
-    JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey_for_tls_channel_id => ok", ssl);
-}
-
-static void NativeCrypto_SSL_use_OpenSSL_PrivateKey(JNIEnv* env, jclass, jlong ssl_address, jlong pkeyRef) {
+static void NativeCrypto_SSL_use_PrivateKey(JNIEnv* env, jclass, jlong ssl_address, jlong pkeyRef) {
     SSL* ssl = to_SSL(env, ssl_address, true);
     EVP_PKEY* pkey = reinterpret_cast<EVP_PKEY*>(pkeyRef);
-    JNI_TRACE("ssl=%p SSL_use_OpenSSL_PrivateKey privatekey=%p", ssl, pkey);
+    JNI_TRACE("ssl=%p SSL_use_PrivateKey privatekey=%p", ssl, pkey);
     if (ssl == NULL) {
         return;
     }
 
     if (pkey == NULL) {
+        jniThrowNullPointerException(env, "pkey == null");
+        JNI_TRACE("ssl=%p SSL_use_PrivateKey => pkey == null", ssl);
         return;
     }
 
@@ -6330,7 +6273,7 @@ static void NativeCrypto_SSL_use_OpenSSL_PrivateKey(JNIEnv* env, jclass, jlong s
         ALOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
         throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE, "Error setting private key");
         SSL_clear(ssl);
-        JNI_TRACE("ssl=%p SSL_use_OpenSSL_PrivateKey => error", ssl);
+        JNI_TRACE("ssl=%p SSL_use_PrivateKey => error", ssl);
         return;
     }
     // SSL_use_PrivateKey expects to take ownership of the EVP_PKEY,
@@ -6338,58 +6281,7 @@ static void NativeCrypto_SSL_use_OpenSSL_PrivateKey(JNIEnv* env, jclass, jlong s
     // OpenSSLKey, so we manually increment the reference count here.
     CRYPTO_add(&pkey->references,+1,CRYPTO_LOCK_EVP_PKEY);
 
-    JNI_TRACE("ssl=%p SSL_use_OpenSSL_PrivateKey => ok", ssl);
-}
-
-static void NativeCrypto_SSL_use_PrivateKey(JNIEnv* env, jclass,
-                                            jlong ssl_address, jbyteArray privatekey)
-{
-    SSL* ssl = to_SSL(env, ssl_address, true);
-    JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey privatekey=%p", ssl, privatekey);
-    if (ssl == NULL) {
-        return;
-    }
-
-    ScopedByteArrayRO buf(env, privatekey);
-    if (buf.get() == NULL) {
-        JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey => threw exception", ssl);
-        return;
-    }
-    const unsigned char* tmp = reinterpret_cast<const unsigned char*>(buf.get());
-    Unique_PKCS8_PRIV_KEY_INFO pkcs8(d2i_PKCS8_PRIV_KEY_INFO(NULL, &tmp, buf.size()));
-    if (pkcs8.get() == NULL) {
-        ALOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
-        throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE,
-                                       "Error parsing private key from DER to PKCS8");
-        SSL_clear(ssl);
-        JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey => error from DER to PKCS8", ssl);
-        return;
-    }
-
-    Unique_EVP_PKEY privatekeyevp(EVP_PKCS82PKEY(pkcs8.get()));
-    if (privatekeyevp.get() == NULL) {
-        ALOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
-        throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE,
-                                       "Error creating private key from PKCS8");
-        SSL_clear(ssl);
-        JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey => error from PKCS8 to key", ssl);
-        return;
-    }
-
-    JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey EVP_PKEY_type=%d",
-              ssl, EVP_PKEY_type(privatekeyevp.get()->type));
-    int ret = SSL_use_PrivateKey(ssl, privatekeyevp.get());
-    if (ret == 1) {
-        OWNERSHIP_TRANSFERRED(privatekeyevp);
-    } else {
-        ALOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
-        throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE, "Error setting private key");
-        SSL_clear(ssl);
-        JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey => error", ssl);
-        return;
-    }
-
-    JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey => ok", ssl);
+    JNI_TRACE("ssl=%p SSL_use_PrivateKey => ok", ssl);
 }
 
 static void NativeCrypto_SSL_use_certificate(JNIEnv* env, jclass,
@@ -8005,10 +7897,8 @@ static JNINativeMethod sNativeCryptoMethods[] = {
     NATIVE_METHOD(NativeCrypto, SSL_new, "(J)J"),
     NATIVE_METHOD(NativeCrypto, SSL_enable_tls_channel_id, "(J)V"),
     NATIVE_METHOD(NativeCrypto, SSL_get_tls_channel_id, "(J)[B"),
-    NATIVE_METHOD(NativeCrypto, SSL_use_OpenSSL_PrivateKey_for_tls_channel_id, "(JJ)V"),
-    NATIVE_METHOD(NativeCrypto, SSL_use_PKCS8_PrivateKey_for_tls_channel_id, "(J[B)V"),
-    NATIVE_METHOD(NativeCrypto, SSL_use_OpenSSL_PrivateKey, "(JJ)V"),
-    NATIVE_METHOD(NativeCrypto, SSL_use_PrivateKey, "(J[B)V"),
+    NATIVE_METHOD(NativeCrypto, SSL_set1_tls_channel_id, "(JJ)V"),
+    NATIVE_METHOD(NativeCrypto, SSL_use_PrivateKey, "(JJ)V"),
     NATIVE_METHOD(NativeCrypto, SSL_use_certificate, "(J[[B)V"),
     NATIVE_METHOD(NativeCrypto, SSL_check_private_key, "(J)V"),
     NATIVE_METHOD(NativeCrypto, SSL_set_client_CA_list, "(J[[B)V"),
diff --git a/crypto/src/test/java/org/conscrypt/NativeCryptoTest.java b/crypto/src/test/java/org/conscrypt/NativeCryptoTest.java
index b55a745..6b8be07 100644
--- a/crypto/src/test/java/org/conscrypt/NativeCryptoTest.java
+++ b/crypto/src/test/java/org/conscrypt/NativeCryptoTest.java
@@ -29,7 +29,6 @@ import java.security.KeyPair;
 import java.security.KeyPairGenerator;
 import java.security.KeyStore;
 import java.security.KeyStore.PrivateKeyEntry;
-import java.security.PrivateKey;
 import java.security.cert.CertificateException;
 import java.security.cert.X509Certificate;
 import java.security.interfaces.DSAPublicKey;
@@ -67,12 +66,12 @@ public class NativeCryptoTest extends TestCase {
 
     private static final long TIMEOUT_SECONDS = 5;
 
-    private static byte[] SERVER_PRIVATE_KEY;
+    private static OpenSSLKey SERVER_PRIVATE_KEY;
     private static byte[][] SERVER_CERTIFICATES;
-    private static byte[] CLIENT_PRIVATE_KEY;
+    private static OpenSSLKey CLIENT_PRIVATE_KEY;
     private static byte[][] CLIENT_CERTIFICATES;
     private static byte[][] CA_PRINCIPALS;
-    private static PrivateKey CHANNEL_ID_PRIVATE_KEY;
+    private static OpenSSLKey CHANNEL_ID_PRIVATE_KEY;
     private static byte[] CHANNEL_ID;
 
     @Override
@@ -80,7 +79,7 @@ public class NativeCryptoTest extends TestCase {
         assertEquals(0, NativeCrypto.ERR_peek_last_error());
     }
 
-    private static byte[] getServerPrivateKey() {
+    private static OpenSSLKey getServerPrivateKey() {
         initCerts();
         return SERVER_PRIVATE_KEY;
     }
@@ -90,7 +89,7 @@ public class NativeCryptoTest extends TestCase {
         return SERVER_CERTIFICATES;
     }
 
-    private static byte[] getClientPrivateKey() {
+    private static OpenSSLKey getClientPrivateKey() {
         initCerts();
         return CLIENT_PRIVATE_KEY;
     }
@@ -116,13 +115,13 @@ public class NativeCryptoTest extends TestCase {
         try {
             PrivateKeyEntry serverPrivateKeyEntry
                     = TestKeyStore.getServer().getPrivateKey("RSA", "RSA");
-            SERVER_PRIVATE_KEY = serverPrivateKeyEntry.getPrivateKey().getEncoded();
+            SERVER_PRIVATE_KEY = OpenSSLKey.fromPrivateKey(serverPrivateKeyEntry.getPrivateKey());
             SERVER_CERTIFICATES = NativeCrypto.encodeCertificates(
                     serverPrivateKeyEntry.getCertificateChain());
 
             PrivateKeyEntry clientPrivateKeyEntry
                     = TestKeyStore.getClientCertificate().getPrivateKey("RSA", "RSA");
-            CLIENT_PRIVATE_KEY = clientPrivateKeyEntry.getPrivateKey().getEncoded();
+            CLIENT_PRIVATE_KEY = OpenSSLKey.fromPrivateKey(clientPrivateKeyEntry.getPrivateKey());
             CLIENT_CERTIFICATES = NativeCrypto.encodeCertificates(
                     clientPrivateKeyEntry.getCertificateChain());
 
@@ -147,7 +146,7 @@ public class NativeCryptoTest extends TestCase {
         BigInteger s = new BigInteger(
                 "229cdbbf489aea584828a261a23f9ff8b0f66f7ccac98bf2096ab3aee41497c5", 16);
         CHANNEL_ID_PRIVATE_KEY = new OpenSSLECPrivateKey(
-                new ECPrivateKeySpec(s, openSslSpec.getECParameterSpec()));
+                new ECPrivateKeySpec(s, openSslSpec.getECParameterSpec())).getOpenSSLKey();
 
         // Channel ID is the concatenation of the X and Y coordinates of the public key.
         CHANNEL_ID = new BigInteger(
@@ -339,7 +338,7 @@ public class NativeCryptoTest extends TestCase {
 
     public void test_SSL_use_PrivateKey_for_tls_channel_id() throws Exception {
         try {
-            NativeCrypto.SSL_set1_tls_channel_id(NULL, null);
+            NativeCrypto.SSL_set1_tls_channel_id(NULL, NULL);
             fail();
         } catch (NullPointerException expected) {
         }
@@ -348,25 +347,14 @@ public class NativeCryptoTest extends TestCase {
         long s = NativeCrypto.SSL_new(c);
 
         try {
-            NativeCrypto.SSL_set1_tls_channel_id(s, null);
+            NativeCrypto.SSL_set1_tls_channel_id(s, NULL);
             fail();
         } catch (NullPointerException expected) {
         }
 
-        // Use the key via the wrapper that decides whether to use PKCS#8 or native OpenSSL.
-        NativeCrypto.SSL_set1_tls_channel_id(s, CHANNEL_ID_PRIVATE_KEY);
-
-        // Use the key via its PKCS#8 representation.
-        assertEquals("PKCS#8", CHANNEL_ID_PRIVATE_KEY.getFormat());
-        byte[] pkcs8EncodedKeyBytes = CHANNEL_ID_PRIVATE_KEY.getEncoded();
-        assertNotNull(pkcs8EncodedKeyBytes);
-        NativeCrypto.SSL_use_PKCS8_PrivateKey_for_tls_channel_id(s, pkcs8EncodedKeyBytes);
-
         // Use the key natively. This works because the initChannelIdKey method ensures that the
         // key is backed by OpenSSL.
-        NativeCrypto.SSL_use_OpenSSL_PrivateKey_for_tls_channel_id(
-                s,
-                ((OpenSSLECPrivateKey) CHANNEL_ID_PRIVATE_KEY).getOpenSSLKey().getPkeyContext());
+        NativeCrypto.SSL_set1_tls_channel_id(s, CHANNEL_ID_PRIVATE_KEY.getPkeyContext());
 
         NativeCrypto.SSL_free(s);
         NativeCrypto.SSL_CTX_free(c);
@@ -374,7 +362,7 @@ public class NativeCryptoTest extends TestCase {
 
     public void test_SSL_use_PrivateKey() throws Exception {
         try {
-            NativeCrypto.SSL_use_PrivateKey(NULL, null);
+            NativeCrypto.SSL_use_PrivateKey(NULL, NULL);
             fail();
         } catch (NullPointerException expected) {
         }
@@ -383,12 +371,12 @@ public class NativeCryptoTest extends TestCase {
         long s = NativeCrypto.SSL_new(c);
 
         try {
-            NativeCrypto.SSL_use_PrivateKey(s, null);
+            NativeCrypto.SSL_use_PrivateKey(s, NULL);
             fail();
         } catch (NullPointerException expected) {
         }
 
-        NativeCrypto.SSL_use_PrivateKey(s, getServerPrivateKey());
+        NativeCrypto.SSL_use_PrivateKey(s, getServerPrivateKey().getPkeyContext());
 
         NativeCrypto.SSL_free(s);
         NativeCrypto.SSL_CTX_free(c);
@@ -430,7 +418,7 @@ public class NativeCryptoTest extends TestCase {
         } catch (SSLException expected) {
         }
 
-        NativeCrypto.SSL_use_PrivateKey(s, getServerPrivateKey());
+        NativeCrypto.SSL_use_PrivateKey(s, getServerPrivateKey().getPkeyContext());
         NativeCrypto.SSL_check_private_key(s);
 
         NativeCrypto.SSL_free(s);
@@ -441,7 +429,7 @@ public class NativeCryptoTest extends TestCase {
         long s = NativeCrypto.SSL_new(c);
 
         // first private, then certificate
-        NativeCrypto.SSL_use_PrivateKey(s, getServerPrivateKey());
+        NativeCrypto.SSL_use_PrivateKey(s, getServerPrivateKey().getPkeyContext());
 
         try {
             NativeCrypto.SSL_check_private_key(s);
@@ -613,7 +601,7 @@ public class NativeCryptoTest extends TestCase {
     private static final boolean DEBUG = false;
 
     public static class Hooks {
-        private PrivateKey channelIdPrivateKey;
+        private OpenSSLKey channelIdPrivateKey;
 
         public long getContext() throws SSLException {
             return NativeCrypto.SSL_CTX_new();
@@ -626,7 +614,7 @@ public class NativeCryptoTest extends TestCase {
             NativeCrypto.SSL_set_cipher_lists(s, new String[] { "RC4-MD5" });
 
             if (channelIdPrivateKey != null) {
-                NativeCrypto.SSL_set1_tls_channel_id(s, channelIdPrivateKey);
+                NativeCrypto.SSL_set1_tls_channel_id(s, channelIdPrivateKey.getPkeyContext());
             }
             return s;
         }
@@ -721,13 +709,13 @@ public class NativeCryptoTest extends TestCase {
     }
 
     public static class ServerHooks extends Hooks {
-        private final byte[] privateKey;
+        private final OpenSSLKey privateKey;
         private final byte[][] certificates;
         private boolean channelIdEnabled;
         private byte[] channelIdAfterHandshake;
         private Throwable channelIdAfterHandshakeException;
 
-        public ServerHooks(byte[] privateKey, byte[][] certificates) {
+        public ServerHooks(OpenSSLKey privateKey, byte[][] certificates) {
             this.privateKey = privateKey;
             this.certificates = certificates;
         }
@@ -736,7 +724,7 @@ public class NativeCryptoTest extends TestCase {
         public long beforeHandshake(long c) throws SSLException {
             long s = super.beforeHandshake(c);
             if (privateKey != null) {
-                NativeCrypto.SSL_use_PrivateKey(s, privateKey);
+                NativeCrypto.SSL_use_PrivateKey(s, privateKey.getPkeyContext());
             }
             if (certificates != null) {
                 NativeCrypto.SSL_use_certificate(s, certificates);
@@ -870,7 +858,7 @@ public class NativeCryptoTest extends TestCase {
             @Override
             public void clientCertificateRequested(long s) {
                 super.clientCertificateRequested(s);
-                NativeCrypto.SSL_use_PrivateKey(s, getClientPrivateKey());
+                NativeCrypto.SSL_use_PrivateKey(s, getClientPrivateKey().getPkeyContext());
                 NativeCrypto.SSL_use_certificate(s, getClientCertificates());
             }
         };
-- 
cgit v1.1