From: Stefan Dösinger Subject: [PATCH 2/2] ws2_32: Don't install APCs if the callback function is NULL. Message-Id: <1421160543-6742-2-git-send-email-stefan@codeweavers.com> Date: Tue, 13 Jan 2015 15:49:03 +0100 This fixes the crash in RIFT on startup. RIFT performs an interruptible wait while waiting for an event that protects internal data. Unfortunately it doesn't check the return value and assumes it can go ahead and modify said data after the APC interrupts the wait. This leads to a NULL pointer crash. Please beware that I do not fully understand why APCs are used to free ws2_accept_async / ws2_async here. As far as I can see the purpose is to keep the HeapFree logic simple and use the same codepath in all cases. This patch obviously makes the alloc / free logic harder to follow. There may be some other purpose for those APCs that I missed. --- dlls/ws2_32/socket.c | 53 +++++++++++++++-------- dlls/ws2_32/tests/sock.c | 107 ++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 133 insertions(+), 27 deletions(-) diff --git a/dlls/ws2_32/socket.c b/dlls/ws2_32/socket.c index e1b2a05..7dff14d 100644 --- a/dlls/ws2_32/socket.c +++ b/dlls/ws2_32/socket.c @@ -1887,9 +1887,9 @@ static void WINAPI ws2_async_apc( void *arg, IO_STATUS_BLOCK *iosb, ULONG reserv { ws2_async *wsa = arg; - if (wsa->completion_func) wsa->completion_func( NtStatusToWSAError(iosb->u.Status), - iosb->Information, wsa->user_overlapped, - wsa->flags ); + wsa->completion_func( NtStatusToWSAError(iosb->u.Status), + iosb->Information, wsa->user_overlapped, + wsa->flags ); HeapFree( GetProcessHeap(), 0, wsa ); } @@ -2008,7 +2008,10 @@ static NTSTATUS WS2_async_recv( void* user, IO_STATUS_BLOCK* iosb, NTSTATUS stat { iosb->u.Status = status; iosb->Information = result; - *apc = ws2_async_apc; + if (wsa->completion_func) + *apc = ws2_async_apc; + else + HeapFree( GetProcessHeap(), 0, wsa ); } return status; } @@ -2018,7 +2021,6 @@ static void WINAPI ws2_async_accept_apc( void *arg, IO_STATUS_BLOCK *iosb, ULONG { struct ws2_accept_async *wsa = arg; - HeapFree( GetProcessHeap(), 0, wsa->read ); HeapFree( GetProcessHeap(), 0, wsa ); } @@ -2042,6 +2044,10 @@ static NTSTATUS WS2_async_accept_recv( void *arg, IO_STATUS_BLOCK *iosb, NTSTATU if (wsa->cvalue) WS_AddCompletion( HANDLE2SOCKET(wsa->listen_socket), wsa->cvalue, iosb->u.Status, iosb->Information ); + /* AcceptEx is the only caller, and it does not support completion callbacks. + * WS2_async_recv has freed wsa->read. */ + if (*apc) + ERR("Completion routine used with WS2_async_accept_recv\n"); *apc = ws2_async_accept_apc; return status; } @@ -2125,6 +2131,7 @@ finish: if (wsa->user_overlapped->hEvent) SetEvent(wsa->user_overlapped->hEvent); + HeapFree( GetProcessHeap(), 0, wsa->read ); *apc = ws2_async_accept_apc; return status; } @@ -2247,7 +2254,10 @@ static NTSTATUS WS2_async_send(void* user, IO_STATUS_BLOCK* iosb, NTSTATUS statu if (status != STATUS_PENDING) { iosb->u.Status = status; - *apc = ws2_async_apc; + if (wsa->completion_func) + *apc = ws2_async_apc; + else + HeapFree( GetProcessHeap(), 0, wsa ); } return status; } @@ -2279,7 +2289,10 @@ static NTSTATUS WS2_async_shutdown( void* user, PIO_STATUS_BLOCK iosb, NTSTATUS } iosb->u.Status = status; iosb->Information = 0; - *apc = ws2_async_apc; + if (wsa->completion_func) + *apc = ws2_async_apc; + else + HeapFree( GetProcessHeap(), 0, wsa ); return status; } @@ -3780,19 +3793,25 @@ static DWORD server_ioctl_sock( SOCKET s, DWORD code, LPVOID in_buff, DWORD in_s { HANDLE event = overlapped ? overlapped->hEvent : 0; HANDLE handle = SOCKET2HANDLE( s ); - struct ws2_async *wsa; + struct ws2_async *wsa = NULL; NTSTATUS status; - PIO_STATUS_BLOCK io; + static IO_STATUS_BLOCK dummy_io_block; + PIO_STATUS_BLOCK io = &dummy_io_block; - if (!(wsa = RtlAllocateHeap( GetProcessHeap(), 0, sizeof(*wsa) ))) - return WSA_NOT_ENOUGH_MEMORY; - wsa->hSocket = handle; - wsa->user_overlapped = overlapped; - wsa->completion_func = completion; - io = (overlapped ? (PIO_STATUS_BLOCK)overlapped : &wsa->local_iosb); + if (completion) + { + if (!(wsa = RtlAllocateHeap( GetProcessHeap(), 0, sizeof(*wsa) ))) + return WSA_NOT_ENOUGH_MEMORY; + wsa->hSocket = handle; + wsa->user_overlapped = overlapped; + wsa->completion_func = completion; + io = &wsa->local_iosb; + } + if (overlapped) + io = (PIO_STATUS_BLOCK)overlapped; - status = NtDeviceIoControlFile( handle, event, (PIO_APC_ROUTINE)ws2_async_apc, wsa, io, code, - in_buff, in_size, out_buff, out_size ); + status = NtDeviceIoControlFile( handle, event, completion ? (PIO_APC_ROUTINE)ws2_async_apc : NULL, + wsa, io, code, in_buff, in_size, out_buff, out_size ); if (status == STATUS_NOT_SUPPORTED) { FIXME("Unsupported ioctl %x (device=%x access=%x func=%x method=%x)\n", diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c index ce47712..78b2467 100644 --- a/dlls/ws2_32/tests/sock.c +++ b/dlls/ws2_32/tests/sock.c @@ -4467,20 +4467,27 @@ static DWORD WINAPI drain_socket_thread(LPVOID arg) return 0; } +static void CALLBACK completion_callback(DWORD error, DWORD transferred, WSAOVERLAPPED *overlapped, DWORD flags) +{ + BOOL *callback_called = overlapped->Pointer; + *callback_called = TRUE; +} + static void test_send(void) { SOCKET src = INVALID_SOCKET; SOCKET dst = INVALID_SOCKET; - HANDLE hThread = NULL; + HANDLE hThread = NULL, dummy_event = NULL; const int buflen = 1024*1024; - char *buffer = NULL; + char *buffer = NULL, *recv_buffer = NULL; int ret, i, zero = 0; WSABUF buf; - OVERLAPPED ov; - BOOL bret; + OVERLAPPED ov, ov2; + BOOL bret, callback_called = FALSE; DWORD id, bytes_sent, dwRet; memset(&ov, 0, sizeof(ov)); + memset(&ov2, 0, sizeof(ov2)); if (tcp_socketpair(&src, &dst) != 0) { @@ -4499,6 +4506,14 @@ static void test_send(void) ok(0, "CreateThread failed, error %d\n", GetLastError()); goto end; } + dummy_event = CreateEventA(NULL, FALSE, FALSE, NULL); + if (dummy_event == NULL) + { + ok(0, "could not create event object, errno = %d\n", GetLastError()); + goto end; + } + /* Flush outstanding APCs */ + WaitForSingleObjectEx(dummy_event, 1, TRUE); buffer = HeapAlloc(GetProcessHeap(), 0, buflen); if (buffer == NULL) @@ -4506,6 +4521,12 @@ static void test_send(void) ok(0, "HeapAlloc failed, error %d\n", GetLastError()); goto end; } + recv_buffer = HeapAlloc(GetProcessHeap(), 0, buflen); + if (recv_buffer == NULL) + { + ok(0, "HeapAlloc failed, error %d\n", GetLastError()); + goto end; + } /* fill the buffer with some nonsense */ for (i = 0; i < buflen; ++i) @@ -4539,21 +4560,23 @@ static void test_send(void) { int j = 0; - ret = recv(src, buffer, 1, 0); + ret = recv(src, recv_buffer, 1, 0); while (ret == SOCKET_ERROR && GetLastError() == WSAEWOULDBLOCK && j < 100) { j++; Sleep(50); - ret = recv(src, buffer, 1, 0); + ret = recv(src, recv_buffer, 1, 0); } ok(ret == 1, "Failed to receive data %d - %d (got %d/%d)\n", ret, GetLastError(), i, buflen); if (ret != 1) break; - ok(buffer[0] == (char) i, "Received bad data at position %d\n", i); + ok(recv_buffer[0] == (char) i, "Received bad data at position %d\n", i); } + dwRet = WaitForSingleObjectEx(dummy_event, 100, TRUE); + ok(dwRet == WAIT_TIMEOUT, "Failed to wait for recv message: %d - %d\n", dwRet, GetLastError()); dwRet = WaitForSingleObject(ov.hEvent, 1000); ok(dwRet == WAIT_OBJECT_0, "Failed to wait for recv message: %d - %d\n", dwRet, GetLastError()); if (dwRet == WAIT_OBJECT_0) @@ -4563,6 +4586,37 @@ static void test_send(void) "Got %d instead of %d (%d - %d)\n", bytes_sent, buflen, bret, GetLastError()); } + bytes_sent = 0; + WSASetLastError(12345); + ov2.Pointer = &callback_called; + ret = WSASend(dst, &buf, 1, &bytes_sent, 0, &ov2, completion_callback); + ok((ret == SOCKET_ERROR && WSAGetLastError() == ERROR_IO_PENDING) || broken(bytes_sent == buflen), + "Failed to start overlapped send %d - %d - %d/%d\n", ret, WSAGetLastError(), bytes_sent, buflen); + + for (i = 0; i < buflen; ++i) + { + int j = 0; + + ret = recv(src, recv_buffer, 1, 0); + while (ret == SOCKET_ERROR && GetLastError() == WSAEWOULDBLOCK && j < 100) + { + j++; + Sleep(50); + ret = recv(src, recv_buffer, 1, 0); + } + + ok(ret == 1, "Failed to receive data %d - %d (got %d/%d)\n", ret, GetLastError(), i, buflen); + if (ret != 1) + break; + + ok(recv_buffer[0] == (char) i, "Received bad data at position %d\n", i); + } + + ok(!callback_called, "Expected completion callback not to be called\n"); + dwRet = WaitForSingleObjectEx(dummy_event, 100, TRUE); + ok(dwRet == WAIT_IO_COMPLETION, "Failed to wait for recv message: %d - %d\n", dwRet, GetLastError()); + ok(callback_called, "Expected completion callback to be called\n"); + WSASetLastError(12345); ret = WSASend(INVALID_SOCKET, &buf, 1, NULL, 0, &ov, NULL); ok(ret == SOCKET_ERROR && WSAGetLastError() == WSAENOTSOCK, @@ -4586,7 +4640,10 @@ end: } if (ov.hEvent) CloseHandle(ov.hEvent); + if (dummy_event) + CloseHandle(dummy_event); HeapFree(GetProcessHeap(), 0, buffer); + HeapFree(GetProcessHeap(), 0, recv_buffer); } typedef struct async_message @@ -4876,21 +4933,21 @@ static void test_events(int useMessages) SOCKET src = INVALID_SOCKET, src2 = INVALID_SOCKET; SOCKET dst = INVALID_SOCKET, dst2 = INVALID_SOCKET; struct sockaddr_in addr; - HANDLE hThread = NULL; + HANDLE hThread = NULL, dummy_event = INVALID_HANDLE_VALUE; HANDLE hEvent = INVALID_HANDLE_VALUE, hEvent2 = INVALID_HANDLE_VALUE; WNDCLASSEXA wndclass; HWND hWnd = NULL; char *buffer = NULL; int bufferSize = 1024*1024; WSABUF bufs; - OVERLAPPED ov, ov2; + OVERLAPPED ov, ov2, ov3; DWORD flags = 0; DWORD bytesReturned; DWORD id; int len; int ret; DWORD dwRet; - BOOL bret; + BOOL bret, callback_called = FALSE; static char szClassName[] = "wstestclass"; const LPARAM *broken_seq[3]; static const LPARAM empty_seq[] = { 0 }; @@ -4909,6 +4966,7 @@ static void test_events(int useMessages) memset(&ov, 0, sizeof(ov)); memset(&ov2, 0, sizeof(ov2)); + memset(&ov3, 0, sizeof(ov3)); /* don't use socketpair, we want connection event */ src = socket(AF_INET, SOCK_STREAM, 0); @@ -5121,6 +5179,15 @@ static void test_events(int useMessages) goto end; } + dummy_event = CreateEventA(NULL, FALSE, FALSE, NULL); + if (dummy_event == NULL) + { + ok(0, "could not create event object, errno = %d\n", GetLastError()); + goto end; + } + /* Flush outstanding APCs */ + WaitForSingleObjectEx(dummy_event, 1, TRUE); + /* FD_WRITE should be set initially, and allow us to send at least 1 byte */ ok_event_seq(src, hEvent, connect_seq, NULL, 1); ok_event_seq(src2, hEvent2, connect_seq, NULL, 1); @@ -5194,6 +5261,9 @@ todo_wine broken_seq[1] = NULL; ok_event_seq(src, hEvent, empty_seq, broken_seq, 0); + dwRet = WaitForSingleObjectEx(dummy_event, 100, TRUE); + ok(dwRet == WAIT_TIMEOUT, "Failed to wait for recv message: %d - %d\n", dwRet, GetLastError()); + dwRet = WaitForSingleObject(ov.hEvent, 100); ok(dwRet == WAIT_OBJECT_0, "Failed to wait for recv message: %d - %d\n", dwRet, GetLastError()); if (dwRet == WAIT_OBJECT_0) @@ -5304,6 +5374,21 @@ todo_wine ok(0, "sending a lot of data failed with error %d\n", WSAGetLastError()); } + bufs.len = sizeof(char); + bufs.buf = buffer; + ov3.Pointer = &callback_called; + ret = WSARecv(src, &bufs, 1, &bytesReturned, &flags, &ov3, completion_callback); + ok(ret == SOCKET_ERROR && GetLastError() == ERROR_IO_PENDING, + "WSARecv failed - %d error %d\n", ret, GetLastError()); + + ret = send(dst, "2", 1, 0); + ok(ret == 1, "Failed to send buffer %d err %d\n", ret, GetLastError()); + + ok(!callback_called, "Expected completion callback not to be called\n"); + dwRet = WaitForSingleObjectEx(dummy_event, 100, TRUE); + ok(dwRet == WAIT_IO_COMPLETION, "Failed to wait for recv message: %d - %d\n", dwRet, GetLastError()); + ok(callback_called, "Expected completion callback to be called\n"); + /* Test how FD_CLOSE is handled */ ret = send(dst, "12", 2, 0); ok(ret == 2, "Failed to send buffer %d err %d\n", ret, GetLastError()); @@ -5442,6 +5527,8 @@ end: CloseHandle(ov.hEvent); if (ov2.hEvent != NULL) CloseHandle(ov2.hEvent); + if (dummy_event != NULL) + CloseHandle(dummy_event); } static void test_ipv6only(void) -- 2.0.5