diff options
author | Elliott Hughes <enh@google.com> | 2009-09-17 18:32:07 -0700 |
---|---|---|
committer | Elliott Hughes <enh@google.com> | 2009-09-18 11:52:23 -0700 |
commit | 42d8ce7a37c6c2245fa619843e9c240c1d2f3e46 (patch) | |
tree | 1b9439b068e00c38e4f6db53cc53c20dc4af2f73 /luni | |
parent | 3dfa79d310b5d304ab6b6328034271705e8057f9 (diff) | |
download | libcore-42d8ce7a37c6c2245fa619843e9c240c1d2f3e46.zip libcore-42d8ce7a37c6c2245fa619843e9c240c1d2f3e46.tar.gz libcore-42d8ce7a37c6c2245fa619843e9c240c1d2f3e46.tar.bz2 |
Fix heap corruption in nio select(2) code.
The active ingredient in this change is that we now test that the fd isn't -1,
used to represent an invalid fd. There's a race condition where a socket can be
closed between SelectorImpl.prepareChannels and the native code. This caused us
to write to the -1th element of a heap-allocated structure, leading to SIGSEGV.
I've also removed the check for an empty fd_set. It was broken before and will
never have fired, but I don't think it makes sense to fix it, given this race
condition.
The race can't be fixed because the implementation is documented to close the
socket channel and *then* cancel the selection key.
This patch also removes various dead functions and tidies up timeval usage.
Bug: 2093094
Diffstat (limited to 'luni')
-rw-r--r-- | luni/src/main/native/org_apache_harmony_luni_platform_OSNetworkSystem.cpp | 271 |
1 files changed, 93 insertions, 178 deletions
diff --git a/luni/src/main/native/org_apache_harmony_luni_platform_OSNetworkSystem.cpp b/luni/src/main/native/org_apache_harmony_luni_platform_OSNetworkSystem.cpp index 1b42a45..afa0dfc 100644 --- a/luni/src/main/native/org_apache_harmony_luni_platform_OSNetworkSystem.cpp +++ b/luni/src/main/native/org_apache_harmony_luni_platform_OSNetworkSystem.cpp @@ -170,8 +170,6 @@ struct CachedFields { jclass byte_class; jmethodID byte_class_init; jfieldID byte_class_value; - jclass string_class; - jmethodID string_class_init; jclass socketimpl_class; jfieldID socketimpl_address; jfieldID socketimpl_port; @@ -424,22 +422,17 @@ jobject newJavaLangInteger(JNIEnv * env, jint anInt) { return env->NewObject(tempClass, tempMethod, anInt); } -/** - * Answer a new java.lang.String object. - * - * @param env pointer to the JNI library - * @param anInt the byte[] constructor argument - * - * @return the new String - */ - -jobject newJavaLangString(JNIEnv * env, jbyteArray bytes) { - jclass tempClass; - jmethodID tempMethod; +// Converts a number of milliseconds to a timeval. +static timeval toTimeval(long ms) { + timeval tv; + tv.tv_sec = ms / 1000; + tv.tv_usec = (ms - tv.tv_sec*1000) * 1000; + return tv; +} - tempClass = gCachedFields.string_class; - tempMethod = gCachedFields.string_class_init; - return env->NewObject(tempClass, tempMethod, (jbyteArray) bytes); +// Converts a timeval to a number of milliseconds. +static long toMs(const timeval& tv) { + return tv.tv_sec * 1000 + tv.tv_usec / 1000; } /** @@ -456,11 +449,10 @@ jobject newJavaLangString(JNIEnv * env, jbyteArray bytes) { */ static int time_msec_clock() { - struct timeval tp; + timeval tp; struct timezone tzp; - gettimeofday(&tp, &tzp); - return (tp.tv_sec * 1000) + (tp.tv_usec / 1000); + return toMs(tp); } /** @@ -871,53 +863,6 @@ static int getSocketAddressFamily(int socket) { } /** - * A helper method, to set the connect context to a Long object. - * - * @param env pointer to the JNI library - * @param longclass Java Long Object - */ -void setConnectContext(JNIEnv *env,jobject longclass,jbyte * context) { - jclass descriptorCLS; - jfieldID descriptorFID; - descriptorCLS = env->FindClass("java/lang/Long"); - descriptorFID = env->GetFieldID(descriptorCLS, "value", "J"); - env->SetLongField(longclass, descriptorFID, (jlong)((jint)context)); -}; - -/** - * A helper method, to get the connect context. - * - * @param env pointer to the JNI library - * @param longclass Java Long Object - */ -jbyte *getConnectContext(JNIEnv *env, jobject longclass) { - jclass descriptorCLS; - jfieldID descriptorFID; - descriptorCLS = env->FindClass("java/lang/Long"); - descriptorFID = env->GetFieldID(descriptorCLS, "value", "J"); - return (jbyte*) ((jint)env->GetLongField(longclass, descriptorFID)); -}; - -// typical ip checksum -unsigned short ip_checksum(unsigned short* buffer, int size) { - register unsigned short * buf = buffer; - register int bufleft = size; - register unsigned long sum = 0; - - while (bufleft > 1) { - sum = sum + (*buf++); - bufleft = bufleft - sizeof(unsigned short ); - } - if (bufleft) { - sum = sum + (*(unsigned char*)buf); - } - sum = (sum >> 16) + (sum & 0xffff); - sum += (sum >> 16); - - return (unsigned short )(~sum); -} - -/** * Converts an IPv4 address to an IPv4-mapped IPv6 address. Performs no error * checking. * @@ -1010,9 +955,8 @@ static int doBind(int socket, struct sockaddr_storage *socketAddress) { * @return 0, if no errors occurred, otherwise the (negative) error code. */ static int sockConnectWithTimeout(int handle, struct sockaddr_storage addr, - unsigned int timeout, unsigned int step, jbyte *ctxt) { + int timeout, unsigned int step, jbyte *ctxt) { int rc = 0; - struct timeval passedTimeout; int errorVal; socklen_t errorValLen = sizeof(int); struct selectFDSet *context = NULL; @@ -1072,13 +1016,13 @@ static int sockConnectWithTimeout(int handle, struct sockaddr_storage addr, * set the timeout value to be used. Because on some unix platforms we * don't get notified when a socket is closed we only sleep for 100ms * at a time + * + * TODO: is this relevant for Android? */ - passedTimeout.tv_sec = 0; if (timeout > 100) { - passedTimeout.tv_usec = 100 * 1000; - } else if ((int)timeout >= 0) { - passedTimeout.tv_usec = timeout * 1000; + timeout = 100; } + timeval passedTimeout(toTimeval(timeout)); /* initialize the FD sets for the select */ FD_ZERO(&(context->exceptionSet)); @@ -1092,7 +1036,7 @@ static int sockConnectWithTimeout(int handle, struct sockaddr_storage addr, &(context->readSet), &(context->writeSet), &(context->exceptionSet), - (int)timeout >= 0 ? &passedTimeout : NULL); + timeout >= 0 ? &passedTimeout : NULL); /* if there is at least one descriptor ready to be checked */ if (0 < rc) { @@ -1399,7 +1343,6 @@ static void osNetworkSystem_oneTimeInitializationImpl(JNIEnv* env, jobject obj, {&c->integer_class, "java/lang/Integer"}, {&c->boolean_class, "java/lang/Boolean"}, {&c->byte_class, "java/lang/Byte"}, - {&c->string_class, "java/lang/String"}, {&c->socketimpl_class, "java/net/SocketImpl"}, {&c->dpack_class, "java/net/DatagramPacket"} }; @@ -1421,7 +1364,6 @@ static void osNetworkSystem_oneTimeInitializationImpl(JNIEnv* env, jobject obj, {&c->integer_class_init, c->integer_class, "<init>", "(I)V", false}, {&c->boolean_class_init, c->boolean_class, "<init>", "(Z)V", false}, {&c->byte_class_init, c->byte_class, "<init>", "(B)V", false}, - {&c->string_class_init, c->string_class, "<init>", "([B)V", false}, {&c->iaddr_getbyaddress, c->iaddr_class, "getByAddress", "([B)Ljava/net/InetAddress;", true} }; @@ -2474,9 +2416,7 @@ static jint osNetworkSystem_receiveStreamImpl(JNIEnv* env, jclass clazz, jbyte* body = env->GetByteArrayElements(data, NULL); // set timeout - struct timeval tv; - tv.tv_sec = timeout / 1000; - tv.tv_usec = (timeout % 1000) * 1000; + timeval tv(toTimeval(timeout)); setsockopt(handle, SOL_SOCKET, SO_RCVTIMEO, (struct timeval *)&tv, sizeof(struct timeval)); @@ -2660,110 +2600,87 @@ static jint osNetworkSystem_sendDatagramImpl2(JNIEnv* env, jclass clazz, return sent; } -static jint osNetworkSystem_selectImpl(JNIEnv* env, jclass clazz, - jobjectArray readFDArray, jobjectArray writeFDArray, jint countReadC, - jint countWriteC, jintArray outFlags, jlong timeout) { - // LOGD("ENTER selectImpl"); - - struct timeval timeP; - int result = 0; - int size = 0; - jobject gotFD; - fd_set *fdset_read,*fdset_write; - int handle; - jint *flagArray; - int val; - unsigned int time_sec = (unsigned int)timeout/1000; - unsigned int time_msec = (unsigned int)(timeout%1000); - - fdset_read = (fd_set *)malloc(sizeof(fd_set)); - fdset_write = (fd_set *)malloc(sizeof(fd_set)); - - FD_ZERO(fdset_read); - FD_ZERO(fdset_write); - - for (val = 0; val<countReadC; val++) { - - gotFD = env->GetObjectArrayElement(readFDArray,val); - - handle = jniGetFDFromFileDescriptor(env, gotFD); - - FD_SET(handle, fdset_read); - - if (0 > (size - handle)) { - size = handle; +static bool initFdSet(JNIEnv* env, jobjectArray fdArray, jint count, fd_set* fdSet, int* maxFd) { + for (int i = 0; i < count; ++i) { + jobject fileDescriptor = env->GetObjectArrayElement(fdArray, i); + if (fileDescriptor == NULL) { + return false; } - } - - for (val = 0; val<countWriteC; val++) { - - gotFD = env->GetObjectArrayElement(writeFDArray,val); - - handle = jniGetFDFromFileDescriptor(env, gotFD); - - FD_SET(handle, fdset_write); - - if (0 > (size - handle)) { - size = handle; + + const int fd = jniGetFDFromFileDescriptor(env, fileDescriptor); + if (fd < 0 || fd > 1024) { + LOGE("selectImpl: invalid fd %i", fd); + continue; } - } - - /* the size is the max_fd + 1 */ - size =size + 1; - - if (0 > size) { - result = SOCKERR_FDSET_SIZEBAD; - } else { - /* only set when timeout >= 0 (non-block)*/ - if (0 <= timeout) { - - timeP.tv_sec = time_sec; - timeP.tv_usec = time_msec*1000; - - result = sockSelect(size, fdset_read, fdset_write, NULL, &timeP); - - } else { - result = sockSelect(size, fdset_read, fdset_write, NULL, NULL); + + FD_SET(fd, fdSet); + + if (fd > *maxFd) { + *maxFd = fd; } } + return true; +} - if (0 < result) { - /*output the result to a int array*/ - flagArray = env->GetIntArrayElements(outFlags, NULL); - - for (val=0; val<countReadC; val++) { - gotFD = env->GetObjectArrayElement(readFDArray,val); - - handle = jniGetFDFromFileDescriptor(env, gotFD); - - if (FD_ISSET(handle,fdset_read)) { - flagArray[val] = SOCKET_OP_READ; - } else { - flagArray[val] = SOCKET_OP_NONE; - } +static bool translateFdSet(JNIEnv* env, jobjectArray fdArray, jint count, const fd_set& fdSet, jint* flagArray, size_t offset, jint op) { + for (int i = 0; i < count; ++i) { + jobject fileDescriptor = env->GetObjectArrayElement(fdArray, i); + if (fileDescriptor == NULL) { + return false; } - - for (val=0; val<countWriteC; val++) { - - gotFD = env->GetObjectArrayElement(writeFDArray,val); - - handle = jniGetFDFromFileDescriptor(env, gotFD); - - if (FD_ISSET(handle,fdset_write)) { - flagArray[val+countReadC] = SOCKET_OP_WRITE; - } else { - flagArray[val+countReadC] = SOCKET_OP_NONE; - } + + const int fd = jniGetFDFromFileDescriptor(env, fileDescriptor); + const bool valid = fd >= 0 && fd < 1024; + + if (valid && FD_ISSET(fd, &fdSet)) { + flagArray[i + offset] = op; + } else { + flagArray[i + offset] = SOCKET_OP_NONE; } - - env->ReleaseIntArrayElements(outFlags, flagArray, 0); } + return true; +} - free(fdset_write); - free(fdset_read); - - /* return both correct and error result, let java handle the exception*/ - return result; +static jint osNetworkSystem_selectImpl(JNIEnv* env, jclass clazz, + jobjectArray readFDArray, jobjectArray writeFDArray, jint countReadC, + jint countWriteC, jintArray outFlags, jlong timeoutMs) { + // LOGD("ENTER selectImpl"); + + // Initialize the fd_sets. + int maxFd = -1; + fd_set readFds; + fd_set writeFds; + FD_ZERO(&readFds); + FD_ZERO(&writeFds); + bool initialized = initFdSet(env, readFDArray, countReadC, &readFds, &maxFd) && + initFdSet(env, writeFDArray, countWriteC, &writeFds, &maxFd); + if (!initialized) { + return -1; + } + + // Initialize the timeout, if any. + timeval tv; + timeval* tvp = NULL; + if (timeoutMs >= 0) { + tv = toTimeval(timeoutMs); + tvp = &tv; + } + + // Perform the select. + int result = sockSelect(maxFd + 1, &readFds, &writeFds, NULL, tvp); + if (result < 0) { + return result; + } + + // Translate the result into the int[] we're supposed to fill in. + jint* flagArray = env->GetIntArrayElements(outFlags, NULL); + if (flagArray == NULL) { + return -1; + } + bool okay = translateFdSet(env, readFDArray, countReadC, readFds, flagArray, 0, SOCKET_OP_READ) && + translateFdSet(env, writeFDArray, countWriteC, writeFds, flagArray, countReadC, SOCKET_OP_WRITE); + env->ReleaseIntArrayElements(outFlags, flagArray, 0); + return okay ? 0 : -1; } static jobject osNetworkSystem_getSocketLocalAddressImpl(JNIEnv* env, @@ -2995,7 +2912,7 @@ static jobject osNetworkSystem_getSocketOptionImpl(JNIEnv* env, jclass clazz, throwSocketException(env, convertError(errno)); return NULL; } - return newJavaLangInteger(env, timeout.tv_sec * 1000 + timeout.tv_usec/1000); + return newJavaLangInteger(env, toMs(timeout)); } default: { throwSocketException(env, SOCKERR_OPTUNSUPP); @@ -3237,9 +3154,7 @@ static void osNetworkSystem_setSocketOptionImpl(JNIEnv* env, jclass clazz, } case JAVASOCKOPT_SO_RCVTIMEOUT: { - struct timeval timeout; - timeout.tv_sec = intVal / 1000; - timeout.tv_usec = (intVal % 1000) * 1000; + timeval timeout(toTimeval(intVal)); result = setsockopt(handle, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(struct timeval)); if (0 != result) { |