Skip to content

Commit

Permalink
✨ Wire in the ability to watch device arrivals through DevQuery 😊
Browse files Browse the repository at this point in the history
Honnestly don't remember why I didn't add this earlier, but it works as expected.
  • Loading branch information
hexawyz committed Feb 9, 2025
1 parent 4c50f12 commit 7d4b73a
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace DeviceTools;

public readonly struct DeviceObjectWatchWatchNotification(WatchNotificationKind kind, DeviceObjectInformation @object)
{
public WatchNotificationKind Kind { get; } = kind;
public DeviceObjectInformation Object { get; } = @object;
}
130 changes: 127 additions & 3 deletions src/DeviceTools/DeviceTools.Core/DeviceQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,56 @@ public sealed class DeviceQuery
AllowSynchronousContinuations = false
};

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
#endif
private static unsafe void WatchAllCallback(IntPtr handle, IntPtr context, NativeMethods.DeviceQueryResultActionData* action)
{
var ctx = Unsafe.As<DevQueryWatchCallbackContext>(GCHandle.FromIntPtr(context).Target)!;
switch (action->Action)
{
case NativeMethods.DeviceQueryResultAction.DevQueryResultAdd:
case NativeMethods.DeviceQueryResultAction.DevQueryResultUpdate:
case NativeMethods.DeviceQueryResultAction.DevQueryResultRemove:
ref var @object = ref action->StateOrObject.DeviceObject;
var kind = (WatchNotificationKind)action->Action;
if (!ctx.IsEnumerationComplete && kind == WatchNotificationKind.Add) kind = WatchNotificationKind.Enumeration;
#if NET6_0_OR_GREATER
ctx.Writer.TryWrite(new(kind, new(@object.ObjectType, MemoryMarshal.CreateReadOnlySpanFromNullTerminated(@object.ObjectId).ToString(), ParseProperties(ref @object))));
#else
ctx.Writer.TryWrite(new(kind, new(@object.ObjectType, Marshal.PtrToStringUni((IntPtr)@object.ObjectId)!, ParseProperties(ref @object))));
#endif
break;
case NativeMethods.DeviceQueryResultAction.DevQueryResultStateChange:
var state = action->StateOrObject.State;
if (state is NativeMethods.DeviceQueryState.DevQueryStateEnumCompleted)
{
ctx.IsEnumerationComplete = true;
}
else if (state is NativeMethods.DeviceQueryState.DevQueryStateAborted)
{
#if NET5_0_OR_GREATER
ctx.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(new Exception("The query was aborted.")));
#else
try { new Exception("The query was aborted."); }
catch (Exception ex) { ctx.Writer.TryComplete(ex); }
#endif
ctx.Dispose();
}
else if (state is NativeMethods.DeviceQueryState.DevQueryStateClosed)
{
#if NET5_0_OR_GREATER
ctx.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException()));
#else
try { new OperationCanceledException(); }
catch (Exception ex) { ctx.Writer.TryComplete(ex); }
#endif
ctx.Dispose();
}
break;
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
#endif
Expand Down Expand Up @@ -167,6 +217,7 @@ private enum Method
{
EnumerateAll = 1,
FindAll,
WatchAll,
GetObjectProperties
}

Expand Down Expand Up @@ -203,6 +254,14 @@ private class DevQueryCallbackContext<TState> : DevQueryCallbackContext
public DevQueryCallbackContext(Method method, TState state) : base(method) => State = state;
}

private class DevQueryWatchCallbackContext : DevQueryCallbackContext
{
public ChannelWriter<DeviceObjectWatchWatchNotification> Writer { get; }
public bool IsEnumerationComplete { get; set; }

public DevQueryWatchCallbackContext(ChannelWriter<DeviceObjectWatchWatchNotification> writer) : base(Method.WatchAll) => Writer = writer;
}

private class DevQueryCallbackContext<TState, TValue> : DevQueryCallbackContext<TState>
where TState : class
{
Expand Down Expand Up @@ -332,6 +391,69 @@ public DevQueryCallbackContext(Method method, TState state) : base(method, state
return dictionary;
}

public static IAsyncEnumerable<DeviceObjectWatchWatchNotification> WatchAllAsync(DeviceObjectKind objectKind, CancellationToken cancellationToken) =>
WatchAllAsync(objectKind, null, null, cancellationToken);

public static IAsyncEnumerable<DeviceObjectWatchWatchNotification> WatchAllAsync(DeviceObjectKind objectKind, DeviceFilterExpression filter, CancellationToken cancellationToken) =>
WatchAllAsync(objectKind, null, filter, cancellationToken);

public static IAsyncEnumerable<DeviceObjectWatchWatchNotification> WatchAllAsync(DeviceObjectKind objectKind, IEnumerable<Property>? properties, CancellationToken cancellationToken) =>
WatchAllAsync(objectKind, properties, null, cancellationToken);

public static IAsyncEnumerable<DeviceObjectWatchWatchNotification> WatchAllAsync
(
DeviceObjectKind objectKind,
IEnumerable<Property>? properties,
DeviceFilterExpression? filter,
CancellationToken cancellationToken
)
{
int count = filter?.GetFilterElementCount(true) ?? 0;
Span<NativeMethods.DevicePropertyFilterExpression> filterExpressions = count <= 4 ?
count == 0 ?
new Span<NativeMethods.DevicePropertyFilterExpression>() :
stackalloc NativeMethods.DevicePropertyFilterExpression[count] :
new NativeMethods.DevicePropertyFilterExpression[count];

Span<NativeMethods.DevicePropertyCompoundKey> propertyKeys = GetPropertyKeys(properties);

SafeDeviceQueryHandle query;
DevQueryWatchCallbackContext context;

filter?.FillExpressions(filterExpressions, true, out count);

var channel = Channel.CreateUnbounded<DeviceObjectWatchWatchNotification>(EnumerateAllChannelOptions);
try
{
context = new(channel);

try
{
query = CreateObjectQuery
(
objectKind,
properties is null ?
NativeMethods.DeviceQueryFlags.UpdateResults | NativeMethods.DeviceQueryFlags.AllProperties | NativeMethods.DeviceQueryFlags.AsyncClose :
NativeMethods.DeviceQueryFlags.UpdateResults | NativeMethods.DeviceQueryFlags.AsyncClose,
propertyKeys,
filterExpressions,
context.GetHandle()
);
}
catch
{
context.Dispose();
throw;
}
}
finally
{
filter?.ReleaseExpressionResources();
}

return EnumerateAllAsync(query, (ChannelReader<DeviceObjectWatchWatchNotification>)channel, cancellationToken);
}

public static IAsyncEnumerable<DeviceObjectInformation> EnumerateAllAsync(DeviceObjectKind objectKind, CancellationToken cancellationToken) =>
EnumerateAllAsync(objectKind, null, null, cancellationToken);

Expand Down Expand Up @@ -390,13 +512,13 @@ CancellationToken cancellationToken
filter?.ReleaseExpressionResources();
}

return EnumerateAllAsync(query, channel, cancellationToken);
return EnumerateAllAsync(query, (ChannelReader<DeviceObjectInformation>)channel, cancellationToken);
}

private static async IAsyncEnumerable<DeviceObjectInformation> EnumerateAllAsync
private static async IAsyncEnumerable<T> EnumerateAllAsync<T>
(
SafeDeviceQueryHandle queryHandle,
ChannelReader<DeviceObjectInformation> reader,
ChannelReader<T> reader,
[EnumeratorCancellation] CancellationToken cancellationToken
)
{
Expand Down Expand Up @@ -821,6 +943,7 @@ private static unsafe IntPtr CreateHelperContext(GCHandle contextHandle, Method
{
Method.EnumerateAll => &EnumerateAllCallback,
Method.FindAll => &FindAllCallback,
Method.WatchAll => &WatchAllCallback,
Method.GetObjectProperties => &GetObjectPropertiesCallback,
_ => throw new InvalidOperationException()
},
Expand All @@ -835,6 +958,7 @@ private static unsafe IntPtr CreateHelperContext(GCHandle contextHandle, Method
{
Method.EnumerateAll => EnumerateAllCallback,
Method.FindAll => FindAllCallback,
Method.WatchAll => WatchAllCallback,
Method.GetObjectProperties => GetObjectPropertiesCallback,
_ => throw new InvalidOperationException()
},
Expand Down
9 changes: 9 additions & 0 deletions src/DeviceTools/DeviceTools.Core/WatchNotificationKind.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace DeviceTools;

public enum WatchNotificationKind
{
Enumeration = 0,
Add = 1,
Update = 2,
Remove = 3,
}
24 changes: 22 additions & 2 deletions src/Tools/StreamDeckPlayground/MainViewModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public async Task WatchDevicesAsync(CancellationToken cancellationToken)
{
await foreach
(
var device in DeviceQuery.EnumerateAllAsync
var notification in DeviceQuery.WatchAllAsync
(
DeviceObjectKind.DeviceInterface,
Properties.System.Devices.InterfaceClassGuid == DeviceInterfaceClassGuids.Hid &
Expand All @@ -53,7 +53,27 @@ var device in DeviceQuery.EnumerateAllAsync
)
)
{
_devices.Add(await StreamDeckViewModel.CreateAsync(device.Id, 0x006C, cancellationToken));
switch (notification.Kind)
{
case WatchNotificationKind.Enumeration:
case WatchNotificationKind.Add:
_devices.Add(await StreamDeckViewModel.CreateAsync(notification.Object.Id, 0x006C, cancellationToken));
break;
case WatchNotificationKind.Update:
break;
case WatchNotificationKind.Remove:
for (int i = 0; i < _devices.Count; i++)
{
var device = _devices[i];
if (device.DeviceName == notification.Object.Id)
{
_devices.RemoveAt(i);
await device.DisposeAsync();
break;
}
}
break;
}
}
}
catch (OperationCanceledException)
Expand Down

0 comments on commit 7d4b73a

Please sign in to comment.