diff --git a/docs/remote-control/README.md b/docs/remote-control/README.md new file mode 100644 index 000000000000..359650b159cc --- /dev/null +++ b/docs/remote-control/README.md @@ -0,0 +1,147 @@ +# Remote Control / IPC — hardened (final candidate) + +This revision hardens the IPC subsystem for Files to address resource, security, and correctness issues: +- Strict JSON-RPC 2.0 validation and shape enforcement (includes IsInvalidRequest). +- Encrypted token storage (DPAPI) with epoch-based rotation that invalidates existing sessions. +- Centralized RpcMethodRegistry used everywhere (transports + adapter). +- WebSocket receive caps, per-method caps, per-client queue caps, lossy coalescing by method and per-client token bucket applied for both requests and notifications. +- Named Pipe per-user ACL and per-session randomized pipe name; length-prefixed framing. +- getMetadata: capped by items and timeout; runs off UI thread and honors client cancellation. +- Selection notifications are capped and include truncated flag. +- UIOperationQueue required to be passed a DispatcherQueue; all UI-affecting operations serialized. + +## Merge checklist +- [ ] Settings UI: "Enable Remote Control" (ProtectedTokenStore.SetEnabled), "Rotate Token" (RotateTokenAsync), "Enable Long Paths" toggle and display of current pipe name/port only when enabled. +- [ ] ShellViewModel: wire ExecuteActionById / NavigateToPathNormalized or expose small interface for adapter. +- [ ] Packaging decision: Document Kestrel + URLACL if WS is desired in Store/MSIX; default recommended for Store builds is NamedPipe-only. +- [ ] Tests: WS/pipe oversize, slow-consumer (lossy/coalesce), JSON-RPC conformance, getMetadata timeout & cancellation, token rotation invalidation. +- [ ] Telemetry hooks: auth failures, slow-client disconnects, queue drops. + +## JSON-RPC error codes used +- -32700 Parse error +- -32600 Invalid Request +- -32601 Method not found +- -32602 Invalid params +- -32001 Authentication required +- -32002 Invalid token +- -32003 Rate limit exceeded +- -32004 Session expired +- -32000 Internal server error + +## Architecture + +### Core Components + +#### JsonRpcMessage +Strict JSON-RPC 2.0 implementation with helpers for creating responses and validating message shapes. Preserves original ID types and enforces result XOR error semantics. + +#### ProtectedTokenStore +DPAPI-backed encrypted token storage in LocalSettings with epoch-based rotation. When tokens are rotated, the epoch increments and invalidates all existing client sessions. + +#### ClientContext +Per-client state management including: +- Token bucket rate limiting (configurable burst and refill rate) +- Lossy message queue with method-based coalescing +- Authentication state and epoch tracking +- Connection lifecycle management + +#### RpcMethodRegistry +Centralized registry for RPC method definitions including: +- Authentication requirements +- Notification permissions +- Per-method payload size limits +- Custom authorization policies + +#### Transport Services +- **WebSocketAppCommunicationService**: HTTP listener on loopback with WebSocket upgrade +- **NamedPipeAppCommunicationService**: Per-user ACL with randomized pipe names + +#### ShellIpcAdapter +Application logic adapter that: +- Enforces method allowlists and security policies +- Provides path normalization and validation +- Implements resource-bounded operations (metadata with timeouts) +- Serializes UI operations through UIOperationQueue + +## Security Features + +### Authentication & Authorization +- DPAPI-encrypted token storage +- Per-session token validation with epoch checking +- Method-level authorization policies +- Per-user ACL on named pipes + +### Resource Protection +- Configurable message size limits per transport +- Per-client queue size limits with lossy behavior +- Rate limiting with token bucket algorithm +- Operation timeouts and cancellation support + +### Attack Mitigation +- Strict JSON-RPC validation prevents malformed requests +- Path normalization rejects device paths and traversal attempts +- Selection notifications capped to prevent resource exhaustion +- Automatic cleanup of inactive/stale connections + +## Configuration + +All limits are configurable via `IpcConfig`: +```csharp +IpcConfig.WebSocketMaxMessageBytes = 16 * 1024 * 1024; // 16 MB +IpcConfig.NamedPipeMaxMessageBytes = 10 * 1024 * 1024; // 10 MB +IpcConfig.PerClientQueueCapBytes = 2 * 1024 * 1024; // 2 MB +IpcConfig.RateLimitPerSecond = 20; +IpcConfig.RateLimitBurst = 60; +IpcConfig.SelectionNotificationCap = 200; +IpcConfig.GetMetadataMaxItems = 500; +IpcConfig.GetMetadataTimeoutSec = 30; +``` + +## Supported Methods + +### Authentication +- `handshake` - Authenticate with token and establish session + +### State Query +- `getState` - Get current navigation state +- `listActions` - Get available actions + +### Operations +- `navigate` - Navigate to path (with normalization) +- `executeAction` - Execute registered action by ID +- `getMetadata` - Get file/folder metadata (batched, with timeout) + +### Notifications (Broadcast) +- `workingDirectoryChanged` - Current directory changed +- `selectionChanged` - File selection changed (with truncation) +- `ping` - Keepalive heartbeat + +## Usage + +**DO NOT enable IPC by default** — StartAsync refuses to start unless the user explicitly enables Remote Control via Settings. See merge checklist above. + +### Enabling Remote Control +```csharp +// In Settings UI +await ProtectedTokenStore.SetEnabled(true); +var token = await ProtectedTokenStore.GetOrCreateTokenAsync(); +``` + +### Starting Services +```csharp +// Only starts if enabled +await webSocketService.StartAsync(); +await namedPipeService.StartAsync(); +``` + +### Token Rotation +```csharp +// Invalidates all existing sessions +var newToken = await ProtectedTokenStore.RotateTokenAsync(); +``` + +## Implementation Status + +✅ **Complete**: Core IPC framework, security model, transport services +🔄 **Pending**: Settings UI integration, ShellViewModel method wiring +📋 **TODO**: Comprehensive tests, telemetry integration, Kestrel option \ No newline at end of file diff --git a/global.json b/global.json index 6b2ebefd9cc0..5ce2e6ef2fcf 100644 --- a/global.json +++ b/global.json @@ -3,4 +3,4 @@ "version": "9.0.200", "rollForward": "latestMajor" } -} \ No newline at end of file +} diff --git a/src/Files.App/Communication/ActionRegistry.cs b/src/Files.App/Communication/ActionRegistry.cs new file mode 100644 index 000000000000..4d51e585a5a8 --- /dev/null +++ b/src/Files.App/Communication/ActionRegistry.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Files.App.Communication +{ + // Simple action registry for IPC system + public sealed class ActionRegistry + { + // readonly fields + private readonly HashSet _allowedActions = new(StringComparer.OrdinalIgnoreCase) + { + "navigate", + "refresh", + "copyPath", + "openInNewTab", + "openInNewWindow", + "toggleDualPane", + "showProperties" + }; + + // Public methods + public bool CanExecute(string actionId, object? context = null) + { + if (string.IsNullOrEmpty(actionId)) + return false; + + return _allowedActions.Contains(actionId); + } + + public IEnumerable GetAllowedActions() => _allowedActions.ToList(); + + public void RegisterAction(string actionId) + { + if (!string.IsNullOrEmpty(actionId)) + _allowedActions.Add(actionId); + } + } +} \ No newline at end of file diff --git a/src/Files.App/Communication/ClientContext.cs b/src/Files.App/Communication/ClientContext.cs new file mode 100644 index 000000000000..3ab5e331116f --- /dev/null +++ b/src/Files.App/Communication/ClientContext.cs @@ -0,0 +1,150 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Net.WebSockets; +using System.Text; +using System.Threading; + +namespace Files.App.Communication +{ + // Per-client state with token-bucket, lossy enqueue and LastSeenUtc tracked. + public sealed class ClientContext : IDisposable + { + // readonly fields + private readonly object _rateLock = new(); + private readonly ConcurrentQueue<(string payload, bool isNotification, string? method)> _sendQueue = new(); + + // Fields + private long _queuedBytes = 0; + private int _tokens; + private DateTime _lastRefill; + + // _disposed field + private bool _disposed; + + // Properties + public Guid Id { get; } = Guid.NewGuid(); + + public string? ClientInfo { get; set; } + + public bool IsAuthenticated { get; set; } + + public int AuthEpoch { get; set; } = 0; // set at handshake + + public DateTime LastSeenUtc { get; set; } = DateTime.UtcNow; + + public long MaxQueuedBytes { get; set; } = IpcConfig.PerClientQueueCapBytes; + + public CancellationTokenSource? Cancellation { get; set; } + + public WebSocket? WebSocket { get; set; } + + public object? TransportHandle { get; set; } // can store session id, pipe name, etc. + + internal ConcurrentQueue<(string payload, bool isNotification, string? method)> SendQueue => _sendQueue; + + // Constructor + public ClientContext() + { + _tokens = IpcConfig.RateLimitBurst; + _lastRefill = DateTime.UtcNow; + } + + // Public methods + public void RefillTokens() + { + lock (_rateLock) + { + var now = DateTime.UtcNow; + var delta = (now - _lastRefill).TotalSeconds; + if (delta <= 0) + return; + + var add = (int)(delta * IpcConfig.RateLimitPerSecond); + if (add > 0) + { + _tokens = Math.Min(IpcConfig.RateLimitBurst, _tokens + add); + _lastRefill = now; + } + } + } + + public bool TryConsumeToken() + { + RefillTokens(); + lock (_rateLock) + { + if (_tokens <= 0) + return false; + + _tokens--; + return true; + } + } + + // Try enqueue with lossy policy; drops oldest notifications of the same method first when needed. + public bool TryEnqueue(string payload, bool isNotification, string? method = null) + { + var bytes = Encoding.UTF8.GetByteCount(payload); + var newVal = Interlocked.Add(ref _queuedBytes, bytes); + if (newVal > MaxQueuedBytes) + { + // attempt to free by dropping oldest notifications (prefer same-method) + int freed = 0; + var initialQueue = new List<(string payload, bool isNotification, string? method)>(); + while (SendQueue.TryDequeue(out var old)) + { + if (!old.isNotification) + { + initialQueue.Add(old); // keep responses + } + else if (old.method != null && method != null && old.method.Equals(method, StringComparison.OrdinalIgnoreCase) && freed == 0) + { + // drop one older of same method + var b = Encoding.UTF8.GetByteCount(old.payload); + Interlocked.Add(ref _queuedBytes, -b); + freed += b; + break; + } + else + { + // for fairness, try dropping other notifications as well + var b = Encoding.UTF8.GetByteCount(old.payload); + Interlocked.Add(ref _queuedBytes, -b); + freed += b; + if (Interlocked.Read(ref _queuedBytes) <= MaxQueuedBytes) + break; + } + } + + // push back preserved responses + foreach (var item in initialQueue) + SendQueue.Enqueue(item); + + newVal = Interlocked.Read(ref _queuedBytes); + if (newVal + bytes > MaxQueuedBytes) + { + // still cannot enqueue + return false; + } + } + + SendQueue.Enqueue((payload, isNotification, method)); + return true; + } + + // Internal methods + internal void DecreaseQueuedBytes(int sentBytes) => Interlocked.Add(ref _queuedBytes, -sentBytes); + + // Dispose + public void Dispose() + { + if (_disposed) + return; + + try { Cancellation?.Cancel(); } catch { } + try { WebSocket?.Dispose(); } catch { } + _disposed = true; + } + } +} \ No newline at end of file diff --git a/src/Files.App/Communication/IAppCommunicationService.cs b/src/Files.App/Communication/IAppCommunicationService.cs new file mode 100644 index 000000000000..a9c946fa7947 --- /dev/null +++ b/src/Files.App/Communication/IAppCommunicationService.cs @@ -0,0 +1,44 @@ +using System; +using System.Threading.Tasks; + +namespace Files.App.Communication +{ + /// + /// Represents a communication service for handling JSON-RPC messages between clients and the application. + /// Implementations provide transport-specific functionality (WebSocket, Named Pipe, etc.) + /// + public interface IAppCommunicationService + { + /// + /// Occurs when a JSON-RPC request is received from a client. + /// + event Func? OnRequestReceived; + + /// + /// Starts the communication service and begins listening for client connections. + /// + /// A task that represents the asynchronous start operation. + Task StartAsync(); + + /// + /// Stops the communication service and closes all client connections. + /// + /// A task that represents the asynchronous stop operation. + Task StopAsync(); + + /// + /// Sends a JSON-RPC response message to a specific client. + /// + /// The client context to send the response to. + /// The JSON-RPC response message to send. + /// A task that represents the asynchronous send operation. + Task SendResponseAsync(ClientContext client, JsonRpcMessage response); + + /// + /// Broadcasts a JSON-RPC notification message to all connected clients. + /// + /// The JSON-RPC notification message to broadcast. + /// A task that represents the asynchronous broadcast operation. + Task BroadcastAsync(JsonRpcMessage notification); + } +} \ No newline at end of file diff --git a/src/Files.App/Communication/IpcConfig.cs b/src/Files.App/Communication/IpcConfig.cs new file mode 100644 index 000000000000..90d0e8b77b79 --- /dev/null +++ b/src/Files.App/Communication/IpcConfig.cs @@ -0,0 +1,22 @@ +namespace Files.App.Communication +{ + // Runtime configuration for IPC system - uses constants from Constants.IpcSettings as defaults + public static class IpcConfig + { + public static long WebSocketMaxMessageBytes { get; set; } = Constants.IpcSettings.WebSocketMaxMessageBytes; + + public static long NamedPipeMaxMessageBytes { get; set; } = Constants.IpcSettings.NamedPipeMaxMessageBytes; + + public static long PerClientQueueCapBytes { get; set; } = Constants.IpcSettings.PerClientQueueCapBytes; + + public static int RateLimitPerSecond { get; set; } = Constants.IpcSettings.RateLimitPerSecond; + + public static int RateLimitBurst { get; set; } = Constants.IpcSettings.RateLimitBurst; + + public static int SelectionNotificationCap { get; set; } = Constants.IpcSettings.SelectionNotificationCap; + + public static int GetMetadataMaxItems { get; set; } = Constants.IpcSettings.GetMetadataMaxItems; + + public static int GetMetadataTimeoutSec { get; set; } = Constants.IpcSettings.GetMetadataTimeoutSec; + } +} \ No newline at end of file diff --git a/src/Files.App/Communication/JsonRpcMessage.cs b/src/Files.App/Communication/JsonRpcMessage.cs new file mode 100644 index 000000000000..87e5b29fe332 --- /dev/null +++ b/src/Files.App/Communication/JsonRpcMessage.cs @@ -0,0 +1,74 @@ +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Files.App.Communication +{ + // Strict JSON-RPC 2.0 model with helpers that preserve original id types and enforce result XOR error. + public sealed record JsonRpcMessage + { + [JsonPropertyName("jsonrpc")] + public string JsonRpc { get; init; } = "2.0"; + + [JsonPropertyName("id")] + public JsonElement? Id { get; init; } // omitted => notification + + [JsonPropertyName("method")] + public string? Method { get; init; } + + [JsonPropertyName("params")] + public JsonElement? Params { get; init; } + + [JsonPropertyName("result")] + public JsonElement? Result { get; init; } + + [JsonPropertyName("error")] + public JsonElement? Error { get; init; } + + public bool IsNotification => Id is null || (Id.HasValue && Id.Value.ValueKind == JsonValueKind.Null); + + public static JsonRpcMessage? FromJson(string json) + { + try { return JsonSerializer.Deserialize(json); } + catch { return null; } + } + + public string ToJson() => JsonSerializer.Serialize(this); + + public static JsonRpcMessage MakeError(JsonElement? id, int code, string message) + { + var errObj = new { code, message }; + var doc = JsonSerializer.SerializeToElement(errObj); + return new JsonRpcMessage { Id = id, Error = doc }; + } + + public static JsonRpcMessage MakeResult(JsonElement? id, object result) + { + var doc = JsonSerializer.SerializeToElement(result); + return new JsonRpcMessage { Id = id, Result = doc }; + } + + public static bool ValidJsonRpc(JsonRpcMessage? msg) => msg is not null && msg.JsonRpc == "2.0"; + + // Validate that incoming message is a legal JSON-RPC request/notification/response shape + public static bool IsInvalidRequest(JsonRpcMessage m) + { + var hasMethod = !string.IsNullOrEmpty(m.Method); + var hasResult = m.Result is not null && m.Result.Value.ValueKind != JsonValueKind.Undefined; + var hasError = m.Error is not null && m.Error.Value.ValueKind != JsonValueKind.Undefined; + + // result and error are mutually exclusive + if (hasResult && hasError) + return true; + + // request or notification: method present; NO result/error + if (hasMethod && (hasResult || hasError)) + return true; + + // response: no method; need exactly one of result or error + if (!hasMethod && !(hasResult ^ hasError)) + return true; + + return false; + } + } +} \ No newline at end of file diff --git a/src/Files.App/Communication/Models/ItemDto.cs b/src/Files.App/Communication/Models/ItemDto.cs new file mode 100644 index 000000000000..8e4c8cf3f0d9 --- /dev/null +++ b/src/Files.App/Communication/Models/ItemDto.cs @@ -0,0 +1,21 @@ +namespace Files.App.Communication.Models +{ + public sealed class ItemDto + { + public string Path { get; set; } = string.Empty; + + public string Name { get; set; } = string.Empty; + + public bool IsDirectory { get; set; } + + public long SizeBytes { get; set; } + + public string DateModified { get; set; } = string.Empty; + + public string DateCreated { get; set; } = string.Empty; + + public string? MimeType { get; set; } + + public bool Exists { get; set; } + } +} \ No newline at end of file diff --git a/src/Files.App/Communication/NamedPipeAppCommunicationService.cs b/src/Files.App/Communication/NamedPipeAppCommunicationService.cs new file mode 100644 index 000000000000..d4e922a83cc9 --- /dev/null +++ b/src/Files.App/Communication/NamedPipeAppCommunicationService.cs @@ -0,0 +1,526 @@ +using System; +using System.Buffers.Binary; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipes; +using System.Linq; +using System.Security.AccessControl; +using System.Security.Principal; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Windows.Storage; + +namespace Files.App.Communication +{ + public sealed class NamedPipeAppCommunicationService : IAppCommunicationService, IDisposable + { + // readonly fields + private readonly RpcMethodRegistry _methodRegistry; + private readonly ILogger _logger; + private readonly ConcurrentDictionary _clients = new(); + private readonly Timer _keepaliveTimer; + private readonly Timer _cleanupTimer; + private readonly CancellationTokenSource _cancellation = new(); + + // Fields + private string? _currentToken; + private int _currentEpoch; + private string? _pipeName; + private bool _isStarted; + private Task? _acceptTask; + + // _disposed field + private bool _disposed; + + // Events + public event Func? OnRequestReceived; + + // Constructor + public NamedPipeAppCommunicationService( + RpcMethodRegistry methodRegistry, + ILogger logger) + { + _methodRegistry = methodRegistry ?? throw new ArgumentNullException(nameof(methodRegistry)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + + // Setup keepalive timer (every 30 seconds) + _keepaliveTimer = new Timer(SendKeepalive, null, TimeSpan.FromSeconds(30d), TimeSpan.FromSeconds(30d)); + + // Setup cleanup timer (every 60 seconds) + _cleanupTimer = new Timer(CleanupInactiveClients, null, TimeSpan.FromSeconds(60d), TimeSpan.FromSeconds(60d)); + } + + // Public methods + public async Task StartAsync() + { + if (!ProtectedTokenStore.IsEnabled()) + { + _logger.LogWarning("Remote control is not enabled, refusing to start Named Pipe service"); + return; + } + + if (_isStarted) + return; + + try + { + _currentToken = await ProtectedTokenStore.GetOrCreateTokenAsync(); + _currentEpoch = ProtectedTokenStore.GetEpoch(); + + // Generate randomized pipe name per session for security + _pipeName = $"Files_IPC_{Environment.UserName}_{Guid.NewGuid():N}"; + + _isStarted = true; + _acceptTask = Task.Run(AcceptConnectionsAsync, _cancellation.Token); + + _logger.LogInformation("Named Pipe IPC service started with pipe: {PipeName}", _pipeName); + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to start Named Pipe IPC service"); + throw; + } + } + + public async Task StopAsync() + { + if (!_isStarted) + return; + + try + { + _cancellation.Cancel(); + + // Wait for accept task to complete + if (_acceptTask != null) + { + try + { + await _acceptTask; + } + catch (OperationCanceledException) + { + // Expected when cancelling + } + } + + // Close all client connections + foreach (var client in _clients.Values) + { + client.Dispose(); + } + _clients.Clear(); + + _isStarted = false; + _logger.LogInformation("Named Pipe IPC service stopped"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error stopping Named Pipe IPC service"); + } + } + + public async Task SendResponseAsync(ClientContext client, JsonRpcMessage response) + { + if (client?.TransportHandle is not NamedPipeServerStream pipe || !pipe.IsConnected) + return; + + try + { + var json = response.ToJson(); + var canEnqueue = client.TryEnqueue(json, false); + if (!canEnqueue) + { + _logger.LogWarning("Client {ClientId} queue full, dropping response", client.Id); + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error queuing response for client {ClientId}", client.Id); + } + } + + public async Task BroadcastAsync(JsonRpcMessage notification) + { + if (!_isStarted) + return; + + var json = notification.ToJson(); + var activeclients = _clients.Values + .Where(c => c.TransportHandle is NamedPipeServerStream pipe && pipe.IsConnected) + .ToList(); + + foreach (var client in activeclients) + { + try + { + var canEnqueue = client.TryEnqueue(json, true, notification.Method); + if (!canEnqueue) + { + _logger.LogDebug("Client {ClientId} queue full, dropping notification {Method}", client.Id, notification.Method); + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error queuing notification for client {ClientId}", client.Id); + } + } + } + + // Private methods + private async Task AcceptConnectionsAsync() + { + while (!_cancellation.Token.IsCancellationRequested) + { + try + { + var pipe = CreateSecurePipeServer(); + await pipe.WaitForConnectionAsync(_cancellation.Token); + + var client = new ClientContext + { + TransportHandle = pipe, + Cancellation = CancellationTokenSource.CreateLinkedTokenSource(_cancellation.Token) + }; + + _clients[client.Id] = client; + _logger.LogDebug("Named Pipe client {ClientId} connected", client.Id); + + // Start client handlers + _ = Task.Run(() => ClientSendLoopAsync(client), client.Cancellation.Token); + _ = Task.Run(() => ClientReceiveLoopAsync(client), client.Cancellation.Token); + } + catch (OperationCanceledException) when (_cancellation.Token.IsCancellationRequested) + { + break; + } + catch (Exception ex) + { + _logger.LogError(ex, "Error accepting Named Pipe connection"); + } + } + } + + private NamedPipeServerStream CreateSecurePipeServer() + { + var currentUser = WindowsIdentity.GetCurrent(); + var pipeSecurity = new PipeSecurity(); + + // Allow full control to current user only + pipeSecurity.AddAccessRule(new PipeAccessRule( + currentUser.User!, + PipeAccessRights.FullControl, + AccessControlType.Allow)); + + // Deny access to everyone else + pipeSecurity.AddAccessRule(new PipeAccessRule( + new SecurityIdentifier(WellKnownSidType.WorldSid, null), + PipeAccessRights.FullControl, + AccessControlType.Deny)); + + return NamedPipeServerStreamAcl.Create( + _pipeName!, + PipeDirection.InOut, + NamedPipeServerStream.MaxAllowedServerInstances, + PipeTransmissionMode.Byte, + PipeOptions.Asynchronous | PipeOptions.WriteThrough, + (int)IpcConfig.NamedPipeMaxMessageBytes, + (int)IpcConfig.NamedPipeMaxMessageBytes, + pipeSecurity); + } + + private async Task ClientReceiveLoopAsync(ClientContext client) + { + var pipe = (NamedPipeServerStream)client.TransportHandle!; + + while (pipe.IsConnected && !client.Cancellation!.Token.IsCancellationRequested) + { + try + { + // Read length prefix (4 bytes) + var lengthBuffer = new byte[4]; + var bytesRead = 0; + while (bytesRead < 4) + { + var read = await pipe.ReadAsync( + lengthBuffer.AsMemory(bytesRead, 4 - bytesRead), + client.Cancellation.Token); + if (read == 0) + return; // Pipe closed + + bytesRead += read; + } + + var messageLength = BinaryPrimitives.ReadInt32LittleEndian(lengthBuffer); + if (messageLength <= 0 || messageLength > IpcConfig.NamedPipeMaxMessageBytes) + { + _logger.LogWarning("Invalid message length {Length} from client {ClientId}", messageLength, client.Id); + return; + } + + // Read message body + var messageBuffer = new byte[messageLength]; + bytesRead = 0; + while (bytesRead < messageLength) + { + var read = await pipe.ReadAsync( + messageBuffer.AsMemory(bytesRead, messageLength - bytesRead), + client.Cancellation.Token); + if (read == 0) + return; // Pipe closed + + bytesRead += read; + } + + var messageText = Encoding.UTF8.GetString(messageBuffer); + client.LastSeenUtc = DateTime.UtcNow; + + await ProcessIncomingMessageAsync(client, messageText); + } + catch (OperationCanceledException) when (client.Cancellation.Token.IsCancellationRequested) + { + break; + } + catch (IOException ex) when (ex.Message.Contains("pipe")) + { + break; // Pipe disconnected + } + catch (Exception ex) + { + _logger.LogError(ex, "Error in client receive loop for {ClientId}", client.Id); + break; + } + } + + // Cleanup client + _clients.TryRemove(client.Id, out _); + client.Dispose(); + _logger.LogDebug("Named Pipe client {ClientId} disconnected", client.Id); + } + + private async Task ClientSendLoopAsync(ClientContext client) + { + var pipe = (NamedPipeServerStream)client.TransportHandle!; + + while (pipe.IsConnected && !client.Cancellation!.Token.IsCancellationRequested) + { + try + { + if (client.SendQueue.TryDequeue(out var item)) + { + var messageBytes = Encoding.UTF8.GetBytes(item.payload); + var lengthBytes = new byte[4]; + BinaryPrimitives.WriteInt32LittleEndian(lengthBytes, messageBytes.Length); + + // Write length prefix first + await pipe.WriteAsync(lengthBytes, client.Cancellation.Token); + + // Write message body + await pipe.WriteAsync(messageBytes, client.Cancellation.Token); + await pipe.FlushAsync(client.Cancellation.Token); + + client.DecreaseQueuedBytes(messageBytes.Length); + } + else + { + // No messages to send, wait a bit + await Task.Delay(10, client.Cancellation.Token); + } + } + catch (OperationCanceledException) when (client.Cancellation.Token.IsCancellationRequested) + { + break; + } + catch (IOException ex) when (ex.Message.Contains("pipe")) + { + break; // Pipe disconnected + } + catch (Exception ex) + { + _logger.LogError(ex, "Error in client send loop for {ClientId}", client.Id); + break; + } + } + } + + private async Task ProcessIncomingMessageAsync(ClientContext client, string messageText) + { + try + { + // Rate limiting check + if (!client.TryConsumeToken()) + { + var error = JsonRpcMessage.MakeError(null, -32003, "Rate limit exceeded"); + await SendResponseAsync(client, error); + return; + } + + var message = JsonRpcMessage.FromJson(messageText); + if (!JsonRpcMessage.ValidJsonRpc(message) || JsonRpcMessage.IsInvalidRequest(message)) + { + var error = JsonRpcMessage.MakeError(message?.Id, -32600, "Invalid Request"); + await SendResponseAsync(client, error); + return; + } + + // Check method registry + if (!string.IsNullOrEmpty(message.Method) && _methodRegistry.TryGet(message.Method, out var methodDef)) + { + // Auth check + if (methodDef.RequiresAuth && !client.IsAuthenticated) + { + var error = JsonRpcMessage.MakeError(message.Id, -32001, "Authentication required"); + await SendResponseAsync(client, error); + return; + } + + // Additional auth policy check + if (methodDef.AuthorizationPolicy != null && !methodDef.AuthorizationPolicy(client, message)) + { + var error = JsonRpcMessage.MakeError(message.Id, -32002, "Authorization failed"); + await SendResponseAsync(client, error); + return; + } + } + + // Handle token validation for handshake + if (message.Method == "handshake") + { + await HandleHandshakeAsync(client, message); + return; + } + + // Delegate to handler + if (OnRequestReceived != null) + { + await OnRequestReceived(client, message); + } + } + catch (JsonException) + { + var error = JsonRpcMessage.MakeError(null, -32700, "Parse error"); + await SendResponseAsync(client, error); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error processing message from client {ClientId}", client.Id); + var error = JsonRpcMessage.MakeError(null, -32603, "Internal error"); + await SendResponseAsync(client, error); + } + } + + private async Task HandleHandshakeAsync(ClientContext client, JsonRpcMessage request) + { + try + { + if (request.Params?.TryGetProperty("token", out var tokenElement) == true) + { + var providedToken = tokenElement.GetString(); + if (string.Equals(providedToken, _currentToken, StringComparison.Ordinal)) + { + client.IsAuthenticated = true; + client.AuthEpoch = _currentEpoch; + + var result = JsonRpcMessage.MakeResult(request.Id, new + { + authenticated = true, + epoch = _currentEpoch, + serverVersion = "1.0" + }); + + await SendResponseAsync(client, result); + _logger.LogInformation("Client {ClientId} authenticated successfully", client.Id); + } + else + { + var error = JsonRpcMessage.MakeError(request.Id, -32002, "Invalid token"); + await SendResponseAsync(client, error); + } + } + else + { + var error = JsonRpcMessage.MakeError(request.Id, -32602, "Invalid params - token required"); + await SendResponseAsync(client, error); + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error handling handshake for client {ClientId}", client.Id); + var error = JsonRpcMessage.MakeError(request.Id, -32603, "Internal error"); + await SendResponseAsync(client, error); + } + } + + private void SendKeepalive(object? state) + { + if (!_isStarted || _cancellation.Token.IsCancellationRequested) + return; + + var pingNotification = new JsonRpcMessage + { + Method = "ping", + Params = JsonSerializer.SerializeToElement(new { timestamp = DateTime.UtcNow }) + }; + + _ = Task.Run(async () => + { + try + { + await BroadcastAsync(pingNotification); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error sending keepalive ping"); + } + }); + } + + private void CleanupInactiveClients(object? state) + { + if (!_isStarted || _cancellation.Token.IsCancellationRequested) + return; + + var cutoff = DateTime.UtcNow.AddMinutes(-5d); + var toRemove = new List(); + + foreach (var client in _clients.Values) + { + var pipe = client.TransportHandle as NamedPipeServerStream; + if (client.LastSeenUtc < cutoff || pipe?.IsConnected != true) + { + toRemove.Add(client); + } + } + + foreach (var client in toRemove) + { + _clients.TryRemove(client.Id, out _); + client.Dispose(); + _logger.LogDebug("Cleaned up inactive client {ClientId}", client.Id); + } + } + + // Dispose + public void Dispose() + { + if (_disposed) + return; + + _cancellation.Cancel(); + _keepaliveTimer?.Dispose(); + _cleanupTimer?.Dispose(); + _cancellation.Dispose(); + + foreach (var client in _clients.Values) + { + client.Dispose(); + } + + _disposed = true; + } + } +} \ No newline at end of file diff --git a/src/Files.App/Communication/ProtectedTokenStore.cs b/src/Files.App/Communication/ProtectedTokenStore.cs new file mode 100644 index 000000000000..72ac5a8e9c28 --- /dev/null +++ b/src/Files.App/Communication/ProtectedTokenStore.cs @@ -0,0 +1,83 @@ +using System; +using System.Threading.Tasks; +using Windows.Security.Cryptography; +using Windows.Security.Cryptography.DataProtection; +using Windows.Storage; + +namespace Files.App.Communication +{ + // DPAPI-backed token store. Stores encrypted token in LocalSettings and maintains an epoch for rotation. + internal static class ProtectedTokenStore + { + // Static fields + private const string KEY_TOKEN = "Files_RemoteControl_ProtectedToken"; + private const string KEY_ENABLED = "Files_RemoteControl_Enabled"; + private const string KEY_EPOCH = "Files_RemoteControl_TokenEpoch"; + + // Static properties + private static ApplicationDataContainer Settings => ApplicationData.Current.LocalSettings; + + // Static methods + public static bool IsEnabled() + { + if (Settings.Values.TryGetValue(KEY_ENABLED, out var v) && v is bool b) + return b; + + return false; + } + + public static void SetEnabled(bool enabled) => Settings.Values[KEY_ENABLED] = enabled; + + public static int GetEpoch() + { + if (Settings.Values.TryGetValue(KEY_EPOCH, out var v) && v is int e) + return e; + + SetEpoch(1); + return 1; + } + + public static async Task GetOrCreateTokenAsync() + { + if (Settings.Values.TryGetValue(KEY_TOKEN, out var val) && val is string b64 && !string.IsNullOrEmpty(b64)) + { + try + { + var protectedBuf = CryptographicBuffer.DecodeFromBase64String(b64); + var provider = new DataProtectionProvider(); + var unprotected = await provider.UnprotectAsync(protectedBuf); + return CryptographicBuffer.ConvertBinaryToString(BinaryStringEncoding.Utf8, unprotected); + } + catch + { + // fallback to regen + } + } + + var t = Guid.NewGuid().ToString("N"); + await SetTokenAsync(t); + SetEpoch(1); + return t; + } + + public static async Task RotateTokenAsync() + { + var t = Guid.NewGuid().ToString("N"); + await SetTokenAsync(t); + var epoch = GetEpoch() + 1; + SetEpoch(epoch); + return t; + } + + private static async Task SetTokenAsync(string token) + { + var provider = new DataProtectionProvider("LOCAL=user"); + var buffer = CryptographicBuffer.ConvertStringToBinary(token, BinaryStringEncoding.Utf8); + var protectedBuf = await provider.ProtectAsync(buffer); + var bytes = CryptographicBuffer.EncodeToBase64String(protectedBuf); + Settings.Values[KEY_TOKEN] = bytes; + } + + private static void SetEpoch(int epoch) => Settings.Values[KEY_EPOCH] = epoch; + } +} \ No newline at end of file diff --git a/src/Files.App/Communication/RpcMethodRegistry.cs b/src/Files.App/Communication/RpcMethodRegistry.cs new file mode 100644 index 000000000000..293188666166 --- /dev/null +++ b/src/Files.App/Communication/RpcMethodRegistry.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; + +namespace Files.App.Communication +{ + public sealed class RpcMethod + { + public string Name { get; init; } = string.Empty; + + public int? MaxPayloadBytes { get; init; } // optional cap per method + + public bool RequiresAuth { get; init; } = true; + + public bool AllowNotifications { get; init; } = true; + + public Func? AuthorizationPolicy { get; init; } // additional checks + } + + public sealed class RpcMethodRegistry + { + // readonly fields + private readonly ConcurrentDictionary _methods = new(); + + // Constructor + public RpcMethodRegistry() + { + Register(new RpcMethod { Name = "handshake", RequiresAuth = false, AllowNotifications = false }); + Register(new RpcMethod { Name = "getState", RequiresAuth = true, AllowNotifications = false }); + Register(new RpcMethod { Name = "listActions", RequiresAuth = true, AllowNotifications = false }); + Register(new RpcMethod { Name = "getMetadata", RequiresAuth = true, AllowNotifications = false, MaxPayloadBytes = 2 * 1024 * 1024 }); + Register(new RpcMethod { Name = "navigate", RequiresAuth = true, AllowNotifications = false }); + Register(new RpcMethod { Name = "executeAction", RequiresAuth = true, AllowNotifications = false }); + } + + // Public methods + public void Register(RpcMethod method) => _methods[method.Name] = method; + + public bool TryGet(string name, out RpcMethod method) => _methods.TryGetValue(name, out method); + + public IEnumerable List() => _methods.Values; + } +} \ No newline at end of file diff --git a/src/Files.App/Communication/UIOperationQueue.cs b/src/Files.App/Communication/UIOperationQueue.cs new file mode 100644 index 000000000000..dc1441e31989 --- /dev/null +++ b/src/Files.App/Communication/UIOperationQueue.cs @@ -0,0 +1,40 @@ +using System; +using System.Threading.Tasks; +using Microsoft.UI.Dispatching; + +namespace Files.App.Communication +{ + // Ensures all UI-affecting operations are serialized on the dispatcher thread + public sealed class UIOperationQueue + { + // readonly fields + private readonly DispatcherQueue _dispatcher; + + // Constructor + public UIOperationQueue(DispatcherQueue dispatcher) + { + _dispatcher = dispatcher ?? throw new ArgumentNullException(nameof(dispatcher)); + } + + // Public methods + public Task EnqueueAsync(Func operation) + { + var tcs = new TaskCompletionSource(); + + _dispatcher.TryEnqueue(async () => + { + try + { + await operation().ConfigureAwait(false); + tcs.SetResult(null); + } + catch (Exception ex) + { + tcs.SetException(ex); + } + }); + + return tcs.Task; + } + } +} \ No newline at end of file diff --git a/src/Files.App/Communication/WebSocketAppCommunicationService.cs b/src/Files.App/Communication/WebSocketAppCommunicationService.cs new file mode 100644 index 000000000000..b25d98ca1293 --- /dev/null +++ b/src/Files.App/Communication/WebSocketAppCommunicationService.cs @@ -0,0 +1,493 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.WebSockets; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Files.App.Communication +{ + public sealed class WebSocketAppCommunicationService : IAppCommunicationService, IDisposable + { + // readonly fields + private readonly HttpListener _httpListener; + private readonly RpcMethodRegistry _methodRegistry; + private readonly ILogger _logger; + private readonly ConcurrentDictionary _clients = new(); + private readonly Timer _keepaliveTimer; + private readonly Timer _cleanupTimer; + private readonly CancellationTokenSource _cancellation = new(); + + // Fields + private string? _currentToken; + private int _currentEpoch; + private bool _isStarted; + + // _disposed field + private bool _disposed; + + // Events + public event Func? OnRequestReceived; + + // Constructor + public WebSocketAppCommunicationService( + RpcMethodRegistry methodRegistry, + ILogger logger) + { + _methodRegistry = methodRegistry ?? throw new ArgumentNullException(nameof(methodRegistry)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _httpListener = new HttpListener(); + + // Setup keepalive timer (every 30 seconds) + _keepaliveTimer = new Timer(SendKeepalive, null, TimeSpan.FromSeconds(30d), TimeSpan.FromSeconds(30d)); + + // Setup cleanup timer (every 60 seconds) + _cleanupTimer = new Timer(CleanupInactiveClients, null, TimeSpan.FromSeconds(60d), TimeSpan.FromSeconds(60d)); + } + + // Public methods + public async Task StartAsync() + { + if (!ProtectedTokenStore.IsEnabled()) + { + _logger.LogWarning("Remote control is not enabled, refusing to start WebSocket service"); + return; + } + + if (_isStarted) + return; + + try + { + _currentToken = await ProtectedTokenStore.GetOrCreateTokenAsync(); + _currentEpoch = ProtectedTokenStore.GetEpoch(); + + _httpListener.Prefixes.Clear(); + _httpListener.Prefixes.Add("http://127.0.0.1:52345/"); + _httpListener.Start(); + _isStarted = true; + + _ = Task.Run(AcceptConnectionsAsync, _cancellation.Token); + + _logger.LogInformation("WebSocket IPC service started on http://127.0.0.1:52345/"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to start WebSocket IPC service"); + throw; + } + } + + public async Task StopAsync() + { + if (!_isStarted) + return; + + try + { + _cancellation.Cancel(); + _httpListener.Stop(); + + // Close all client connections + foreach (var client in _clients.Values) + { + client.Dispose(); + } + _clients.Clear(); + + _isStarted = false; + _logger.LogInformation("WebSocket IPC service stopped"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error stopping WebSocket IPC service"); + } + } + + public async Task SendResponseAsync(ClientContext client, JsonRpcMessage response) + { + if (client?.WebSocket?.State != WebSocketState.Open) + return; + + try + { + var json = response.ToJson(); + var canEnqueue = client.TryEnqueue(json, false); + if (!canEnqueue) + { + _logger.LogWarning("Client {ClientId} queue full, dropping response", client.Id); + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error queuing response for client {ClientId}", client.Id); + } + } + + public async Task BroadcastAsync(JsonRpcMessage notification) + { + if (!_isStarted) + return; + + var json = notification.ToJson(); + var activeclients = _clients.Values.Where(c => c.WebSocket?.State == WebSocketState.Open).ToList(); + + foreach (var client in activeclients) + { + try + { + var canEnqueue = client.TryEnqueue(json, true, notification.Method); + if (!canEnqueue) + { + _logger.LogDebug("Client {ClientId} queue full, dropping notification {Method}", client.Id, notification.Method); + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error queuing notification for client {ClientId}", client.Id); + } + } + } + + // Private methods + private async Task AcceptConnectionsAsync() + { + while (!_cancellation.Token.IsCancellationRequested) + { + try + { + var context = await _httpListener.GetContextAsync(); + if (context.Request.IsWebSocketRequest) + { + _ = Task.Run(() => HandleWebSocketConnection(context), _cancellation.Token); + } + else + { + context.Response.StatusCode = 400; + context.Response.Close(); + } + } + catch (HttpListenerException) when (_cancellation.Token.IsCancellationRequested) + { + break; + } + catch (Exception ex) + { + _logger.LogError(ex, "Error accepting WebSocket connection"); + } + } + } + + private async Task HandleWebSocketConnection(HttpListenerContext httpContext) + { + WebSocketContext? webSocketContext = null; + ClientContext? client = null; + + try + { + webSocketContext = await httpContext.AcceptWebSocketAsync(null); + var webSocket = webSocketContext.WebSocket; + + client = new ClientContext + { + WebSocket = webSocket, + Cancellation = CancellationTokenSource.CreateLinkedTokenSource(_cancellation.Token) + }; + + _clients[client.Id] = client; + _logger.LogDebug("WebSocket client {ClientId} connected", client.Id); + + // Start send loop + _ = Task.Run(() => ClientSendLoopAsync(client), client.Cancellation.Token); + + // Handle receive loop + await ClientReceiveLoopAsync(client); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error in WebSocket connection handler"); + } + finally + { + if (client != null) + { + _clients.TryRemove(client.Id, out _); + client.Dispose(); + _logger.LogDebug("WebSocket client {ClientId} disconnected", client.Id); + } + } + } + + private async Task ClientReceiveLoopAsync(ClientContext client) + { + var buffer = new byte[IpcConfig.WebSocketMaxMessageBytes]; + var webSocket = client.WebSocket!; + + while (webSocket.State == WebSocketState.Open && !client.Cancellation!.Token.IsCancellationRequested) + { + try + { + var messageBuilder = new StringBuilder(); + WebSocketReceiveResult result; + + do + { + result = await webSocket.ReceiveAsync(new ArraySegment(buffer), client.Cancellation.Token); + + if (result.MessageType == WebSocketMessageType.Text) + { + var text = Encoding.UTF8.GetString(buffer, 0, result.Count); + messageBuilder.Append(text); + } + else if (result.MessageType == WebSocketMessageType.Close) + return; + + } while (!result.EndOfMessage); + + var messageText = messageBuilder.ToString(); + if (string.IsNullOrEmpty(messageText)) + continue; + + client.LastSeenUtc = DateTime.UtcNow; + await ProcessIncomingMessageAsync(client, messageText); + } + catch (OperationCanceledException) when (client.Cancellation.Token.IsCancellationRequested) + { + break; + } + catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) + { + break; + } + catch (Exception ex) + { + _logger.LogError(ex, "Error in client receive loop for {ClientId}", client.Id); + break; + } + } + } + + private async Task ClientSendLoopAsync(ClientContext client) + { + var webSocket = client.WebSocket!; + + while (webSocket.State == WebSocketState.Open && !client.Cancellation!.Token.IsCancellationRequested) + { + try + { + if (client.SendQueue.TryDequeue(out var item)) + { + var bytes = Encoding.UTF8.GetBytes(item.payload); + await webSocket.SendAsync( + new ArraySegment(bytes), + WebSocketMessageType.Text, + true, + client.Cancellation.Token); + + client.DecreaseQueuedBytes(bytes.Length); + } + else + { + // No messages to send, wait a bit + await Task.Delay(10, client.Cancellation.Token); + } + } + catch (OperationCanceledException) when (client.Cancellation.Token.IsCancellationRequested) + { + break; + } + catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) + { + break; + } + catch (Exception ex) + { + _logger.LogError(ex, "Error in client send loop for {ClientId}", client.Id); + break; + } + } + } + + private async Task ProcessIncomingMessageAsync(ClientContext client, string messageText) + { + try + { + // Rate limiting check + if (!client.TryConsumeToken()) + { + var error = JsonRpcMessage.MakeError(null, -32003, "Rate limit exceeded"); + await SendResponseAsync(client, error); + return; + } + + var message = JsonRpcMessage.FromJson(messageText); + if (!JsonRpcMessage.ValidJsonRpc(message) || JsonRpcMessage.IsInvalidRequest(message)) + { + var error = JsonRpcMessage.MakeError(message?.Id, -32600, "Invalid Request"); + await SendResponseAsync(client, error); + return; + } + + // Check method registry + if (!string.IsNullOrEmpty(message.Method) && _methodRegistry.TryGet(message.Method, out var methodDef)) + { + // Auth check + if (methodDef.RequiresAuth && !client.IsAuthenticated) + { + var error = JsonRpcMessage.MakeError(message.Id, -32001, "Authentication required"); + await SendResponseAsync(client, error); + return; + } + + // Additional auth policy check + if (methodDef.AuthorizationPolicy != null && !methodDef.AuthorizationPolicy(client, message)) + { + var error = JsonRpcMessage.MakeError(message.Id, -32002, "Authorization failed"); + await SendResponseAsync(client, error); + return; + } + } + + // Handle token validation for handshake + if (message.Method == "handshake") + { + await HandleHandshakeAsync(client, message); + return; + } + + // Delegate to handler + if (OnRequestReceived != null) + { + await OnRequestReceived(client, message); + } + } + catch (JsonException) + { + var error = JsonRpcMessage.MakeError(null, -32700, "Parse error"); + await SendResponseAsync(client, error); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error processing message from client {ClientId}", client.Id); + var error = JsonRpcMessage.MakeError(null, -32603, "Internal error"); + await SendResponseAsync(client, error); + } + } + + private async Task HandleHandshakeAsync(ClientContext client, JsonRpcMessage request) + { + try + { + if (request.Params?.TryGetProperty("token", out var tokenElement) == true) + { + var providedToken = tokenElement.GetString(); + if (string.Equals(providedToken, _currentToken, StringComparison.Ordinal)) + { + client.IsAuthenticated = true; + client.AuthEpoch = _currentEpoch; + + var result = JsonRpcMessage.MakeResult(request.Id, new + { + authenticated = true, + epoch = _currentEpoch, + serverVersion = "1.0" + }); + + await SendResponseAsync(client, result); + _logger.LogInformation("Client {ClientId} authenticated successfully", client.Id); + } + else + { + var error = JsonRpcMessage.MakeError(request.Id, -32002, "Invalid token"); + await SendResponseAsync(client, error); + } + } + else + { + var error = JsonRpcMessage.MakeError(request.Id, -32602, "Invalid params - token required"); + await SendResponseAsync(client, error); + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error handling handshake for client {ClientId}", client.Id); + var error = JsonRpcMessage.MakeError(request.Id, -32603, "Internal error"); + await SendResponseAsync(client, error); + } + } + + private void SendKeepalive(object? state) + { + if (!_isStarted || _cancellation.Token.IsCancellationRequested) + return; + + var pingNotification = new JsonRpcMessage + { + Method = "ping", + Params = JsonSerializer.SerializeToElement(new { timestamp = DateTime.UtcNow }) + }; + + _ = Task.Run(async () => + { + try + { + await BroadcastAsync(pingNotification); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error sending keepalive ping"); + } + }); + } + + private void CleanupInactiveClients(object? state) + { + if (!_isStarted || _cancellation.Token.IsCancellationRequested) + return; + + var cutoff = DateTime.UtcNow.AddMinutes(-5d); + var toRemove = new List(); + + foreach (var client in _clients.Values) + { + if (client.LastSeenUtc < cutoff || client.WebSocket?.State != WebSocketState.Open) + { + toRemove.Add(client); + } + } + + foreach (var client in toRemove) + { + _clients.TryRemove(client.Id, out _); + client.Dispose(); + _logger.LogDebug("Cleaned up inactive client {ClientId}", client.Id); + } + } + + // Dispose + public void Dispose() + { + if (_disposed) + return; + + _cancellation.Cancel(); + _keepaliveTimer?.Dispose(); + _cleanupTimer?.Dispose(); + _httpListener?.Stop(); + _httpListener?.Close(); + _cancellation.Dispose(); + + foreach (var client in _clients.Values) + { + client.Dispose(); + } + + _disposed = true; + } + } +} \ No newline at end of file diff --git a/src/Files.App/Constants.cs b/src/Files.App/Constants.cs index c3a531e502dd..4c37e8972722 100644 --- a/src/Files.App/Constants.cs +++ b/src/Files.App/Constants.cs @@ -144,6 +144,25 @@ public static class Drives } } + public static class IpcSettings + { + public const long WebSocketMaxMessageBytes = 16L * 1024L * 1024L; // 16 MB + + public const long NamedPipeMaxMessageBytes = 10L * 1024L * 1024L; // 10 MB + + public const long PerClientQueueCapBytes = 2L * 1024L * 1024L; // 2 MB + + public const int RateLimitPerSecond = 20; + + public const int RateLimitBurst = 60; + + public const int SelectionNotificationCap = 200; + + public const int GetMetadataMaxItems = 500; + + public const int GetMetadataTimeoutSec = 30; + } + public static class LocalSettings { public const string DateTimeFormat = "datetimeformat"; @@ -223,7 +242,7 @@ public static class Actions public static class DragAndDrop { - public const Int32 HoverToOpenTimespan = 1300; + public const int HoverToOpenTimespan = 1300; } public static class UserEnvironmentPaths diff --git a/src/Files.App/ViewModels/ShellIpcAdapter.cs b/src/Files.App/ViewModels/ShellIpcAdapter.cs new file mode 100644 index 000000000000..3c6b134946a8 --- /dev/null +++ b/src/Files.App/ViewModels/ShellIpcAdapter.cs @@ -0,0 +1,476 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Files.App.Communication; +using Files.App.Communication.Models; +using Microsoft.Extensions.Logging; +using Microsoft.UI.Dispatching; + +namespace Files.App.ViewModels +{ + // Adapter with strict allowlist, path normalization, selection cap and structured errors. + public sealed class ShellIpcAdapter + { + // readonly fields + private readonly ShellViewModel _shell; + private readonly IAppCommunicationService _comm; + private readonly ActionRegistry _actions; + private readonly RpcMethodRegistry _methodRegistry; + private readonly UIOperationQueue _uiQueue; + private readonly ILogger _logger; + private readonly TimeSpan _coalesceWindow = TimeSpan.FromMilliseconds(100d); + + // Fields + private DateTime _lastWdmNotif = DateTime.MinValue; + + // Constructor + public ShellIpcAdapter( + ShellViewModel shell, + IAppCommunicationService comm, + ActionRegistry actions, + RpcMethodRegistry methodRegistry, + DispatcherQueue dispatcher, + ILogger logger) + { + _shell = shell ?? throw new ArgumentNullException(nameof(shell)); + _comm = comm ?? throw new ArgumentNullException(nameof(comm)); + _actions = actions ?? throw new ArgumentNullException(nameof(actions)); + _methodRegistry = methodRegistry ?? throw new ArgumentNullException(nameof(methodRegistry)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _uiQueue = new UIOperationQueue(dispatcher ?? throw new ArgumentNullException(nameof(dispatcher))); + + _comm.OnRequestReceived += HandleRequestAsync; + + _shell.WorkingDirectoryModified += Shell_WorkingDirectoryModified; + // Note: SelectionChanged event would need to be added to ShellViewModel or accessed via different mechanism + } + + // Private methods - Event handlers + private async void Shell_WorkingDirectoryModified(object? sender, WorkingDirectoryModifiedEventArgs e) + { + // Coalesce rapid directory changes + var now = DateTime.UtcNow; + if ((now - _lastWdmNotif) < _coalesceWindow) + return; + + _lastWdmNotif = now; + + var notification = new JsonRpcMessage + { + Method = "workingDirectoryChanged", + Params = JsonSerializer.SerializeToElement(new + { + path = NormalizePath(e.Path), + isValidPath = IsValidPath(e.Path) + }) + }; + + try + { + await _comm.BroadcastAsync(notification); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error broadcasting working directory changed notification"); + } + } + + private async Task HandleRequestAsync(ClientContext client, JsonRpcMessage request) + { + if (string.IsNullOrEmpty(request.Method)) + { + var error = JsonRpcMessage.MakeError(request.Id, -32600, "Invalid Request"); + await _comm.SendResponseAsync(client, error); + return; + } + + try + { + switch (request.Method) + { + case "getState": + await HandleGetStateAsync(client, request); + break; + + case "listActions": + await HandleListActionsAsync(client, request); + break; + + case "getMetadata": + await HandleGetMetadataAsync(client, request); + break; + + case "navigate": + await HandleNavigateAsync(client, request); + break; + + case "executeAction": + await HandleExecuteActionAsync(client, request); + break; + + default: + var error = JsonRpcMessage.MakeError(request.Id, -32601, "Method not found"); + await _comm.SendResponseAsync(client, error); + break; + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error handling request {Method} from client {ClientId}", request.Method, client.Id); + var error = JsonRpcMessage.MakeError(request.Id, -32603, "Internal error"); + await _comm.SendResponseAsync(client, error); + } + } + + private async Task HandleGetStateAsync(ClientContext client, JsonRpcMessage request) + { + try + { + var result = JsonRpcMessage.MakeResult(request.Id, new + { + currentPath = NormalizePath(_shell.FilesystemViewModel?.WorkingDirectory ?? string.Empty), + isValidPath = IsValidPath(_shell.FilesystemViewModel?.WorkingDirectory ?? string.Empty), + canNavigateBack = _shell.CanNavigateBackward, + canNavigateForward = _shell.CanNavigateForward, + selectedItemsCount = _shell.SlimContentPage?.SelectedItems?.Count ?? 0, + totalItemsCount = _shell.SlimContentPage?.FilesystemViewModel?.FilesAndFolders?.Count ?? 0 + }); + + await _comm.SendResponseAsync(client, result); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error getting application state"); + var error = JsonRpcMessage.MakeError(request.Id, -32603, "Failed to get application state"); + await _comm.SendResponseAsync(client, error); + } + } + + private async Task HandleListActionsAsync(ClientContext client, JsonRpcMessage request) + { + try + { + var actions = _actions.GetAllowedActions().ToArray(); + var result = JsonRpcMessage.MakeResult(request.Id, new + { + actions = actions, + count = actions.Length + }); + + await _comm.SendResponseAsync(client, result); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error listing actions"); + var error = JsonRpcMessage.MakeError(request.Id, -32603, "Failed to list actions"); + await _comm.SendResponseAsync(client, error); + } + } + + private async Task HandleGetMetadataAsync(ClientContext client, JsonRpcMessage request) + { + if (!request.Params.HasValue || !request.Params.Value.TryGetProperty("paths", out var pathsElement)) + { + var error = JsonRpcMessage.MakeError(request.Id, -32602, "Invalid params - paths array required"); + await _comm.SendResponseAsync(client, error); + return; + } + + try + { + var pathStrings = new List(); + if (pathsElement.ValueKind == JsonValueKind.Array) + { + foreach (var pathElement in pathsElement.EnumerateArray()) + { + var pathStr = pathElement.GetString(); + if (!string.IsNullOrEmpty(pathStr)) + pathStrings.Add(pathStr); + } + } + + // Cap the number of items to process + if (pathStrings.Count > IpcConfig.GetMetadataMaxItems) + pathStrings = pathStrings.Take(IpcConfig.GetMetadataMaxItems).ToList(); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(IpcConfig.GetMetadataTimeoutSec)); + var metadata = await GetMetadataForPathsAsync(pathStrings, cts.Token); + + var result = JsonRpcMessage.MakeResult(request.Id, new + { + metadata = metadata, + processed = metadata.Count, + total = pathStrings.Count + }); + + await _comm.SendResponseAsync(client, result); + } + catch (OperationCanceledException) + { + var error = JsonRpcMessage.MakeError(request.Id, -32603, "Request timeout - too many items or slow filesystem"); + await _comm.SendResponseAsync(client, error); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error getting metadata"); + var error = JsonRpcMessage.MakeError(request.Id, -32603, "Failed to get metadata"); + await _comm.SendResponseAsync(client, error); + } + } + + private async Task HandleNavigateAsync(ClientContext client, JsonRpcMessage request) + { + if (!request.Params.HasValue || !request.Params.Value.TryGetProperty("path", out var pathElement)) + { + var error = JsonRpcMessage.MakeError(request.Id, -32602, "Invalid params - path required"); + await _comm.SendResponseAsync(client, error); + return; + } + + var path = pathElement.GetString(); + if (string.IsNullOrEmpty(path)) + { + var error = JsonRpcMessage.MakeError(request.Id, -32602, "Invalid params - path cannot be empty"); + await _comm.SendResponseAsync(client, error); + return; + } + + var normalizedPath = NormalizePath(path); + if (!IsValidPath(normalizedPath)) + { + var error = JsonRpcMessage.MakeError(request.Id, -32602, "Invalid path - security check failed"); + await _comm.SendResponseAsync(client, error); + return; + } + + try + { + await _uiQueue.EnqueueAsync(async () => + { + // This would need to be implemented based on the actual ShellViewModel navigation methods + // await _shell.NavigateToPathAsync(normalizedPath); + _logger.LogInformation("Navigation to {Path} requested (not yet implemented)", normalizedPath); + }); + + var result = JsonRpcMessage.MakeResult(request.Id, new + { + success = true, + navigatedTo = normalizedPath + }); + + await _comm.SendResponseAsync(client, result); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error navigating to {Path}", normalizedPath); + var error = JsonRpcMessage.MakeError(request.Id, -32603, "Navigation failed"); + await _comm.SendResponseAsync(client, error); + } + } + + private async Task HandleExecuteActionAsync(ClientContext client, JsonRpcMessage request) + { + if (!request.Params.HasValue || !request.Params.Value.TryGetProperty("actionId", out var actionElement)) + { + var error = JsonRpcMessage.MakeError(request.Id, -32602, "Invalid params - actionId required"); + await _comm.SendResponseAsync(client, error); + return; + } + + var actionId = actionElement.GetString(); + if (string.IsNullOrEmpty(actionId)) + { + var error = JsonRpcMessage.MakeError(request.Id, -32602, "Invalid params - actionId cannot be empty"); + await _comm.SendResponseAsync(client, error); + return; + } + + if (!_actions.CanExecute(actionId)) + { + var error = JsonRpcMessage.MakeError(request.Id, -32602, "Action not allowed or not found"); + await _comm.SendResponseAsync(client, error); + return; + } + + try + { + // Extract optional context parameter + object? context = null; + if (request.Params.Value.TryGetProperty("context", out var contextElement)) + { + context = JsonSerializer.Deserialize(contextElement); + } + + await _uiQueue.EnqueueAsync(async () => + { + // This would need to be implemented based on the actual action execution system + // await _shell.ExecuteActionAsync(actionId, context); + _logger.LogInformation("Action {ActionId} execution requested (not yet implemented)", actionId); + }); + + var result = JsonRpcMessage.MakeResult(request.Id, new + { + success = true, + executedAction = actionId + }); + + await _comm.SendResponseAsync(client, result); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error executing action {ActionId}", actionId); + var error = JsonRpcMessage.MakeError(request.Id, -32603, "Action execution failed"); + await _comm.SendResponseAsync(client, error); + } + } + + // Private helper methods + private static string NormalizePath(string path) + { + if (string.IsNullOrEmpty(path)) + return string.Empty; + + try + { + // Normalize path separators and resolve relative components + var normalized = Path.GetFullPath(path); + return normalized; + } + catch + { + return path; // Return original if normalization fails + } + } + + private static bool IsValidPath(string path) + { + if (string.IsNullOrEmpty(path)) + return false; + + try + { + // Security checks: reject device paths, UNC admin shares, etc. + var upper = path.ToUpperInvariant(); + + // Reject device paths + if (upper.StartsWith(@"\\.\", StringComparison.Ordinal) || + upper.StartsWith(@"\\?\", StringComparison.Ordinal)) + return false; + + // Reject admin shares + if (upper.StartsWith(@"\\", StringComparison.Ordinal) && upper.Contains(@"\C$", StringComparison.Ordinal)) + return false; + + // Check for path traversal attempts + if (path.Contains("..") || path.Contains("~")) + return false; + + // Must be rooted (absolute path) + return Path.IsPathRooted(path); + } + catch + { + return false; + } + } + + private async Task> GetMetadataForPathsAsync(List paths, CancellationToken cancellationToken) + { + var results = new List(); + + foreach (var path in paths) + { + cancellationToken.ThrowIfCancellationRequested(); + + try + { + var normalizedPath = NormalizePath(path); + if (!IsValidPath(normalizedPath)) + { + results.Add(new ItemDto + { + Path = path, + Name = Path.GetFileName(path), + Exists = false + }); + continue; + } + + // Check if path exists + var exists = File.Exists(normalizedPath) || Directory.Exists(normalizedPath); + if (!exists) + { + results.Add(new ItemDto + { + Path = normalizedPath, + Name = Path.GetFileName(normalizedPath), + Exists = false + }); + continue; + } + + // Get metadata + var isDirectory = Directory.Exists(normalizedPath); + var info = isDirectory ? (FileSystemInfo)new DirectoryInfo(normalizedPath) : new FileInfo(normalizedPath); + + var item = new ItemDto + { + Path = normalizedPath, + Name = info.Name, + IsDirectory = isDirectory, + Exists = true, + DateCreated = info.CreationTime.ToString("yyyy-MM-ddTHH:mm:ssZ"), + DateModified = info.LastWriteTime.ToString("yyyy-MM-ddTHH:mm:ssZ") + }; + + if (!isDirectory) + { + var fileInfo = (FileInfo)info; + item.SizeBytes = fileInfo.Length; + item.MimeType = GetMimeType(normalizedPath); + } + + results.Add(item); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error getting metadata for path {Path}", path); + results.Add(new ItemDto + { + Path = path, + Name = Path.GetFileName(path), + Exists = false + }); + } + } + + return results; + } + + private static string? GetMimeType(string filePath) + { + var extension = Path.GetExtension(filePath)?.ToLowerInvariant(); + return extension switch + { + ".txt" => "text/plain", + ".json" => "application/json", + ".xml" => "application/xml", + ".html" => "text/html", + ".css" => "text/css", + ".js" => "application/javascript", + ".pdf" => "application/pdf", + ".jpg" or ".jpeg" => "image/jpeg", + ".png" => "image/png", + ".gif" => "image/gif", + ".mp4" => "video/mp4", + ".mp3" => "audio/mpeg", + ".zip" => "application/zip", + _ => "application/octet-stream" + }; + } + } +} \ No newline at end of file