From 7d4b73a85972c08c892c9f94347ac38bb391144d Mon Sep 17 00:00:00 2001 From: hexawyz <8518235+hexawyz@users.noreply.github.com> Date: Sun, 9 Feb 2025 23:26:05 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Wire=20in=20the=20ability=20to=20wa?= =?UTF-8?q?tch=20device=20arrivals=20through=20DevQuery=20=F0=9F=98=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Honnestly don't remember why I didn't add this earlier, but it works as expected. --- .../DeviceObjectWatchWatchNotification.cs | 7 + .../DeviceTools.Core/DeviceQuery.cs | 130 +++++++++++++++++- .../DeviceTools.Core/WatchNotificationKind.cs | 9 ++ .../StreamDeckPlayground/MainViewModel.cs | 24 +++- 4 files changed, 165 insertions(+), 5 deletions(-) create mode 100644 src/DeviceTools/DeviceTools.Core/DeviceObjectWatchWatchNotification.cs create mode 100644 src/DeviceTools/DeviceTools.Core/WatchNotificationKind.cs diff --git a/src/DeviceTools/DeviceTools.Core/DeviceObjectWatchWatchNotification.cs b/src/DeviceTools/DeviceTools.Core/DeviceObjectWatchWatchNotification.cs new file mode 100644 index 0000000..b28cb09 --- /dev/null +++ b/src/DeviceTools/DeviceTools.Core/DeviceObjectWatchWatchNotification.cs @@ -0,0 +1,7 @@ +namespace DeviceTools; + +public readonly struct DeviceObjectWatchWatchNotification(WatchNotificationKind kind, DeviceObjectInformation @object) +{ + public WatchNotificationKind Kind { get; } = kind; + public DeviceObjectInformation Object { get; } = @object; +} diff --git a/src/DeviceTools/DeviceTools.Core/DeviceQuery.cs b/src/DeviceTools/DeviceTools.Core/DeviceQuery.cs index c5c6bdf..cd730ce 100644 --- a/src/DeviceTools/DeviceTools.Core/DeviceQuery.cs +++ b/src/DeviceTools/DeviceTools.Core/DeviceQuery.cs @@ -21,6 +21,56 @@ public sealed class DeviceQuery AllowSynchronousContinuations = false }; +#if NET5_0_OR_GREATER + [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })] +#endif + private static unsafe void WatchAllCallback(IntPtr handle, IntPtr context, NativeMethods.DeviceQueryResultActionData* action) + { + var ctx = Unsafe.As(GCHandle.FromIntPtr(context).Target)!; + switch (action->Action) + { + case NativeMethods.DeviceQueryResultAction.DevQueryResultAdd: + case NativeMethods.DeviceQueryResultAction.DevQueryResultUpdate: + case NativeMethods.DeviceQueryResultAction.DevQueryResultRemove: + ref var @object = ref action->StateOrObject.DeviceObject; + var kind = (WatchNotificationKind)action->Action; + if (!ctx.IsEnumerationComplete && kind == WatchNotificationKind.Add) kind = WatchNotificationKind.Enumeration; +#if NET6_0_OR_GREATER + ctx.Writer.TryWrite(new(kind, new(@object.ObjectType, MemoryMarshal.CreateReadOnlySpanFromNullTerminated(@object.ObjectId).ToString(), ParseProperties(ref @object)))); +#else + ctx.Writer.TryWrite(new(kind, new(@object.ObjectType, Marshal.PtrToStringUni((IntPtr)@object.ObjectId)!, ParseProperties(ref @object)))); +#endif + break; + case NativeMethods.DeviceQueryResultAction.DevQueryResultStateChange: + var state = action->StateOrObject.State; + if (state is NativeMethods.DeviceQueryState.DevQueryStateEnumCompleted) + { + ctx.IsEnumerationComplete = true; + } + else if (state is NativeMethods.DeviceQueryState.DevQueryStateAborted) + { +#if NET5_0_OR_GREATER + ctx.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(new Exception("The query was aborted."))); +#else + try { new Exception("The query was aborted."); } + catch (Exception ex) { ctx.Writer.TryComplete(ex); } +#endif + ctx.Dispose(); + } + else if (state is NativeMethods.DeviceQueryState.DevQueryStateClosed) + { +#if NET5_0_OR_GREATER + ctx.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException())); +#else + try { new OperationCanceledException(); } + catch (Exception ex) { ctx.Writer.TryComplete(ex); } +#endif + ctx.Dispose(); + } + break; + } + } + #if NET5_0_OR_GREATER [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })] #endif @@ -167,6 +217,7 @@ private enum Method { EnumerateAll = 1, FindAll, + WatchAll, GetObjectProperties } @@ -203,6 +254,14 @@ private class DevQueryCallbackContext : DevQueryCallbackContext public DevQueryCallbackContext(Method method, TState state) : base(method) => State = state; } + private class DevQueryWatchCallbackContext : DevQueryCallbackContext + { + public ChannelWriter Writer { get; } + public bool IsEnumerationComplete { get; set; } + + public DevQueryWatchCallbackContext(ChannelWriter writer) : base(Method.WatchAll) => Writer = writer; + } + private class DevQueryCallbackContext : DevQueryCallbackContext where TState : class { @@ -332,6 +391,69 @@ public DevQueryCallbackContext(Method method, TState state) : base(method, state return dictionary; } + public static IAsyncEnumerable WatchAllAsync(DeviceObjectKind objectKind, CancellationToken cancellationToken) => + WatchAllAsync(objectKind, null, null, cancellationToken); + + public static IAsyncEnumerable WatchAllAsync(DeviceObjectKind objectKind, DeviceFilterExpression filter, CancellationToken cancellationToken) => + WatchAllAsync(objectKind, null, filter, cancellationToken); + + public static IAsyncEnumerable WatchAllAsync(DeviceObjectKind objectKind, IEnumerable? properties, CancellationToken cancellationToken) => + WatchAllAsync(objectKind, properties, null, cancellationToken); + + public static IAsyncEnumerable WatchAllAsync + ( + DeviceObjectKind objectKind, + IEnumerable? properties, + DeviceFilterExpression? filter, + CancellationToken cancellationToken + ) + { + int count = filter?.GetFilterElementCount(true) ?? 0; + Span filterExpressions = count <= 4 ? + count == 0 ? + new Span() : + stackalloc NativeMethods.DevicePropertyFilterExpression[count] : + new NativeMethods.DevicePropertyFilterExpression[count]; + + Span propertyKeys = GetPropertyKeys(properties); + + SafeDeviceQueryHandle query; + DevQueryWatchCallbackContext context; + + filter?.FillExpressions(filterExpressions, true, out count); + + var channel = Channel.CreateUnbounded(EnumerateAllChannelOptions); + try + { + context = new(channel); + + try + { + query = CreateObjectQuery + ( + objectKind, + properties is null ? + NativeMethods.DeviceQueryFlags.UpdateResults | NativeMethods.DeviceQueryFlags.AllProperties | NativeMethods.DeviceQueryFlags.AsyncClose : + NativeMethods.DeviceQueryFlags.UpdateResults | NativeMethods.DeviceQueryFlags.AsyncClose, + propertyKeys, + filterExpressions, + context.GetHandle() + ); + } + catch + { + context.Dispose(); + throw; + } + } + finally + { + filter?.ReleaseExpressionResources(); + } + + return EnumerateAllAsync(query, (ChannelReader)channel, cancellationToken); + } + public static IAsyncEnumerable EnumerateAllAsync(DeviceObjectKind objectKind, CancellationToken cancellationToken) => EnumerateAllAsync(objectKind, null, null, cancellationToken); @@ -390,13 +512,13 @@ CancellationToken cancellationToken filter?.ReleaseExpressionResources(); } - return EnumerateAllAsync(query, channel, cancellationToken); + return EnumerateAllAsync(query, (ChannelReader)channel, cancellationToken); } - private static async IAsyncEnumerable EnumerateAllAsync + private static async IAsyncEnumerable EnumerateAllAsync ( SafeDeviceQueryHandle queryHandle, - ChannelReader reader, + ChannelReader reader, [EnumeratorCancellation] CancellationToken cancellationToken ) { @@ -821,6 +943,7 @@ private static unsafe IntPtr CreateHelperContext(GCHandle contextHandle, Method { Method.EnumerateAll => &EnumerateAllCallback, Method.FindAll => &FindAllCallback, + Method.WatchAll => &WatchAllCallback, Method.GetObjectProperties => &GetObjectPropertiesCallback, _ => throw new InvalidOperationException() }, @@ -835,6 +958,7 @@ private static unsafe IntPtr CreateHelperContext(GCHandle contextHandle, Method { Method.EnumerateAll => EnumerateAllCallback, Method.FindAll => FindAllCallback, + Method.WatchAll => WatchAllCallback, Method.GetObjectProperties => GetObjectPropertiesCallback, _ => throw new InvalidOperationException() }, diff --git a/src/DeviceTools/DeviceTools.Core/WatchNotificationKind.cs b/src/DeviceTools/DeviceTools.Core/WatchNotificationKind.cs new file mode 100644 index 0000000..02d8a57 --- /dev/null +++ b/src/DeviceTools/DeviceTools.Core/WatchNotificationKind.cs @@ -0,0 +1,9 @@ +namespace DeviceTools; + +public enum WatchNotificationKind +{ + Enumeration = 0, + Add = 1, + Update = 2, + Remove = 3, +} diff --git a/src/Tools/StreamDeckPlayground/MainViewModel.cs b/src/Tools/StreamDeckPlayground/MainViewModel.cs index 73b1cc2..640ff8b 100644 --- a/src/Tools/StreamDeckPlayground/MainViewModel.cs +++ b/src/Tools/StreamDeckPlayground/MainViewModel.cs @@ -42,7 +42,7 @@ public async Task WatchDevicesAsync(CancellationToken cancellationToken) { await foreach ( - var device in DeviceQuery.EnumerateAllAsync + var notification in DeviceQuery.WatchAllAsync ( DeviceObjectKind.DeviceInterface, Properties.System.Devices.InterfaceClassGuid == DeviceInterfaceClassGuids.Hid & @@ -53,7 +53,27 @@ var device in DeviceQuery.EnumerateAllAsync ) ) { - _devices.Add(await StreamDeckViewModel.CreateAsync(device.Id, 0x006C, cancellationToken)); + switch (notification.Kind) + { + case WatchNotificationKind.Enumeration: + case WatchNotificationKind.Add: + _devices.Add(await StreamDeckViewModel.CreateAsync(notification.Object.Id, 0x006C, cancellationToken)); + break; + case WatchNotificationKind.Update: + break; + case WatchNotificationKind.Remove: + for (int i = 0; i < _devices.Count; i++) + { + var device = _devices[i]; + if (device.DeviceName == notification.Object.Id) + { + _devices.RemoveAt(i); + await device.DisposeAsync(); + break; + } + } + break; + } } } catch (OperationCanceledException)