From: "Zebediah Figura (she/her)" Subject: Re: [PATCH 5/6] ntoskrnl.exe/tests: Add tests to read/write reports on device. Message-Id: Date: Fri, 18 Jun 2021 11:06:58 -0500 In-Reply-To: <20210618120611.703993-5-rbernon@codeweavers.com> References: <20210618120611.703993-1-rbernon@codeweavers.com> <20210618120611.703993-5-rbernon@codeweavers.com> On 6/18/21 7:06 AM, Rémi Bernon wrote: > Marking input report read requests as pending and completing them on > write, otherwise Windows keeps reading input reports and never finishes > setting up the device. > > Signed-off-by: Rémi Bernon > --- > dlls/ntoskrnl.exe/tests/driver_hid.c | 95 +++++++++++++++++++++++++++- > dlls/ntoskrnl.exe/tests/ntoskrnl.c | 70 ++++++++++++++++++-- > 2 files changed, 157 insertions(+), 8 deletions(-) > > diff --git a/dlls/ntoskrnl.exe/tests/driver_hid.c b/dlls/ntoskrnl.exe/tests/driver_hid.c > index e684e2531db..eb81426823b 100644 > --- a/dlls/ntoskrnl.exe/tests/driver_hid.c > +++ b/dlls/ntoskrnl.exe/tests/driver_hid.c > @@ -42,10 +42,32 @@ static UNICODE_STRING control_symlink; > static unsigned int got_start_device; > static DWORD report_id; > > +struct minidevice_extension > +{ > + LIST_ENTRY irp_queue; > + BOOL removed; > +}; I hate to nitpick, but I really don't like the naming here. I'd rather call this something like "struct hid_device". "mext" used below is also a kind of jarring name for a variable; it looks like "next". > + > +static void cancel_pending_requests(DEVICE_OBJECT *device) > +{ > + HID_DEVICE_EXTENSION *ext = device->DeviceExtension; > + struct minidevice_extension *mext = ext->MiniDeviceExtension; > + LIST_ENTRY *entry; > + IRP *irp; > + > + while ((entry = RemoveHeadList(&mext->irp_queue)) != &mext->irp_queue) > + { > + irp = CONTAINING_RECORD(entry, IRP, Tail.Overlay.ListEntry); > + irp->IoStatus.Status = STATUS_CANCELLED; > + IoCompleteRequest(irp, IO_NO_INCREMENT); > + } > +} > + > static NTSTATUS WINAPI driver_pnp(DEVICE_OBJECT *device, IRP *irp) > { > IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(irp); > HID_DEVICE_EXTENSION *ext = device->DeviceExtension; > + struct minidevice_extension *mext = ext->MiniDeviceExtension; > > if (winetest_debug > 1) trace("pnp %#x\n", stack->MinorFunction); > > @@ -53,6 +75,8 @@ static NTSTATUS WINAPI driver_pnp(DEVICE_OBJECT *device, IRP *irp) > { > case IRP_MN_START_DEVICE: > ++got_start_device; > + InitializeListHead(&mext->irp_queue); > + mext->removed = FALSE; > IoSetDeviceInterfaceState(&control_symlink, TRUE); > irp->IoStatus.Status = STATUS_SUCCESS; > break; > @@ -60,10 +84,14 @@ static NTSTATUS WINAPI driver_pnp(DEVICE_OBJECT *device, IRP *irp) > case IRP_MN_SURPRISE_REMOVAL: > case IRP_MN_QUERY_REMOVE_DEVICE: > case IRP_MN_STOP_DEVICE: > + mext->removed = TRUE; > + cancel_pending_requests(device); > irp->IoStatus.Status = STATUS_SUCCESS; > break; > > case IRP_MN_REMOVE_DEVICE: > + mext->removed = TRUE; > + cancel_pending_requests(device); > IoSetDeviceInterfaceState(&control_symlink, FALSE); > irp->IoStatus.Status = STATUS_SUCCESS; > break; > @@ -290,6 +318,16 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp) > REPORT_SIZE(1, 1), > FEATURE(1, Data|Var|Abs), > END_COLLECTION, > + > + USAGE_PAGE(1, HID_USAGE_PAGE_LED), > + USAGE(1, HID_USAGE_LED_GREEN), > + COLLECTION(1, Report), > + REPORT_ID_OR_USAGE_PAGE(1, report_id, 0), > + USAGE_PAGE(1, HID_USAGE_PAGE_LED), > + REPORT_COUNT(1, 8), > + REPORT_SIZE(1, 1), > + OUTPUT(1, Cnst|Var|Abs), > + END_COLLECTION, > END_COLLECTION, > }; > #undef REPORT_ID_OR_USAGE_PAGE > @@ -297,6 +335,8 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp) > > static BOOL test_failed; > IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(irp); > + HID_DEVICE_EXTENSION *ext = device->DeviceExtension; > + struct minidevice_extension *mext = ext->MiniDeviceExtension; > const ULONG in_size = stack->Parameters.DeviceIoControl.InputBufferLength; > const ULONG out_size = stack->Parameters.DeviceIoControl.OutputBufferLength; > const ULONG code = stack->Parameters.DeviceIoControl.IoControlCode; > @@ -378,7 +418,50 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp) > } > if (out_size != expected_size) test_failed = TRUE; > > - ret = STATUS_NOT_IMPLEMENTED; > + if (mext->removed) ret = STATUS_DEVICE_REMOVED; > + else > + { > + InsertTailList(&mext->irp_queue, &irp->Tail.Overlay.ListEntry); > + ret = STATUS_PENDING; You need to call IoMarkIrpPending() if you're going to return STATUS_PENDING. > + } > + break; > + } > + I don't think this is thread-safe. I know it's a test, but as long as we're in the kernel I'd really rather be careful; having to reboot a test VM is not fun. > + case IOCTL_HID_WRITE_REPORT: > + { > + HID_XFER_PACKET *packet = irp->UserBuffer; > + ULONG expected_size = report_id ? 2 : 1; > + LIST_ENTRY *entry; > + todo_wine > + ok(in_size == sizeof(*packet), "got input size %u\n", in_size); > + todo_wine > + ok(!out_size, "got output size %u\n", out_size); > + todo_wine_if(!report_id) > + ok(packet->reportBufferLen == expected_size, "got report size %u\n", packet->reportBufferLen); > + > + if (report_id) > + { > + todo_wine_if(packet->reportBuffer[0] == 0xa5) > + ok(packet->reportBuffer[0] == report_id, "got report id %x\n", packet->reportBuffer[0]); > + } > + else > + { > + todo_wine > + ok(packet->reportBuffer[0] == 0xcd, "got first byte %x\n", packet->reportBuffer[0]); > + } > + > + if ((entry = RemoveHeadList(&mext->irp_queue)) != &mext->irp_queue) > + { > + IRP *tmp = CONTAINING_RECORD(entry, IRP, Tail.Overlay.ListEntry); > + memset(tmp->UserBuffer, 0xa5, 23); > + if (report_id) ((char *)tmp->UserBuffer)[0] = report_id; > + tmp->IoStatus.Information = 23; > + tmp->IoStatus.Status = STATUS_SUCCESS; > + IoCompleteRequest(tmp, IO_NO_INCREMENT); > + } > + > + irp->IoStatus.Information = packet->reportBufferLen; > + ret = STATUS_SUCCESS; > break; > } > This seems awkward, and may stymie future attempts to test output reports. Can we just use a custom ioctl instead? > @@ -389,7 +472,9 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp) > ok(!in_size, "got input size %u\n", in_size); > ok(out_size == sizeof(*packet), "got output size %u\n", out_size); > > + todo_wine_if(report_id) > ok(packet->reportId == report_id, "got packet report id %u\n", packet->reportId); > + todo_wine_if(report_id) > ok(packet->reportBufferLen == expected_size, "got packet buffer len %u\n", packet->reportBufferLen); > ok(!!packet->reportBuffer, "got packet buffer %p\n", packet->reportBuffer); > > @@ -414,8 +499,11 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp) > ret = STATUS_NOT_IMPLEMENTED; > } > > - irp->IoStatus.Status = ret; > - IoCompleteRequest(irp, IO_NO_INCREMENT); > + if (ret != STATUS_PENDING) > + { > + irp->IoStatus.Status = ret; > + IoCompleteRequest(irp, IO_NO_INCREMENT); > + } > return ret; > } > > @@ -475,6 +563,7 @@ NTSTATUS WINAPI DriverEntry(DRIVER_OBJECT *driver, UNICODE_STRING *registry) > { > .Revision = HID_REVISION, > .DriverObject = driver, > + .DeviceExtensionSize = sizeof(struct minidevice_extension), > .RegistryPath = registry, > }; > UNICODE_STRING name_str; > diff --git a/dlls/ntoskrnl.exe/tests/ntoskrnl.c b/dlls/ntoskrnl.exe/tests/ntoskrnl.c > index 5453af8ff1c..df661327b41 100644 > --- a/dlls/ntoskrnl.exe/tests/ntoskrnl.c > +++ b/dlls/ntoskrnl.exe/tests/ntoskrnl.c > @@ -1609,8 +1609,9 @@ static void test_hidp(HANDLE file, int report_id) > .Usage = HID_USAGE_GENERIC_JOYSTICK, > .UsagePage = HID_USAGE_PAGE_GENERIC, > .InputReportByteLength = 24, > + .OutputReportByteLength = 2, > .FeatureReportByteLength = 18, > - .NumberLinkCollectionNodes = 8, > + .NumberLinkCollectionNodes = 9, > .NumberInputButtonCaps = 13, > .NumberInputValueCaps = 7, > .NumberInputDataIndices = 43, > @@ -1623,8 +1624,9 @@ static void test_hidp(HANDLE file, int report_id) > .Usage = HID_USAGE_GENERIC_JOYSTICK, > .UsagePage = HID_USAGE_PAGE_GENERIC, > .InputReportByteLength = 23, > + .OutputReportByteLength = 2, > .FeatureReportByteLength = 17, > - .NumberLinkCollectionNodes = 8, > + .NumberLinkCollectionNodes = 9, > .NumberInputButtonCaps = 13, > .NumberInputValueCaps = 7, > .NumberInputDataIndices = 43, > @@ -1766,8 +1768,8 @@ static void test_hidp(HANDLE file, int report_id) > .LinkUsage = HID_USAGE_GENERIC_JOYSTICK, > .LinkUsagePage = HID_USAGE_PAGE_GENERIC, > .CollectionType = 1, > - .NumberOfChildren = 5, > - .FirstChild = 7, > + .NumberOfChildren = 6, > + .FirstChild = 8, > }, > { > .LinkUsage = HID_USAGE_GENERIC_JOYSTICK, > @@ -2569,6 +2571,64 @@ static void test_hidp(HANDLE file, int report_id) > todo_wine > ok(!memcmp(buffer, buffer + 16, 16), "unexpected report value\n"); > > + memset(report, 0xcd, sizeof(report)); > + status = HidP_InitializeReportForID(HidP_Input, report_id, preparsed_data, report, caps.InputReportByteLength); > + todo_wine_if(report_id) > + ok(status == HIDP_STATUS_SUCCESS, "HidP_InitializeReportForID returned %#x\n", status); > + > + SetLastError(0xdeadbeef); > + ret = HidD_GetInputReport(file, report, caps.InputReportByteLength); > + ok(ret, "HidD_GetInputReport failed, last error %u\n", GetLastError()); > + > + memset(report, 0xcd, sizeof(report)); > + SetLastError(0xdeadbeef); > + ret = ReadFile(file, report, 0, &value, NULL); > + ok(!ret && GetLastError() == ERROR_INVALID_USER_BUFFER, "ReadFile failed, last error %u\n", GetLastError()); > + ok(value == 0, "ReadFile returned %x\n", value); > + SetLastError(0xdeadbeef); > + ret = ReadFile(file, report, caps.InputReportByteLength - 1, &value, NULL); > + ok(!ret && GetLastError() == ERROR_INVALID_USER_BUFFER, "ReadFile failed, last error %u\n", GetLastError()); > + ok(value == 0, "ReadFile returned %x\n", value); > + > + SetLastError(0xdeadbeef); > + ret = WriteFile(file, report, 0, &value, NULL); > + ok(!ret && GetLastError() == ERROR_INVALID_USER_BUFFER, "WriteFile failed, last error %u\n", GetLastError()); > + ok(value == 0, "WriteFile returned %x\n", value); > + SetLastError(0xdeadbeef); > + ret = WriteFile(file, report, caps.OutputReportByteLength - 1, &value, NULL); > + ok(!ret && GetLastError() == ERROR_INVALID_PARAMETER, "WriteFile failed, last error %u\n", GetLastError()); > + ok(value == 0, "WriteFile returned %x\n", value); > + > + memset(report, 0xcd, sizeof(report)); > + report[0] = 0xa5; > + SetLastError(0xdeadbeef); > + ret = WriteFile(file, report, caps.OutputReportByteLength, &value, NULL); > + if (report_id) > + { > + todo_wine > + ok(!ret && GetLastError() == ERROR_INVALID_PARAMETER, "WriteFile succeeded, last error %u\n", GetLastError()); > + todo_wine > + ok(value == 0, "WriteFile returned %x\n", value); > + > + SetLastError(0xdeadbeef); > + report[0] = report_id; > + ret = WriteFile(file, report, caps.OutputReportByteLength, &value, NULL); > + ok(ret, "WriteFile failed, last error %u\n", GetLastError()); > + ok(value == caps.OutputReportByteLength, "WriteFile returned %x\n", value); > + } > + else > + { > + ok(ret, "WriteFile failed, last error %u\n", GetLastError()); > + ok(value == caps.OutputReportByteLength, "WriteFile returned %x\n", value); > + } > + > + memset( report, 0xcd, sizeof(report) ); > + SetLastError(0xdeadbeef); > + ret = ReadFile( file, report, caps.InputReportByteLength, &value, NULL ); > + ok(ret, "ReadFile failed, last error %u\n", GetLastError()); > + ok(value == caps.InputReportByteLength, "ReadFile returned %x\n", value); > + ok(report[0] == report_id, "unexpected report data\n"); > + > HidD_FreePreparsedData(preparsed_data); > CloseHandle(file); > } > @@ -2616,7 +2676,7 @@ static void test_hid_device(DWORD report_id) > > todo_wine ok(found, "didn't find device\n"); > > - file = CreateFileA(iface_detail->DevicePath, FILE_READ_ACCESS, > + file = CreateFileA(iface_detail->DevicePath, FILE_READ_ACCESS | FILE_WRITE_ACCESS, > FILE_SHARE_READ | FILE_SHARE_WRITE, NULL, OPEN_EXISTING, 0, NULL); > ok(file != INVALID_HANDLE_VALUE, "got error %u\n", GetLastError()); > >