From 4cda2ddd507855f8cb8d997720ca6ff9ae7932a0 Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 19 Oct 2025 14:49:32 +0000 Subject: [PATCH 01/20] IOCP --- .config/csharpier/.csharpierignore | 2 + .config/csharpier/.csharpierrc.json | 3 + .config/cspell/cspell.json | 18 ++ .config/dotnet/.globalconfig | 3 + .config/dotnet/Format.targets | 9 + .config/dotnet/Packages.props | 10 + .config/dotnet/Project.props | 20 ++ .config/dotnet/tools.json | 5 + .gitattributes => .config/git/attributes | 0 .config/git/ignore | 8 + .config/pnpm/rc | 3 + .config/prettier/.prettierrc.json | 9 + .config/workspaces/Directory.Build.props | 9 + .config/workspaces/pnpm-workspace.yaml | 0 .devcontainer/.env | 8 + .devcontainer/Dockerfile | 3 + .devcontainer/compose.yaml | 21 ++ .devcontainer/devcontainer.json | 39 +++ .devcontainer/dot-config.json | 15 + .gitignore | 5 - Automation/Automation.csproj | 9 + Automation/Build.cs | 16 + Automation/Context.cs | 14 + Automation/Format.cs | 9 + Automation/Lint.cs | 10 + Automation/Program.cs | 5 + Automation/Publish/Program.cs | 78 ----- Automation/Publish/Publish.csproj | 14 - Automation/Restore.cs | 14 + Automation/Restore/Program.cs | 33 -- Automation/Restore/Restore.csproj | 14 - Automation/Tasks/CSharpier.cs | 20 ++ Automation/Tasks/CSpell.cs | 12 + Automation/Tasks/DotNetFormat.cs | 26 ++ Automation/Tasks/Prettier.cs | 20 ++ .../UpdatePackages/Directory.Packages.props | 10 - Automation/UpdatePackages/Program.cs | 89 ------ .../UpdatePackages/UpdatePackages.csproj | 15 - Directory.Build.props | 26 +- Divert.Windows/AssemblyInfo.cs | 3 + Divert.Windows/CString.cs | 9 +- Divert.Windows/Divert.Windows.csproj | 13 +- Divert.Windows/DivertAddress.cs | 63 ++-- Divert.Windows/DivertEvent.cs | 20 +- Divert.Windows/DivertFilter.cs | 48 ++- Divert.Windows/DivertFlags.cs | 2 +- Divert.Windows/DivertHandle.cs | 16 + Divert.Windows/DivertHelper.cs | 9 +- Divert.Windows/DivertHelperFlags.cs | 2 +- Divert.Windows/DivertLayer.cs | 10 +- Divert.Windows/DivertReceiveResult.cs | 7 + Divert.Windows/DivertService.cs | 283 ++++++----------- Divert.Windows/DivertShutdown.cs | 20 +- Divert.Windows/DivertValueTaskSource.cs | 287 ++++++++++++++++++ Divert.Windows/NativeMethods.cs | 96 +++--- Divert.Windows/NativeMethods.txt | 2 + Divert.Windows/NativeTypes.cs | 12 +- Examples/Ping/Ping.csproj | 16 +- Examples/Ping/Program.cs | 68 +++-- License => LICENSE | 0 Workspace.proj | 6 + package.json | 19 ++ 62 files changed, 992 insertions(+), 643 deletions(-) create mode 100644 .config/csharpier/.csharpierignore create mode 100644 .config/csharpier/.csharpierrc.json create mode 100644 .config/cspell/cspell.json create mode 100644 .config/dotnet/.globalconfig create mode 100644 .config/dotnet/Format.targets create mode 100644 .config/dotnet/Packages.props create mode 100644 .config/dotnet/Project.props create mode 100644 .config/dotnet/tools.json rename .gitattributes => .config/git/attributes (100%) create mode 100644 .config/git/ignore create mode 100644 .config/pnpm/rc create mode 100644 .config/prettier/.prettierrc.json create mode 100644 .config/workspaces/Directory.Build.props create mode 100644 .config/workspaces/pnpm-workspace.yaml create mode 100644 .devcontainer/.env create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/compose.yaml create mode 100644 .devcontainer/devcontainer.json create mode 100644 .devcontainer/dot-config.json delete mode 100644 .gitignore create mode 100644 Automation/Automation.csproj create mode 100644 Automation/Build.cs create mode 100644 Automation/Context.cs create mode 100644 Automation/Format.cs create mode 100644 Automation/Lint.cs create mode 100644 Automation/Program.cs delete mode 100644 Automation/Publish/Program.cs delete mode 100644 Automation/Publish/Publish.csproj create mode 100644 Automation/Restore.cs delete mode 100644 Automation/Restore/Program.cs delete mode 100644 Automation/Restore/Restore.csproj create mode 100644 Automation/Tasks/CSharpier.cs create mode 100644 Automation/Tasks/CSpell.cs create mode 100644 Automation/Tasks/DotNetFormat.cs create mode 100644 Automation/Tasks/Prettier.cs delete mode 100644 Automation/UpdatePackages/Directory.Packages.props delete mode 100644 Automation/UpdatePackages/Program.cs delete mode 100644 Automation/UpdatePackages/UpdatePackages.csproj create mode 100644 Divert.Windows/AssemblyInfo.cs create mode 100644 Divert.Windows/DivertHandle.cs create mode 100644 Divert.Windows/DivertReceiveResult.cs create mode 100644 Divert.Windows/DivertValueTaskSource.cs rename License => LICENSE (100%) create mode 100644 Workspace.proj create mode 100644 package.json diff --git a/.config/csharpier/.csharpierignore b/.config/csharpier/.csharpierignore new file mode 100644 index 0000000..51045e2 --- /dev/null +++ b/.config/csharpier/.csharpierignore @@ -0,0 +1,2 @@ +**/* +!**/*.cs diff --git a/.config/csharpier/.csharpierrc.json b/.config/csharpier/.csharpierrc.json new file mode 100644 index 0000000..963354f --- /dev/null +++ b/.config/csharpier/.csharpierrc.json @@ -0,0 +1,3 @@ +{ + "printWidth": 120 +} diff --git a/.config/cspell/cspell.json b/.config/cspell/cspell.json new file mode 100644 index 0000000..55bcb88 --- /dev/null +++ b/.config/cspell/cspell.json @@ -0,0 +1,18 @@ +{ + "version": "0.2", + "enableGlobDot": true, + "useGitignore": true, + "gitignoreRoot": ".", + "ignorePaths": ["LICENSE"], + "words": [ + "csharpierignore", + "csharpierrc", + "devcontainer", + "devcontainers", + "globalconfig", + "msbuild", + "packagejson", + "runas", + "WINDIVERT" + ] +} diff --git a/.config/dotnet/.globalconfig b/.config/dotnet/.globalconfig new file mode 100644 index 0000000..9cd5137 --- /dev/null +++ b/.config/dotnet/.globalconfig @@ -0,0 +1,3 @@ +dotnet_diagnostic.IDE0005.severity=warning +dotnet_diagnostic.CA1852.severity=warning +dotnet_diagnostic.CA2007.severity=warning diff --git a/.config/dotnet/Format.targets b/.config/dotnet/Format.targets new file mode 100644 index 0000000..ddcc8d5 --- /dev/null +++ b/.config/dotnet/Format.targets @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/.config/dotnet/Packages.props b/.config/dotnet/Packages.props new file mode 100644 index 0000000..9e360e8 --- /dev/null +++ b/.config/dotnet/Packages.props @@ -0,0 +1,10 @@ + + + true + + + + + + + diff --git a/.config/dotnet/Project.props b/.config/dotnet/Project.props new file mode 100644 index 0000000..f9260ca --- /dev/null +++ b/.config/dotnet/Project.props @@ -0,0 +1,20 @@ + + + enable + enable + + + + + + + + + + true + true + embedded + true + $(MSBuildProjectDirectory)=/_/$(MSBuildProjectName) + + diff --git a/.config/dotnet/tools.json b/.config/dotnet/tools.json new file mode 100644 index 0000000..7f63919 --- /dev/null +++ b/.config/dotnet/tools.json @@ -0,0 +1,5 @@ +{ + "version": 1, + "isRoot": true, + "tools": { "csharpier": { "version": "1.1.2", "commands": ["csharpier"], "rollForward": false } } +} diff --git a/.gitattributes b/.config/git/attributes similarity index 100% rename from .gitattributes rename to .config/git/attributes diff --git a/.config/git/ignore b/.config/git/ignore new file mode 100644 index 0000000..d1c0ab2 --- /dev/null +++ b/.config/git/ignore @@ -0,0 +1,8 @@ +.* +_* +!.devcontainer/ +!**/.devcontainer/** +!.config/ +!**/.config/** + +node_modules/ diff --git a/.config/pnpm/rc b/.config/pnpm/rc new file mode 100644 index 0000000..863fba9 --- /dev/null +++ b/.config/pnpm/rc @@ -0,0 +1,3 @@ +lockfile=false +resolution-mode=time-based +store-dir=/home/dev/.local/share/pnpm/store diff --git a/.config/prettier/.prettierrc.json b/.config/prettier/.prettierrc.json new file mode 100644 index 0000000..b8ae6eb --- /dev/null +++ b/.config/prettier/.prettierrc.json @@ -0,0 +1,9 @@ +{ + "printWidth": 120, + "plugins": ["prettier-plugin-packagejson", "prettier-plugin-sh", "@prettier/plugin-xml", "prettier-plugin-ini"], + "xmlWhitespaceSensitivity": "ignore", + "overrides": [ + { "files": "app.manifest", "options": { "parser": "xml" } }, + { "files": "*.globalconfig", "options": { "parser": "ini" } } + ] +} diff --git a/.config/workspaces/Directory.Build.props b/.config/workspaces/Directory.Build.props new file mode 100644 index 0000000..b7d1bef --- /dev/null +++ b/.config/workspaces/Directory.Build.props @@ -0,0 +1,9 @@ + + + $(MSBuildThisFileDirectory)artifacts + + + + + + diff --git a/.config/workspaces/pnpm-workspace.yaml b/.config/workspaces/pnpm-workspace.yaml new file mode 100644 index 0000000..e69de29 diff --git a/.devcontainer/.env b/.devcontainer/.env new file mode 100644 index 0000000..9ae87c3 --- /dev/null +++ b/.devcontainer/.env @@ -0,0 +1,8 @@ +WORKSPACES=/workspaces +XDG_CONFIG_HOME=/home/dev/.config +XDG_CACHE_HOME=/home/dev/.cache +XDG_DATA_HOME=/home/dev/.local/share +XDG_STATE_HOME=/home/dev/.local/state +XDG_DATA_DIRS=/usr/local/share:/usr/share +XDG_CONFIG_DIRS=/etc/xdg +NUGET_PACKAGES=/home/dev/.local/share/NuGet/global-packages diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000..cc342aa --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,3 @@ +FROM mcr.microsoft.com/devcontainers/javascript-node:22 + +RUN npm install --global pnpm@latest-10 diff --git a/.devcontainer/compose.yaml b/.devcontainer/compose.yaml new file mode 100644 index 0000000..8cabace --- /dev/null +++ b/.devcontainer/compose.yaml @@ -0,0 +1,21 @@ +services: + devcontainer: + env_file: + - .env + build: + context: . + dockerfile: Dockerfile + init: true + volumes: + - WORKSPACES:${WORKSPACES} + - ..:${WORKSPACES}/divert-windows + - XDG_CONFIG_HOME:${XDG_CONFIG_HOME} + - XDG_CACHE_HOME:${XDG_CACHE_HOME} + - XDG_DATA_HOME:${XDG_DATA_HOME} + - XDG_STATE_HOME:${XDG_STATE_HOME} +volumes: + WORKSPACES: + XDG_CONFIG_HOME: + XDG_CACHE_HOME: + XDG_DATA_HOME: + XDG_STATE_HOME: diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..bb26252 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,39 @@ +// spell-checker:ignore esbenp dotnettools +{ + "name": "divert-windows", + "dockerComposeFile": "compose.yaml", + "service": "devcontainer", + "remoteUser": "dev", + "overrideCommand": true, + "workspaceFolder": "/workspaces/divert-windows", + "features": { + "ghcr.io/devcontainer-config/features/user-init:2": {}, + "ghcr.io/devcontainer-config/features/dot-config:3": {}, + "ghcr.io/devcontainers/features/dotnet:2": { "version": "8.0" } + }, + "customizations": { + "vscode": { + "extensions": [ + "esbenp.prettier-vscode", + "ms-azuretools.vscode-docker", + "streetsidesoftware.code-spell-checker", + "ms-dotnettools.csharp", + "csharpier.csharpier-vscode" + ], + "settings": { + "files.associations": { + "ignore": "ignore", + "attributes": "properties", + "rc": "properties", + "*.globalconfig": "ini" + }, + "editor.formatOnSave": true, + "editor.defaultFormatter": "esbenp.prettier-vscode", + "cSpell.autoFormatConfigFile": true, + "cSpell.checkOnlyEnabledFileTypes": false, + "[csharp]": { "editor.defaultFormatter": "csharpier.csharpier-vscode" } + } + } + }, + "onCreateCommand": "pnpm install && pnpm restore || true" +} diff --git a/.devcontainer/dot-config.json b/.devcontainer/dot-config.json new file mode 100644 index 0000000..cda4619 --- /dev/null +++ b/.devcontainer/dot-config.json @@ -0,0 +1,15 @@ +{ + "git": { "attributes": "/home/dev/.config/git/attributes", "ignore": "/home/dev/.config/git/ignore" }, + "pnpm": { "rc": "/home/dev/.config/pnpm/rc" }, + "prettier": { ".prettierrc.json": ".prettierrc.json" }, + "cspell": { "cspell.json": "cspell.json" }, + "workspaces": { + "../git/attributes": ".gitattributes", + "../git/ignore": ".gitignore", + "../../package.json": "package.json", + "pnpm-workspace.yaml": "pnpm-workspace.yaml", + "Directory.Build.props": "Directory.Build.props" + }, + "csharpier": { ".csharpierrc.json": ".csharpierrc.json", ".csharpierignore": ".csharpierignore" }, + "dotnet": { "tools.json": ".config/dotnet-tools.json" } +} diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 2246bd7..0000000 --- a/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -.*/ -obj/ -bin/ -/Build/ -/Publish/ diff --git a/Automation/Automation.csproj b/Automation/Automation.csproj new file mode 100644 index 0000000..e9c5ec0 --- /dev/null +++ b/Automation/Automation.csproj @@ -0,0 +1,9 @@ + + + Exe + $(DefaultTargetFramework) + + + + + diff --git a/Automation/Build.cs b/Automation/Build.cs new file mode 100644 index 0000000..3736eb1 --- /dev/null +++ b/Automation/Build.cs @@ -0,0 +1,16 @@ +using Cake.Common.Tools.DotNet; +using Cake.Common.Tools.DotNet.MSBuild; +using Cake.Frosting; + +namespace Automation; + +public class Build : FrostingTask +{ + public override void Run(Context context) + { + context.DotNetBuild( + Context.ProjectRoot, + new() { MSBuildSettings = new() { TreatAllWarningsAs = MSBuildTreatAllWarningsAs.Error } } + ); + } +} diff --git a/Automation/Context.cs b/Automation/Context.cs new file mode 100644 index 0000000..a2ac672 --- /dev/null +++ b/Automation/Context.cs @@ -0,0 +1,14 @@ +using System.Runtime.CompilerServices; +using Cake.Core; +using Cake.Frosting; + +namespace Automation; + +public class Context(ICakeContext context) : FrostingContext(context) +{ + static string GetFilePath([CallerFilePath] string? path = null) => path!; + + public static string ProjectRoot => new FileInfo(GetFilePath()).Directory!.Parent!.FullName; + + public static string Workspaces => new DirectoryInfo(ProjectRoot).Parent!.FullName; +} diff --git a/Automation/Format.cs b/Automation/Format.cs new file mode 100644 index 0000000..96bee98 --- /dev/null +++ b/Automation/Format.cs @@ -0,0 +1,9 @@ +using Automation.Tasks; +using Cake.Frosting; + +namespace Automation; + +[IsDependentOn(typeof(PrettierFormat))] +[IsDependentOn(typeof(DotNetFormat))] +[IsDependentOn(typeof(CSharpierFormat))] +public class Format : FrostingTask; diff --git a/Automation/Lint.cs b/Automation/Lint.cs new file mode 100644 index 0000000..5ebf948 --- /dev/null +++ b/Automation/Lint.cs @@ -0,0 +1,10 @@ +using Automation.Tasks; +using Cake.Frosting; + +namespace Automation; + +[IsDependentOn(typeof(PrettierCheck))] +[IsDependentOn(typeof(DotNetFormatCheck))] +[IsDependentOn(typeof(CSharpierCheck))] +[IsDependentOn(typeof(CSpell))] +public class Lint : FrostingTask; diff --git a/Automation/Program.cs b/Automation/Program.cs new file mode 100644 index 0000000..f2fde84 --- /dev/null +++ b/Automation/Program.cs @@ -0,0 +1,5 @@ +using Automation; +using Cake.Frosting; + +Directory.SetCurrentDirectory(Context.Workspaces); +return new CakeHost().UseContext().Run(args.Concat(["--verbosity", "diagnostic"])); diff --git a/Automation/Publish/Program.cs b/Automation/Publish/Program.cs deleted file mode 100644 index 35d589f..0000000 --- a/Automation/Publish/Program.cs +++ /dev/null @@ -1,78 +0,0 @@ -using System.ComponentModel; -using System.Runtime.CompilerServices; -using Microsoft.DotNet.Cli.Utils; - -static string GetFilePath([CallerFilePath] string? path = null) -{ - if (path is null) - { - throw new InvalidOperationException(nameof(path)); - } - return path; -} - -static void Run(string commandName, params string[] args) -{ - var command = Command.Create(commandName, args); - Console.WriteLine($"{commandName} {command.CommandArgs}"); - var result = command.Execute(); - if (result.ExitCode != 0) - { - throw new Win32Exception(result.ExitCode); - } -} - -static string GetOutput(string commandName, params string[] args) -{ - var command = Command.Create(commandName, args).CaptureStdOut(); - var result = command.Execute(); - if (result.ExitCode != 0) - { - throw new Win32Exception(result.ExitCode); - } - return result.StdOut; -} - -string version = args.Length > 0 ? args[0] : "1.0.0"; - -string filePath = GetFilePath(); -string projectPath = new FileInfo(filePath).Directory?.Parent?.Parent?.FullName!; -Console.WriteLine($"Project path: {projectPath}"); - -string publishPath = Path.Combine(projectPath, "Publish"); -if (Directory.Exists(publishPath)) -{ - Directory.Delete(publishPath, recursive: true); -} - -// Get metadata from Git. -string userName = GetOutput("git", "config", "user.name").Trim(); -var remotes = GetOutput("git", "remote").Trim(); -string? repositoryUrl = remotes.Split('\n').FirstOrDefault() switch -{ - null or "" => null, - string remote => GetOutput("git", "remote", "get-url", remote.Trim()).Trim() -}; -Console.WriteLine($"{nameof(userName)}: {userName}"); -Console.WriteLine($"{nameof(repositoryUrl)}: {repositoryUrl}"); -string projectName = "Divert.Windows"; -string description = "WinDivert .NET APIs."; - -var arguments = new List -{ - "pack", - Path.Combine(projectPath, projectName, $"{projectName}.csproj"), - "--configuration", "Release", - "--output", publishPath, - $"-property:PackageVersion={version}", - $"-property:Authors={userName}", - $"-property:PackageDescription={description}", - "-property:PackageLicenseExpression=MIT", - "-property:PackageRequireLicenseAcceptance=true", - "-property:PackageTags=WinDivert", -}; -if (repositoryUrl is not null) -{ - arguments.Add($"-property:RepositoryUrl={repositoryUrl}"); -} -Run("dotnet", arguments.ToArray()); diff --git a/Automation/Publish/Publish.csproj b/Automation/Publish/Publish.csproj deleted file mode 100644 index 7dbbd88..0000000 --- a/Automation/Publish/Publish.csproj +++ /dev/null @@ -1,14 +0,0 @@ - - - - Exe - net6.0 - enable - enable - - - - - - - \ No newline at end of file diff --git a/Automation/Restore.cs b/Automation/Restore.cs new file mode 100644 index 0000000..50011cb --- /dev/null +++ b/Automation/Restore.cs @@ -0,0 +1,14 @@ +using Cake.Common.Tools.Command; +using Cake.Common.Tools.DotNet; +using Cake.Frosting; + +namespace Automation; + +public class Restore : FrostingTask +{ + public override void Run(Context context) + { + context.Command(["dotnet"], "tool restore"); + context.DotNetRestore(Context.ProjectRoot); + } +} diff --git a/Automation/Restore/Program.cs b/Automation/Restore/Program.cs deleted file mode 100644 index 7a665d5..0000000 --- a/Automation/Restore/Program.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System.ComponentModel; -using System.Runtime.CompilerServices; -using Microsoft.DotNet.Cli.Utils; - -static string GetFilePath([CallerFilePath] string? path = null) -{ - if (path is null) - { - throw new InvalidOperationException(nameof(path)); - } - return path; -} - -static void Run(string commandName, params string[] args) -{ - var command = Command.Create(commandName, args); - Console.WriteLine($"{commandName} {command.CommandArgs}"); - var result = command.Execute(); - if (result.ExitCode != 0) - { - throw new Win32Exception(result.ExitCode); - } -} - -string filePath = GetFilePath(); -string workspacePath = new FileInfo(filePath).Directory?.Parent?.Parent?.FullName!; -Console.WriteLine($"Workspace path: {workspacePath}"); - -Run("dotnet", "nuget", "locals", "http-cache", "--clear"); -Parallel.ForEach(Directory.EnumerateFiles(workspacePath, "*.csproj", SearchOption.AllDirectories), csProjectPath => -{ - Run("dotnet", "restore", csProjectPath); -}); diff --git a/Automation/Restore/Restore.csproj b/Automation/Restore/Restore.csproj deleted file mode 100644 index 7dbbd88..0000000 --- a/Automation/Restore/Restore.csproj +++ /dev/null @@ -1,14 +0,0 @@ - - - - Exe - net6.0 - enable - enable - - - - - - - \ No newline at end of file diff --git a/Automation/Tasks/CSharpier.cs b/Automation/Tasks/CSharpier.cs new file mode 100644 index 0000000..b099697 --- /dev/null +++ b/Automation/Tasks/CSharpier.cs @@ -0,0 +1,20 @@ +using Cake.Common.Tools.DotNet; +using Cake.Frosting; + +namespace Automation.Tasks; + +public class CSharpierCheck : FrostingTask +{ + public override void Run(Context context) + { + context.DotNetTool(Context.ProjectRoot, "csharpier", $"check {Context.ProjectRoot}"); + } +} + +public class CSharpierFormat : FrostingTask +{ + public override void Run(Context context) + { + context.DotNetTool(Context.ProjectRoot, "csharpier", $"format {Context.ProjectRoot}"); + } +} diff --git a/Automation/Tasks/CSpell.cs b/Automation/Tasks/CSpell.cs new file mode 100644 index 0000000..75a23d4 --- /dev/null +++ b/Automation/Tasks/CSpell.cs @@ -0,0 +1,12 @@ +using Cake.Common.Tools.Command; +using Cake.Frosting; + +namespace Automation.Tasks; + +public class CSpell : FrostingTask +{ + public override void Run(Context context) + { + context.Command(["pnpm"], $"cspell {Context.ProjectRoot}"); + } +} diff --git a/Automation/Tasks/DotNetFormat.cs b/Automation/Tasks/DotNetFormat.cs new file mode 100644 index 0000000..f698863 --- /dev/null +++ b/Automation/Tasks/DotNetFormat.cs @@ -0,0 +1,26 @@ +using Cake.Common.Tools.DotNet; +using Cake.Frosting; + +namespace Automation.Tasks; + +public class DotNetFormatCheck : FrostingTask +{ + public override void Run(Context context) + { + context.DotNetMSBuild( + Context.ProjectRoot, + new() { Targets = { "GetTargetPath" }, Properties = { ["DotNetFormatCheck"] = ["true"] } } + ); + } +} + +public class DotNetFormat : FrostingTask +{ + public override void Run(Context context) + { + context.DotNetMSBuild( + Context.ProjectRoot, + new() { Targets = { "GetTargetPath" }, Properties = { ["DotNetFormat"] = ["true"] } } + ); + } +} diff --git a/Automation/Tasks/Prettier.cs b/Automation/Tasks/Prettier.cs new file mode 100644 index 0000000..3b7fe1f --- /dev/null +++ b/Automation/Tasks/Prettier.cs @@ -0,0 +1,20 @@ +using Cake.Common.Tools.Command; +using Cake.Frosting; + +namespace Automation.Tasks; + +public class PrettierCheck : FrostingTask +{ + public override void Run(Context context) + { + context.Command(["pnpm"], $"prettier --check {Context.ProjectRoot}"); + } +} + +public class PrettierFormat : FrostingTask +{ + public override void Run(Context context) + { + context.Command(["pnpm"], $"prettier --write {Context.ProjectRoot}"); + } +} diff --git a/Automation/UpdatePackages/Directory.Packages.props b/Automation/UpdatePackages/Directory.Packages.props deleted file mode 100644 index a6ee229..0000000 --- a/Automation/UpdatePackages/Directory.Packages.props +++ /dev/null @@ -1,10 +0,0 @@ - - - false - - - - - - - \ No newline at end of file diff --git a/Automation/UpdatePackages/Program.cs b/Automation/UpdatePackages/Program.cs deleted file mode 100644 index 723d72d..0000000 --- a/Automation/UpdatePackages/Program.cs +++ /dev/null @@ -1,89 +0,0 @@ -using System.Collections.Immutable; -using System.Runtime.CompilerServices; -using Microsoft.Build.Construction; - -static string GetFilePath([CallerFilePath] string? path = null) -{ - if (path is null) - { - throw new InvalidOperationException(nameof(path)); - } - return path; -} - -string filePath = GetFilePath(); -string workspacePath = new FileInfo(filePath).Directory?.Parent?.Parent?.FullName!; -Console.WriteLine($"Workspace path: {workspacePath}"); - -// Load package versions. -string packagePropsFileName = "Directory.Packages.props"; -var packageVersions = await Task.Run(() => -{ - string packagePropsPath = Path.Combine(Path.GetDirectoryName(filePath)!, packagePropsFileName); - var props = ProjectRootElement.Open(packagePropsPath); - var result = new Dictionary(); - foreach (var item in props.Items) - { - if (item is - { - ElementName: "PackageVersion", - Include: string packageName, - FirstChild: ProjectMetadataElement - { - Name: "Version", - Value: string version - } - }) - { - result.Add(packageName, version); - } - } - return result; -}); - -// Update package versions in .csproj files. -var foundPackages = new HashSet(); -foreach (var csProjectPath in Directory.EnumerateFiles(workspacePath, "*.csproj", SearchOption.AllDirectories)) -{ - Console.WriteLine(csProjectPath); - var projectRoot = ProjectRootElement.Open(csProjectPath, new(), preserveFormatting: true); - - foreach (var item in projectRoot.Items) - { - if (item is - { - ElementName: "PackageReference", - Include: string packageName, - FirstChild: ProjectMetadataElement packageVersion and - { - Name: "Version", - Value: string version - } - }) - { - if (packageVersions.TryGetValue(packageName, out string? specifiedVersion)) - { - foundPackages.Add(packageName); - if (version != specifiedVersion) - { - Console.WriteLine($"Updating {packageName} version from {version} to {specifiedVersion}."); - packageVersion.Value = specifiedVersion; - } - } - else - { - Console.WriteLine($"{packageName} version is not specified in {packagePropsFileName}."); - } - } - } - - projectRoot.Save(); - Console.WriteLine(); -} - -foreach (var packageName in packageVersions.Keys.Except(foundPackages).ToImmutableSortedSet()) -{ - Console.WriteLine($"Package {packageName} is not referenced."); -} - -Console.WriteLine("Done."); diff --git a/Automation/UpdatePackages/UpdatePackages.csproj b/Automation/UpdatePackages/UpdatePackages.csproj deleted file mode 100644 index 8379413..0000000 --- a/Automation/UpdatePackages/UpdatePackages.csproj +++ /dev/null @@ -1,15 +0,0 @@ - - - - Exe - net6.0 - preview - enable - enable - - - - - - - \ No newline at end of file diff --git a/Directory.Build.props b/Directory.Build.props index d936144..2b3bcf1 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -1,20 +1,12 @@ - - embedded - true - $(MSBuildProjectDirectory)=$(MSBuildProjectName) - - - FreeBSD - Linux - OSX - Windows - Unknown - + + - $([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildProjectDirectory), "Directory.Build.props")) - $([MSBuild]::MakeRelative($(WorkspacePath), $(MSBuildProjectDirectory))) - $(WorkspacePath)/Build/$(OSName)/$(MSBuildProjectRelativePath)/obj - $(WorkspacePath)/Build/$(OSName)/$(MSBuildProjectRelativePath)/bin + $(MSBuildThisFileDirectory) + $([MSBuild]::NormalizeDirectory($(ProjectRoot), ".config", "dotnet")) + $(ConfigDirectory)Packages.props + net8.0 - \ No newline at end of file + + + diff --git a/Divert.Windows/AssemblyInfo.cs b/Divert.Windows/AssemblyInfo.cs new file mode 100644 index 0000000..654c4a8 --- /dev/null +++ b/Divert.Windows/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.Versioning; + +[assembly: SupportedOSPlatform("windows6.0.6000")] diff --git a/Divert.Windows/CString.cs b/Divert.Windows/CString.cs index 6f593f8..7dcdcca 100644 --- a/Divert.Windows/CString.cs +++ b/Divert.Windows/CString.cs @@ -2,14 +2,9 @@ namespace Divert.Windows; -internal class CString : IDisposable +internal sealed class CString(string str) : IDisposable { - internal IntPtr Ptr { get; } - - public CString(string str) - { - Ptr = Marshal.StringToHGlobalAnsi(str); - } + internal IntPtr Ptr { get; } = Marshal.StringToHGlobalAnsi(str); private bool disposed; diff --git a/Divert.Windows/Divert.Windows.csproj b/Divert.Windows/Divert.Windows.csproj index 6ae7dc6..ee50c1c 100644 --- a/Divert.Windows/Divert.Windows.csproj +++ b/Divert.Windows/Divert.Windows.csproj @@ -1,19 +1,14 @@ - - net6.0 + $(DefaultTargetFramework) x64 - enable - enable true - + all - runtime; build; native; contentfiles; analyzers - + - - \ No newline at end of file + diff --git a/Divert.Windows/DivertAddress.cs b/Divert.Windows/DivertAddress.cs index 7a8c5d7..334d9a0 100644 --- a/Divert.Windows/DivertAddress.cs +++ b/Divert.Windows/DivertAddress.cs @@ -1,10 +1,11 @@ using System.Net; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; namespace Divert.Windows; [StructLayout(LayoutKind.Sequential)] -unsafe public struct DivertAddress +public unsafe struct DivertAddress { public struct NetworkData { @@ -47,8 +48,6 @@ public struct ReflectData private WINDIVERT_ADDRESS address; - internal WINDIVERT_ADDRESS Struct => address; - internal DivertAddress(WINDIVERT_ADDRESS address) { this.address = address; @@ -61,30 +60,38 @@ public DivertAddress(int interfaceIndex, int subInterfaceIndex) Network = new WINDIVERT_DATA_NETWORK { IfIdx = checked((uint)interfaceIndex), - SubIfIdx = checked((uint)subInterfaceIndex) - } + SubIfIdx = checked((uint)subInterfaceIndex), + }, }; } + public void Reset() + { + fixed (WINDIVERT_ADDRESS* pAddress = &address) + { + Unsafe.InitBlock(pAddress, 0, (uint)sizeof(WINDIVERT_ADDRESS)); + } + } + public long Timestamp { - get { return address.Timestamp; } + readonly get { return address.Timestamp; } set { address.Timestamp = value; } } public DivertLayer Layer { - get { return (DivertLayer)address.Layer; } + readonly get { return (DivertLayer)address.Layer; } set { address.Layer = (byte)value; } } public DivertEvent Event { - get { return (DivertEvent)address.Event; } + readonly get { return (DivertEvent)address.Event; } set { address.Layer = (byte)value; } } - private bool GetBit(WINDIVERT_ADDRESS_BITS bit) + private readonly bool GetBit(WINDIVERT_ADDRESS_BITS bit) { return (address.Bits & bit) != 0; } @@ -103,49 +110,49 @@ private void SetBit(WINDIVERT_ADDRESS_BITS bit, bool value) public bool IsSniffed { - get { return GetBit(WINDIVERT_ADDRESS_BITS.Sniffed); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.Sniffed); } set { SetBit(WINDIVERT_ADDRESS_BITS.Sniffed, value); } } public bool IsOutbound { - get { return GetBit(WINDIVERT_ADDRESS_BITS.Outbound); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.Outbound); } set { SetBit(WINDIVERT_ADDRESS_BITS.Outbound, value); } } public bool IsLoopback { - get { return GetBit(WINDIVERT_ADDRESS_BITS.Loopback); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.Loopback); } set { SetBit(WINDIVERT_ADDRESS_BITS.Loopback, value); } } public bool IsImpostor { - get { return GetBit(WINDIVERT_ADDRESS_BITS.Impostor); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.Impostor); } set { SetBit(WINDIVERT_ADDRESS_BITS.Impostor, value); } } public bool IsIPv6 { - get { return GetBit(WINDIVERT_ADDRESS_BITS.IPv6); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.IPv6); } set { SetBit(WINDIVERT_ADDRESS_BITS.IPv6, value); } } public bool IsIPChecksumValid { - get { return GetBit(WINDIVERT_ADDRESS_BITS.IPChecksum); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.IPChecksum); } set { SetBit(WINDIVERT_ADDRESS_BITS.IPChecksum, value); } } public bool IsTCPChecksumValid { - get { return GetBit(WINDIVERT_ADDRESS_BITS.TCPChecksum); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.TCPChecksum); } set { SetBit(WINDIVERT_ADDRESS_BITS.TCPChecksum, value); } } public bool IsUDPChecksumValid { - get { return GetBit(WINDIVERT_ADDRESS_BITS.UDPChecksum); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.UDPChecksum); } set { SetBit(WINDIVERT_ADDRESS_BITS.UDPChecksum, value); } } @@ -156,16 +163,18 @@ public NetworkData GetNetworkData() DivertLayer.Network or DivertLayer.Forward => new NetworkData { InterfaceIndex = address.Network.IfIdx, - SubInterfaceIndex = address.Network.SubIfIdx + SubInterfaceIndex = address.Network.SubIfIdx, }, _ => throw new InvalidOperationException($"{nameof(Layer)}: {Layer}"), }; } - private IPAddress GetIPAddress(Span bytes) + private readonly IPAddress GetIPAddress(Span bytes) { - bytes.Reverse(); - var address = new IPAddress(bytes); + Span beBytes = stackalloc byte[bytes.Length]; + bytes.CopyTo(beBytes); + beBytes.Reverse(); + var address = new IPAddress(beBytes); if (!IsIPv6) { address = address.MapToIPv4(); @@ -188,10 +197,10 @@ public FlowData GetFlowData() RemoteAddress = GetIPAddress(new Span(flow.RemoteAddr, 16)), LocalPort = address.Flow.LocalPort, RemotePort = address.Flow.RemotePort, - Protocol = address.Flow.Protocol + Protocol = address.Flow.Protocol, }; default: - throw new InvalidOperationException(Layer.ToString()); + throw new InvalidOperationException($"{nameof(Layer)}: {Layer}"); } } @@ -210,10 +219,10 @@ public SocketData GetSocketData() RemoteAddress = GetIPAddress(new Span(socket.RemoteAddr, 16)), LocalPort = address.Socket.LocalPort, RemotePort = address.Socket.RemotePort, - Protocol = address.Socket.Protocol + Protocol = address.Socket.Protocol, }; default: - throw new InvalidOperationException(Layer.ToString()); + throw new InvalidOperationException($"{nameof(Layer)}: {Layer}"); } } @@ -227,9 +236,9 @@ public ReflectData GetReflectData() ProcessId = address.Reflect.ProcessId, Layer = (DivertLayer)address.Reflect.Layer, Flags = (DivertFlags)address.Reflect.Flags, - Priority = address.Reflect.Priority + Priority = address.Reflect.Priority, }, - _ => throw new InvalidOperationException(Layer.ToString()), + _ => throw new InvalidOperationException($"{nameof(Layer)}: {Layer}"), }; } } diff --git a/Divert.Windows/DivertEvent.cs b/Divert.Windows/DivertEvent.cs index 9342c58..3d1b351 100644 --- a/Divert.Windows/DivertEvent.cs +++ b/Divert.Windows/DivertEvent.cs @@ -2,14 +2,14 @@ namespace Divert.Windows; public enum DivertEvent { - NetworkPacket = 0, - FlowEstablished = 1, - FlowDeleted = 2, - SocketBind = 3, - SocketConnect = 4, - SocketListen = 5, - SocketAccept = 6, - SocketClose = 7, - ReflectOpen = 8, - ReflectClose = 9, + NetworkPacket = WINDIVERT_EVENT.WINDIVERT_EVENT_NETWORK_PACKET, + FlowEstablished = WINDIVERT_EVENT.WINDIVERT_EVENT_FLOW_ESTABLISHED, + FlowDeleted = WINDIVERT_EVENT.WINDIVERT_EVENT_FLOW_DELETED, + SocketBind = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_BIND, + SocketConnect = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_CONNECT, + SocketListen = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_LISTEN, + SocketAccept = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_ACCEPT, + SocketClose = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_CLOSE, + ReflectOpen = WINDIVERT_EVENT.WINDIVERT_EVENT_REFLECT_OPEN, + ReflectClose = WINDIVERT_EVENT.WINDIVERT_EVENT_REFLECT_CLOSE, } diff --git a/Divert.Windows/DivertFilter.cs b/Divert.Windows/DivertFilter.cs index 57cc1e3..7d3cb17 100644 --- a/Divert.Windows/DivertFilter.cs +++ b/Divert.Windows/DivertFilter.cs @@ -3,9 +3,9 @@ namespace Divert.Windows; -internal record class ReplaceParenthesesOperation(string Expression); +internal record struct ReplaceParenthesesOperation(string Expression); -public class DivertFilter +public partial class DivertFilter { public string Clause { get; } @@ -65,11 +65,15 @@ private static string ReplaceParentheses(string expression) return builder.ToString(); } - private static readonly string orOp = Regex.Escape("||"); + [GeneratedRegex(@"\s(or|\|\|)\s")] + private static partial Regex OrPatternRegex(); - private static bool MatchOrPattern(string s) => Regex.IsMatch(ReplaceParentheses(s), @$"\s(or|{orOp})\s"); + private static bool MatchOrPattern(string s) => OrPatternRegex().IsMatch(ReplaceParentheses(s)); - private static bool MatchAndPattern(string s) => Regex.IsMatch(ReplaceParentheses(s), @"\s(and|&&)\s"); + [GeneratedRegex(@"\s(and|&&)\s")] + private static partial Regex AndPatternRegex(); + + private static bool MatchAndPattern(string s) => AndPatternRegex().IsMatch(ReplaceParentheses(s)); public static DivertFilter operator &(DivertFilter left, DivertFilter right) { @@ -121,22 +125,13 @@ public Field(string field) public override bool Equals(object? obj) { - if (obj is null) - { - return false; - } - else if (obj is Field filter) - { - return field == filter.field; - } - else if (obj is string s) + return obj switch { - return field == s; - } - else - { - return false; - } + null => false, + Field filter => field == filter.field, + string s => field == s, + _ => false, + }; } public override int GetHashCode() => base.GetHashCode(); @@ -175,37 +170,37 @@ public static implicit operator DivertFilter(Field field) return new DivertFilter(clause); } - public static DivertFilter operator ==(Field left, object right) + public static DivertFilter operator ==(Field left, string right) { string clause = $"{left} = {right}"; return new DivertFilter(clause); } - public static DivertFilter operator !=(Field left, object right) + public static DivertFilter operator !=(Field left, string right) { string clause = $"{left} != {right}"; return new DivertFilter(clause); } - public static DivertFilter operator <(Field left, object right) + public static DivertFilter operator <(Field left, string right) { string clause = $"{left} < {right}"; return new DivertFilter(clause); } - public static DivertFilter operator >(Field left, object right) + public static DivertFilter operator >(Field left, string right) { string clause = $"{left} > {right}"; return new DivertFilter(clause); } - public static DivertFilter operator <=(Field left, object right) + public static DivertFilter operator <=(Field left, string right) { string clause = $"{left} <= {right}"; return new DivertFilter(clause); } - public static DivertFilter operator >=(Field left, object right) + public static DivertFilter operator >=(Field left, string right) { string clause = $"{left} >= {right}"; return new DivertFilter(clause); @@ -266,6 +261,7 @@ public static implicit operator DivertFilter(Field field) public static Field ICMP { get; } = "icmp"; + // spell-checker:ignore icmpv6 public static Field ICMPv6 { get; } = "icmpv6"; public static Field TCP { get; } = "tcp"; diff --git a/Divert.Windows/DivertFlags.cs b/Divert.Windows/DivertFlags.cs index 7f3dfd7..314f2fe 100644 --- a/Divert.Windows/DivertFlags.cs +++ b/Divert.Windows/DivertFlags.cs @@ -11,5 +11,5 @@ public enum DivertFlags SendOnly = 0x0008, WriteOnly = SendOnly, NoInstall = 0x0010, - Fragments = 0x0020 + Fragments = 0x0020, } diff --git a/Divert.Windows/DivertHandle.cs b/Divert.Windows/DivertHandle.cs new file mode 100644 index 0000000..bd2c568 --- /dev/null +++ b/Divert.Windows/DivertHandle.cs @@ -0,0 +1,16 @@ +using Microsoft.Win32.SafeHandles; + +namespace Divert.Windows; + +internal sealed class DivertHandle : SafeHandleZeroOrMinusOneIsInvalid +{ + internal IntPtr Handle => handle; + + public DivertHandle(IntPtr handle) + : base(ownsHandle: false) + { + this.handle = handle; + } + + protected override bool ReleaseHandle() => true; +} diff --git a/Divert.Windows/DivertHelper.cs b/Divert.Windows/DivertHelper.cs index c6d1b22..f2f8f44 100644 --- a/Divert.Windows/DivertHelper.cs +++ b/Divert.Windows/DivertHelper.cs @@ -3,19 +3,16 @@ namespace Divert.Windows; -unsafe public static class DivertHelper +public static unsafe class DivertHelper { public static void CalculateChecksums(Span packet, DivertHelperFlags flags) { fixed (byte* pPacket = packet) { - bool success = NativeMethods.WinDivertHelperCalcChecksums( - pPacket, (uint)packet.Length, - null, - (ulong)flags); + bool success = NativeMethods.WinDivertHelperCalcChecksums(pPacket, (uint)packet.Length, null, (ulong)flags); if (!success) { - throw new Win32Exception(Marshal.GetLastWin32Error()); + throw new Win32Exception(Marshal.GetLastPInvokeError()); } } } diff --git a/Divert.Windows/DivertHelperFlags.cs b/Divert.Windows/DivertHelperFlags.cs index 5e547d5..9a3f03d 100644 --- a/Divert.Windows/DivertHelperFlags.cs +++ b/Divert.Windows/DivertHelperFlags.cs @@ -8,5 +8,5 @@ public enum DivertHelperFlags NoICMPChecksum = 2, NoICMPv6Checksum = 4, NoTCPChecksum = 8, - NoUDPChecksum = 16 + NoUDPChecksum = 16, } diff --git a/Divert.Windows/DivertLayer.cs b/Divert.Windows/DivertLayer.cs index 435cd14..4cc1c4b 100644 --- a/Divert.Windows/DivertLayer.cs +++ b/Divert.Windows/DivertLayer.cs @@ -2,9 +2,9 @@ namespace Divert.Windows; public enum DivertLayer { - Network = 0, - Forward = 1, - Flow = 2, - Socket = 3, - Reflect = 4 + Network = WINDIVERT_LAYER.WINDIVERT_LAYER_NETWORK, + Forward = WINDIVERT_LAYER.WINDIVERT_LAYER_NETWORK_FORWARD, + Flow = WINDIVERT_LAYER.WINDIVERT_LAYER_FLOW, + Socket = WINDIVERT_LAYER.WINDIVERT_LAYER_SOCKET, + Reflect = WINDIVERT_LAYER.WINDIVERT_LAYER_REFLECT, } diff --git a/Divert.Windows/DivertReceiveResult.cs b/Divert.Windows/DivertReceiveResult.cs new file mode 100644 index 0000000..16dec94 --- /dev/null +++ b/Divert.Windows/DivertReceiveResult.cs @@ -0,0 +1,7 @@ +namespace Divert.Windows; + +public readonly struct DivertReceiveResult(int length, Memory addresses) +{ + public int Length => length; + public Memory Addresses => addresses; +} diff --git a/Divert.Windows/DivertService.cs b/Divert.Windows/DivertService.cs index 946bced..f907ac5 100644 --- a/Divert.Windows/DivertService.cs +++ b/Divert.Windows/DivertService.cs @@ -1,16 +1,19 @@ using System.ComponentModel; using System.Runtime.InteropServices; -using System.Runtime.Versioning; -using Windows.Win32; +using System.Threading.Channels; using Windows.Win32.Foundation; -[assembly: SupportedOSPlatform("windows6.0.6000")] - namespace Divert.Windows; -unsafe sealed public class DivertService : IDisposable +/// +/// Main entry point WinDivert APIs. +/// +public sealed unsafe class DivertService : IDisposable { - internal HANDLE Handle { get; } + private readonly DivertHandle handle; + + private readonly ThreadPoolBoundHandle threadPoolBoundHandle; + private readonly Channel vtsPool = Channel.CreateUnbounded(); /// /// Opens a WinDivert handle for the given filter. @@ -23,7 +26,8 @@ public DivertService( DivertFilter filter, DivertLayer layer = DivertLayer.Network, short priority = 0, - DivertFlags flags = DivertFlags.None) + DivertFlags flags = DivertFlags.None + ) { ArgumentNullException.ThrowIfNull(filter); @@ -36,7 +40,8 @@ public DivertService( null, 0, &errorStr, - &errorPos); + &errorPos + ); if (!success) { string? errorString = Marshal.PtrToStringAnsi(errorStr); @@ -44,24 +49,52 @@ public DivertService( } var handle = NativeMethods.WinDivertOpen(s.Ptr, (WINDIVERT_LAYER)layer, priority, (ulong)flags); - if (handle.Value.ToInt64() == -1) + if (new HANDLE(handle) == HANDLE.INVALID_HANDLE_VALUE) { - throw new Win32Exception(Marshal.GetLastWin32Error()); + throw new Win32Exception(Marshal.GetLastPInvokeError()); } - Handle = handle; + this.handle = new DivertHandle(handle); + threadPoolBoundHandle = ThreadPoolBoundHandle.BindHandle(this.handle); } private bool disposed = false; + private bool closed = false; + + private void Close(bool throwOnError) + { + if (!closed) + { + bool success = NativeMethods.WinDivertClose(handle.Handle); + if (!success && throwOnError) + { + throw new Win32Exception(Marshal.GetLastPInvokeError()); + } + closed = true; + } + } + + /// + /// Closes the WinDivert handle. + /// + public void Close() => Close(throwOnError: true); + /// + /// Releases all resources used by the . + /// public void Dispose() { if (!disposed) { - bool success = NativeMethods.WinDivertClose(Handle); - if (!success) + if (vtsPool.Writer.TryComplete()) { - throw new Win32Exception(Marshal.GetLastWin32Error()); + while (vtsPool.Reader.TryRead(out var vts)) + { + vts.Dispose(); + } } + threadPoolBoundHandle.Dispose(); + handle.Dispose(); + Close(throwOnError: false); disposed = true; } GC.SuppressFinalize(this); @@ -72,197 +105,81 @@ public void Dispose() Dispose(); } - internal void ThrowIfDisposed() + private void ThrowIfClosedOrDisposed() { - if (disposed) + if (closed) { - throw new ObjectDisposedException(nameof(DivertService)); + throw new InvalidOperationException(nameof(closed)); } + ObjectDisposedException.ThrowIf(disposed, this); } + /// + /// Shuts down the WinDivert handle. + /// + /// Specifies how to shut down the handle. public void Shutdown(DivertShutdown how) { - ThrowIfDisposed(); + ThrowIfClosedOrDisposed(); - bool success = NativeMethods.WinDivertShutdown(Handle, (WINDIVERT_SHUTDOWN)how); + bool success = NativeMethods.WinDivertShutdown(handle.Handle, (WINDIVERT_SHUTDOWN)how); if (!success) { - throw new Win32Exception(Marshal.GetLastWin32Error()); - } - } - - public (int packetLength, DivertAddress address) Receive(Span buffer) - { - ThrowIfDisposed(); - - uint packetLength; - WINDIVERT_ADDRESS address; - fixed (byte* pBuffer = buffer) - { - bool success = NativeMethods.WinDivertRecv( - Handle, - pBuffer, checked((uint)buffer.Length), - &packetLength, - &address); - if (!success) - { - throw new Win32Exception(Marshal.GetLastWin32Error()); - } - } - return ((int)packetLength, new DivertAddress(address)); - } - - public (int packetLength, int addressLength) ReceiveEx( - Span buffer, - Span addresses, - CancellationToken cancellationToken = default) - { - ThrowIfDisposed(); - cancellationToken.ThrowIfCancellationRequested(); - - using var eventHandle = new ManualResetEvent(initialState: false); - using var _ = cancellationToken.Register(() => eventHandle.Set()); - var overlapped = new NativeOverlapped - { - EventHandle = eventHandle.SafeWaitHandle.DangerousGetHandle() - }; - uint addrLen = (uint)(sizeof(DivertAddress) * addresses.Length); - fixed (byte* pBuffer = buffer) - fixed (DivertAddress* pAddr = addresses) - { - bool success = NativeMethods.WinDivertRecvEx( - Handle, - pBuffer, checked((uint)buffer.Length), - null, - 0, - (WINDIVERT_ADDRESS*)pAddr, - &addrLen, - &overlapped); - if (!success) - { - int error = Marshal.GetLastWin32Error(); - if (error == 997) // ERROR_IO_PENDING - { - eventHandle.WaitOne(); - if (cancellationToken.IsCancellationRequested) - { - success = PInvoke.CancelIoEx(Handle, &overlapped); - if (!success) - { - error = Marshal.GetLastWin32Error(); - if (error != 1168) // ERROR_NOT_FOUND - { - throw new Win32Exception(error); - } - } - } - } - else - { - throw new Win32Exception(error); - } - } - uint packetsLength; - success = PInvoke.GetOverlappedResult(Handle, &overlapped, &packetsLength, true); - if (!success) - { - int error = Marshal.GetLastWin32Error(); - if (error == 995) // ERROR_OPERATION_ABORTED - { - throw new OperationCanceledException(); - } - else - { - throw new Win32Exception(error); - } - } - int addressLength = (int)addrLen / sizeof(DivertAddress); - return ((int)packetsLength, addressLength); + throw new Win32Exception(Marshal.GetLastPInvokeError()); } } - public void Send(ReadOnlySpan buffer, DivertAddress address) + /// + /// Receives one or more packets from the network stack. + /// + /// Buffer to receive the packet data. + /// Buffer to receive the packet addresses. + /// Token to observe cancellation requests. + /// A ValueTask representing the asynchronous receive operation. + public ValueTask ReceiveAsync( + Memory buffer, + Memory addresses, + CancellationToken cancellationToken = default + ) { - ThrowIfDisposed(); + ThrowIfClosedOrDisposed(); - var divertAddress = address.Struct; - fixed (byte* pBuffer = buffer) + if (!vtsPool.Reader.TryRead(out var vts)) { - bool success = NativeMethods.WinDivertSend( - Handle, - pBuffer, checked((uint)buffer.Length), - null, - &divertAddress); - if (!success) - { - throw new Win32Exception(Marshal.GetLastWin32Error()); - } + vts = new DivertValueTaskSource( + vtsPool, + handle.Handle, + threadPoolBoundHandle, + runContinuationsAsynchronously: true + ); } + return vts.ReceiveAsync(buffer, addresses, cancellationToken); } - public void SendEx( - ReadOnlySpan buffer, - Span addresses, - CancellationToken cancellationToken = default) + /// + /// Injects one or more packets into the network stack. + /// + /// Buffer containing the packet data. + /// Addresses of the packets to be injected. + /// Token to observe cancellation requests. + /// A ValueTask representing the asynchronous send operation. + public ValueTask SendAsync( + ReadOnlyMemory buffer, + ReadOnlyMemory addresses, + CancellationToken cancellationToken = default + ) { - ThrowIfDisposed(); - cancellationToken.ThrowIfCancellationRequested(); + ThrowIfClosedOrDisposed(); - using var eventHandle = new ManualResetEvent(initialState: false); - using var _ = cancellationToken.Register(() => eventHandle.Set()); - var overlapped = new NativeOverlapped - { - EventHandle = eventHandle.SafeWaitHandle.DangerousGetHandle() - }; - fixed (byte* pBuffer = buffer) - fixed (DivertAddress* pAddr = addresses) + if (!vtsPool.Reader.TryRead(out var vts)) { - bool success = NativeMethods.WinDivertSendEx( - Handle, - pBuffer, checked((uint)buffer.Length), - null, - 0, - (WINDIVERT_ADDRESS*)pAddr, - (uint)(sizeof(WINDIVERT_ADDRESS) * addresses.Length), - &overlapped); - if (!success) - { - int error = Marshal.GetLastWin32Error(); - if (error == 997) // ERROR_IO_PENDING - { - eventHandle.WaitOne(); - if (cancellationToken.IsCancellationRequested) - { - success = PInvoke.CancelIoEx(Handle, &overlapped); - if (!success) - { - error = Marshal.GetLastWin32Error(); - if (error != 1168) // ERROR_NOT_FOUND - { - throw new Win32Exception(error); - } - } - } - } - else - { - throw new Win32Exception(error); - } - } - uint bytesSent; - success = PInvoke.GetOverlappedResult(Handle, &overlapped, &bytesSent, true); - if (!success) - { - int error = Marshal.GetLastWin32Error(); - if (error == 995) // ERROR_OPERATION_ABORTED - { - throw new OperationCanceledException(); - } - else - { - throw new Win32Exception(error); - } - } + vts = new DivertValueTaskSource( + vtsPool, + handle.Handle, + threadPoolBoundHandle, + runContinuationsAsynchronously: true + ); } + return vts.SendAsync(buffer, addresses, cancellationToken); } } diff --git a/Divert.Windows/DivertShutdown.cs b/Divert.Windows/DivertShutdown.cs index 2173006..067ff9d 100644 --- a/Divert.Windows/DivertShutdown.cs +++ b/Divert.Windows/DivertShutdown.cs @@ -1,8 +1,22 @@ namespace Divert.Windows; +/// +/// Specifies how to shut down a WinDivert handle. +/// public enum DivertShutdown { - Receive = 0x1, - Send = 0x2, - Both = 0x3 + /// + /// Stops new packets from being queued for receiving. + /// + Receive = WINDIVERT_SHUTDOWN.WINDIVERT_SHUTDOWN_RECV, + + /// + /// Stops new packets from being injected. + /// + Send = WINDIVERT_SHUTDOWN.WINDIVERT_SHUTDOWN_SEND, + + /// + /// Stops both receiving and sending of packets. + /// + Both = WINDIVERT_SHUTDOWN.WINDIVERT_SHUTDOWN_BOTH, } diff --git a/Divert.Windows/DivertValueTaskSource.cs b/Divert.Windows/DivertValueTaskSource.cs new file mode 100644 index 0000000..a2ce416 --- /dev/null +++ b/Divert.Windows/DivertValueTaskSource.cs @@ -0,0 +1,287 @@ +using System.Buffers; +using System.ComponentModel; +using System.Runtime.InteropServices; +using System.Threading.Channels; +using System.Threading.Tasks.Sources; +using Windows.Win32; +using Windows.Win32.Foundation; + +namespace Divert.Windows; + +internal sealed unsafe class DivertValueTaskSource + : IValueTaskSource, + IValueTaskSource, + IDisposable +{ + private static class Status + { + public const uint Idle = 0; + public const uint Canceled = 1; + public const uint Pending = 2; + public const uint Disposed = 3; + } + + private static readonly IOCompletionCallback ioCompletionCallback = OnIOCompleted; + + private ManualResetValueTaskSourceCore source; + private readonly Channel pool; + private readonly IntPtr divertHandle; + private readonly ThreadPoolBoundHandle threadPoolBoundHandle; + private readonly PreAllocatedOverlapped preAllocatedOverlapped; + private readonly Memory addressesLengthBuffer; + + private struct PendingOperation( + DivertValueTaskSource vts, + Memory packetBuffer, + Memory addresses, + CancellationTokenRegistration cancellationTokenRegistration + ) : IDisposable + { + public NativeOverlapped* NativeOverlapped { get; private set; } = + vts.threadPoolBoundHandle.AllocateNativeOverlapped(vts.preAllocatedOverlapped); + + public readonly Memory Addresses => addresses; + + public MemoryHandle PacketBufferHandle { get; } = packetBuffer.Pin(); + public MemoryHandle AddressesHandle { get; } = addresses.Pin(); + public MemoryHandle AddressesLengthBufferHandle { get; } = vts.addressesLengthBuffer.Pin(); + + public readonly CancellationToken CancellationToken => cancellationTokenRegistration.Token; + + public void Dispose() + { + cancellationTokenRegistration.Dispose(); + + PacketBufferHandle.Dispose(); + AddressesHandle.Dispose(); + AddressesLengthBufferHandle.Dispose(); + + if (NativeOverlapped is not null) + { + vts.threadPoolBoundHandle.FreeNativeOverlapped(NativeOverlapped); + NativeOverlapped = null; + } + } + } + + private PendingOperation pendingOperation; + private uint status; + + public DivertValueTaskSource( + Channel pool, + IntPtr divertHandle, + ThreadPoolBoundHandle threadPoolBoundHandle, + bool runContinuationsAsynchronously + ) + { + source.RunContinuationsAsynchronously = runContinuationsAsynchronously; + this.pool = pool; + this.divertHandle = divertHandle; + this.threadPoolBoundHandle = threadPoolBoundHandle; + preAllocatedOverlapped = new PreAllocatedOverlapped(ioCompletionCallback, this, null); + addressesLengthBuffer = GC.AllocateArray(1, pinned: true); + } + + private void ExecuteOrRequestCancel() + { + uint originalStatus = Interlocked.CompareExchange(ref status, Status.Canceled, Status.Idle); + if (originalStatus is Status.Pending) + { + _ = PInvoke.CancelIoEx(new(divertHandle), pendingOperation.NativeOverlapped); + } + } + + private void CancelIfRequested() + { + uint originalStatus = Interlocked.CompareExchange(ref status, Status.Pending, Status.Idle); + if (originalStatus is Status.Canceled) + { + _ = PInvoke.CancelIoEx(new(divertHandle), pendingOperation.NativeOverlapped); + Interlocked.CompareExchange(ref status, Status.Idle, originalStatus); + } + } + + private PendingOperation PrepareOperation( + Memory packetBuffer, + Memory addresses, + CancellationToken cancellationToken + ) + { + var cancellationTokenRegistration = cancellationToken.CanBeCanceled + ? cancellationToken.UnsafeRegister( + static state => ((DivertValueTaskSource)state!).ExecuteOrRequestCancel(), + this + ) + : default; + pendingOperation = new PendingOperation(this, packetBuffer, addresses, cancellationTokenRegistration); + return pendingOperation; + } + + public void Dispose() + { + if (Interlocked.Exchange(ref status, Status.Disposed) is Status.Disposed) + { + return; + } + + pendingOperation.Dispose(); + preAllocatedOverlapped.Dispose(); + GC.SuppressFinalize(this); + } + + ~DivertValueTaskSource() + { + Dispose(); + } + + internal short Version => source.Version; + + private void Complete(uint errorCode, uint numBytes) + { + var addresses = pendingOperation.Addresses; + int addressesLength = (int)addressesLengthBuffer.Span[0] / sizeof(DivertAddress); + var token = pendingOperation.CancellationToken; + pendingOperation.Dispose(); + + if (errorCode == (uint)WIN32_ERROR.ERROR_NO_DATA) + { + errorCode = 0; + } + + if (errorCode == (uint)WIN32_ERROR.ERROR_SUCCESS) + { + source.SetResult(new DivertReceiveResult((int)numBytes, addresses[..addressesLength])); + } + else if (errorCode == (uint)WIN32_ERROR.ERROR_OPERATION_ABORTED) + { + var exception = token.IsCancellationRequested + ? new OperationCanceledException() + : new OperationCanceledException(token); + source.SetException(exception); + } + else + { + var exception = new Win32Exception((int)errorCode); + source.SetException(exception); + } + } + + private static void OnIOCompleted(uint errorCode, uint numBytes, NativeOverlapped* pOVERLAP) + { + var vts = (DivertValueTaskSource)ThreadPoolBoundHandle.GetNativeOverlappedState(pOVERLAP)!; + vts.Complete(errorCode, numBytes); + } + + public ValueTaskSourceStatus GetStatus(short token) => source.GetStatus(token); + + public void OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags + ) => source.OnCompleted(continuation, state, token, flags); + + public DivertReceiveResult GetResult(short token) + { + try + { + return source.GetResult(token); + } + finally + { + Interlocked.CompareExchange(ref status, Status.Idle, Status.Pending); + source.Reset(); + if (!pool.Writer.TryWrite(this)) + { + Dispose(); + } + } + } + + void IValueTaskSource.GetResult(short token) => GetResult(token); + + public ValueTask ReceiveAsync( + Memory buffer, + Memory addresses, + CancellationToken cancellationToken + ) + { + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } + + addressesLengthBuffer.Span[0] = (uint)(addresses.Length * sizeof(DivertAddress)); + var pendingOperation = PrepareOperation(buffer, addresses, cancellationToken); + bool success = NativeMethods.WinDivertRecvEx( + divertHandle, + pendingOperation.PacketBufferHandle.Pointer, + (uint)buffer.Length, + null, + 0, + (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, + (uint*)pendingOperation.AddressesLengthBufferHandle.Pointer, + pendingOperation.NativeOverlapped + ); + if (!success) + { + int error = Marshal.GetLastPInvokeError(); + if (error == (int)WIN32_ERROR.ERROR_IO_PENDING) + { + CancelIfRequested(); + } + else + { + pendingOperation.Dispose(); + Interlocked.CompareExchange(ref status, Status.Idle, Status.Canceled); + return ValueTask.FromException(new Win32Exception(error)); + } + } + + return new ValueTask(this, source.Version); + } + + public ValueTask SendAsync( + ReadOnlyMemory buffer, + ReadOnlyMemory addresses, + CancellationToken cancellationToken + ) + { + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } + + var pendingOperation = PrepareOperation( + MemoryMarshal.AsMemory(buffer), + MemoryMarshal.AsMemory(addresses), + cancellationToken + ); + bool success = NativeMethods.WinDivertSendEx( + divertHandle, + pendingOperation.PacketBufferHandle.Pointer, + (uint)buffer.Length, + null, + 0, + (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, + (uint)(addresses.Length * sizeof(DivertAddress)), + pendingOperation.NativeOverlapped + ); + if (!success) + { + int error = Marshal.GetLastPInvokeError(); + if (error == (int)WIN32_ERROR.ERROR_IO_PENDING) + { + CancelIfRequested(); + } + else + { + pendingOperation.Dispose(); + Interlocked.CompareExchange(ref status, Status.Idle, Status.Canceled); + return ValueTask.FromException(new Win32Exception(error)); + } + } + + return new ValueTask(this, source.Version); + } +} diff --git a/Divert.Windows/NativeMethods.cs b/Divert.Windows/NativeMethods.cs index 1cdd2a5..d817327 100644 --- a/Divert.Windows/NativeMethods.cs +++ b/Divert.Windows/NativeMethods.cs @@ -1,91 +1,93 @@ using System.Runtime.InteropServices; -using Windows.Win32.Foundation; namespace Divert.Windows; -unsafe internal static class NativeMethods +internal static unsafe partial class NativeMethods { private const string dllName = "WinDivert.dll"; - [DllImport(dllName, SetLastError = true)] - public static extern HANDLE WinDivertOpen( - IntPtr filter, - WINDIVERT_LAYER layer, - short priority, - ulong flags); + [LibraryImport(dllName, SetLastError = true)] + public static partial IntPtr WinDivertOpen(IntPtr filter, WINDIVERT_LAYER layer, short priority, ulong flags); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertRecv( - HANDLE handle, + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertRecv( + IntPtr handle, void* pPacket, uint packetLen, uint* pRecvLen, - WINDIVERT_ADDRESS* pAddr); + WINDIVERT_ADDRESS* pAddr + ); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertRecvEx( - HANDLE handle, + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertRecvEx( + IntPtr handle, void* pPacket, uint packetLen, uint* pRecvLen, ulong flags, WINDIVERT_ADDRESS* pAddr, uint* pAddrLen, - NativeOverlapped* lpOverlapped); + NativeOverlapped* lpOverlapped + ); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertSend( - HANDLE handle, + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertSend( + IntPtr handle, void* pPacket, uint packetLen, uint* pSendLen, - WINDIVERT_ADDRESS* pAddr); + WINDIVERT_ADDRESS* pAddr + ); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertSendEx( - HANDLE handle, + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertSendEx( + IntPtr handle, void* pPacket, uint packetLen, uint* pSendLen, ulong flags, WINDIVERT_ADDRESS* pAddr, uint addrLen, - NativeOverlapped* lpOverlapped); + NativeOverlapped* lpOverlapped + ); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertShutdown( - HANDLE handle, - WINDIVERT_SHUTDOWN how); + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertShutdown(IntPtr handle, WINDIVERT_SHUTDOWN how); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertClose( - HANDLE handle); + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertClose(IntPtr handle); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertSetParam( - HANDLE handle, - WINDIVERT_PARAM param, - ulong value); + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertSetParam(IntPtr handle, WINDIVERT_PARAM param, ulong value); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertGetParam( - HANDLE handle, - WINDIVERT_PARAM param, - ulong* pValue); + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertGetParam(IntPtr handle, WINDIVERT_PARAM param, ulong* pValue); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertHelperCalcChecksums( + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertHelperCalcChecksums( void* pPacket, uint packetLen, WINDIVERT_ADDRESS* pAddr, - ulong flags); + ulong flags + ); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertHelperCompileFilter( + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertHelperCompileFilter( IntPtr filter, WINDIVERT_LAYER layer, byte* @object, uint objLen, IntPtr* errorStr, - uint* errorPos); + uint* errorPos + ); } diff --git a/Divert.Windows/NativeMethods.txt b/Divert.Windows/NativeMethods.txt index fbcf31c..3737ecf 100644 --- a/Divert.Windows/NativeMethods.txt +++ b/Divert.Windows/NativeMethods.txt @@ -1,2 +1,4 @@ +INVALID_HANDLE_VALUE +WIN32_ERROR GetOverlappedResult CancelIoEx diff --git a/Divert.Windows/NativeTypes.cs b/Divert.Windows/NativeTypes.cs index 30ae8e8..5af8043 100644 --- a/Divert.Windows/NativeTypes.cs +++ b/Divert.Windows/NativeTypes.cs @@ -8,7 +8,7 @@ internal enum WINDIVERT_LAYER WINDIVERT_LAYER_NETWORK_FORWARD = 1, WINDIVERT_LAYER_FLOW = 2, WINDIVERT_LAYER_SOCKET = 3, - WINDIVERT_LAYER_REFLECT = 4 + WINDIVERT_LAYER_REFLECT = 4, } [StructLayout(LayoutKind.Explicit, Size = 8)] @@ -22,7 +22,7 @@ internal struct WINDIVERT_DATA_NETWORK } [StructLayout(LayoutKind.Explicit, Size = 64)] -unsafe internal struct WINDIVERT_DATA_FLOW +internal unsafe struct WINDIVERT_DATA_FLOW { [FieldOffset(0)] public ulong EndpointId; @@ -50,7 +50,7 @@ unsafe internal struct WINDIVERT_DATA_FLOW } [StructLayout(LayoutKind.Explicit, Size = 64)] -unsafe internal struct WINDIVERT_DATA_SOCKET +internal unsafe struct WINDIVERT_DATA_SOCKET { [FieldOffset(0)] public ulong EndpointId; @@ -110,7 +110,7 @@ internal enum WINDIVERT_ADDRESS_BITS : byte } [StructLayout(LayoutKind.Explicit, Size = 80)] -unsafe internal struct WINDIVERT_ADDRESS +internal unsafe struct WINDIVERT_ADDRESS { [FieldOffset(0)] public long Timestamp; @@ -166,12 +166,12 @@ internal enum WINDIVERT_PARAM WINDIVERT_PARAM_QUEUE_TIME = 1, WINDIVERT_PARAM_QUEUE_SIZE = 2, WINDIVERT_PARAM_VERSION_MAJOR = 3, - WINDIVERT_PARAM_VERSION_MINOR = 4 + WINDIVERT_PARAM_VERSION_MINOR = 4, } internal enum WINDIVERT_SHUTDOWN { WINDIVERT_SHUTDOWN_RECV = 0x1, WINDIVERT_SHUTDOWN_SEND = 0x2, - WINDIVERT_SHUTDOWN_BOTH = 0x3 + WINDIVERT_SHUTDOWN_BOTH = 0x3, } diff --git a/Examples/Ping/Ping.csproj b/Examples/Ping/Ping.csproj index ae0f9c1..848f1cd 100644 --- a/Examples/Ping/Ping.csproj +++ b/Examples/Ping/Ping.csproj @@ -1,18 +1,14 @@ - Exe - net6.0 - enable - enable + $(DefaultTargetFramework) + true + win-x64 + $(ProjectRoot).local/Ping + CA2007 - - - - - - \ No newline at end of file + diff --git a/Examples/Ping/Program.cs b/Examples/Ping/Program.cs index 918c2c7..bb3e8f4 100644 --- a/Examples/Ping/Program.cs +++ b/Examples/Ping/Program.cs @@ -1,6 +1,7 @@ using System.Diagnostics; using System.Net; using System.Net.NetworkInformation; +using System.Net.Sockets; using System.Runtime.Versioning; using System.Security.Principal; using Divert.Windows; @@ -16,23 +17,30 @@ FileName = Environment.ProcessPath, Verb = "runas", Arguments = args.Length > 0 ? args[0] : string.Empty, - UseShellExecute = true + UseShellExecute = true, }; Process.Start(startInfo); Environment.Exit(0); } -string remoteIP = args.Length > 0 ? args[0] : "8.8.8.8"; +string? remoteIP = args.Length > 0 ? args[0] : null; +if (remoteIP is null) +{ + var gateway = NetworkInterface + .GetAllNetworkInterfaces() + .Where(nic => + nic.OperationalStatus == OperationalStatus.Up && nic.NetworkInterfaceType != NetworkInterfaceType.Loopback + ) + .SelectMany(nic => nic.GetIPProperties()?.GatewayAddresses?.ToArray() ?? []) + .Select(g => g.Address) + .FirstOrDefault(address => address.AddressFamily == AddressFamily.InterNetwork); + remoteIP = gateway?.ToString(); +} +remoteIP ??= "1.1.1.1"; var outboundFilter = - DivertFilter.Outbound - & !DivertFilter.Loopback - & DivertFilter.RemoteAddress == remoteIP - & DivertFilter.ICMP; -var inboundFilter = - DivertFilter.Inbound - & DivertFilter.RemoteAddress == remoteIP - & DivertFilter.ICMP; + DivertFilter.Outbound & !DivertFilter.Loopback & DivertFilter.RemoteAddress == remoteIP & DivertFilter.ICMP; +var inboundFilter = DivertFilter.Inbound & DivertFilter.RemoteAddress == remoteIP & DivertFilter.ICMP; Console.WriteLine($"{nameof(outboundFilter)}: {outboundFilter}"); Console.WriteLine($"{nameof(inboundFilter)}: {inboundFilter}"); @@ -52,18 +60,20 @@ { Console.WriteLine( $"Reply from {reply.Address} (Ping): " - + $"bytes={reply.Buffer?.Length} " - + $"time={reply.RoundtripTime}ms " - + $"TTL={reply.Options?.Ttl}"); + + $"bytes={reply.Buffer?.Length} " + + $"time={reply.RoundtripTime}ms " + + $"TTL={reply.Options?.Ttl}" + ); } else { Console.WriteLine(reply.Status); + break; } } }); -var outbound = Task.Run(() => +var outbound = Task.Run(async () => { using var cts = new CancellationTokenSource(); cts.CancelAfter(3500); @@ -73,11 +83,11 @@ { try { - var (packetLength, _) = outDivert.ReceiveEx(buffer, addresses, cts.Token); - var packet = buffer.AsSpan(0, packetLength); - var remoteAddress = new IPAddress(packet[16..20]); + var result = await outDivert.ReceiveAsync(buffer, addresses, cts.Token).ConfigureAwait(false); + var packet = buffer.AsMemory(0, result.Length); + var remoteAddress = new IPAddress(packet[16..20].Span); Console.WriteLine($"Pinging {remoteAddress} with {packet.Length - 28} bytes of data (Divert):"); - outDivert.Send(packet, addresses[0]); + await outDivert.SendAsync(packet, addresses, cts.Token); } catch (OperationCanceledException) { @@ -92,25 +102,27 @@ } }); -var inbound = Task.Run(() => +var inbound = Task.Run(async () => { using var cts = new CancellationTokenSource(); cts.CancelAfter(2500); var buffer = new byte[ushort.MaxValue]; + var addresses = new DivertAddress[1]; while (true) { try { - var (packetLength, address) = inDivert.Receive(buffer); - var packet = buffer.AsSpan(0, packetLength); - var remoteAddress = new IPAddress(packet[12..16]); - long timestamp = BitConverter.ToInt64(packet.Slice(28, sizeof(long))); + var receiveResult = await inDivert.ReceiveAsync(buffer, addresses, cts.Token).ConfigureAwait(false); + var packet = buffer.AsMemory(0, receiveResult.Length); + var remoteAddress = new IPAddress(packet[12..16].Span); + long timestamp = BitConverter.ToInt64(packet.Slice(28, sizeof(long)).Span); Console.WriteLine( - $"Reply from {remoteAddress} (Divert): " + $"Reply from {remoteAddress} (Divert): " + $"bytes={packet.Length - 28} " + $"time={DateTimeOffset.Now.ToUnixTimeMilliseconds() - timestamp}ms " - + $"TTL={packet[8]}"); - inDivert.SendEx(packet, new[] { address }, cts.Token); + + $"TTL={packet.Span[8]}" + ); + await inDivert.SendAsync(packet, addresses, cts.Token); } catch (OperationCanceledException) { @@ -125,4 +137,6 @@ } }); -await Task.WhenAll(ping, outbound, inbound); +await Task.WhenAll(ping, outbound, inbound).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); +Console.WriteLine("Done."); +Console.ReadLine(); diff --git a/License b/LICENSE similarity index 100% rename from License rename to LICENSE diff --git a/Workspace.proj b/Workspace.proj new file mode 100644 index 0000000..204d901 --- /dev/null +++ b/Workspace.proj @@ -0,0 +1,6 @@ + + + + + + diff --git a/package.json b/package.json new file mode 100644 index 0000000..3be0d0a --- /dev/null +++ b/package.json @@ -0,0 +1,19 @@ +{ + "private": true, + "type": "module", + "scripts": { + "build": "pnpm run cake Build", + "cake": "dotnet run --project Automation --target", + "format": "pnpm run cake Format", + "lint": "pnpm run cake Lint", + "restore": "pnpm run cake Restore" + }, + "dependencies": { + "@prettier/plugin-xml": "^3.4.2", + "cspell": "^9.2.1", + "prettier": "^3.6.2", + "prettier-plugin-ini": "^1.3.0", + "prettier-plugin-packagejson": "^2.5.19", + "prettier-plugin-sh": "^0.18.0" + } +} From 45b4a639a1101024dcb174f4a1d44f9d37cce5d0 Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 2 Nov 2025 03:39:46 +0000 Subject: [PATCH 02/20] test setup --- .config/cspell/cspell.json | 10 + .config/dotnet/Packages.props | 3 + .config/dotnet/Project.props | 1 + .config/dotnet/tools.json | 13 +- .devcontainer/Dockerfile | 3 + .devcontainer/compose.yaml | 3 + .devcontainer/devcontainer.json | 3 +- Automation/Automation.csproj | 2 + Automation/Context.cs | 4 + Automation/Restore.cs | 2 + Automation/Tasks/CIFSMount.cs | 75 +++++ Automation/Test.cs | 101 ++++++ Divert.Windows.TestRunner/AssemblyInfo.cs | 4 + .../CoverageSettings.xml | 11 + .../Divert.Windows.TestRunner.csproj | 24 ++ Divert.Windows.TestRunner/Program.cs | 29 ++ Divert.Windows.TestRunner/app.manifest | 10 + Divert.Windows.Tests/AssemblyInfo.cs | 4 + .../Divert.Windows.Tests.csproj | 24 ++ Divert.Windows.Tests/FilterTests.cs | 13 + .../DivertReceiveValueTaskSource.cs | 121 ++++++++ .../DivertSendValueTaskSource.cs | 114 +++++++ .../AsyncOperation/DivertValueTaskSource.cs | 135 ++++++++ .../AsyncOperation/PendingOperation.cs | 25 ++ Divert.Windows/Divert.Windows.csproj | 5 +- Divert.Windows/DivertHandle.cs | 6 +- Divert.Windows/DivertService.cs | 93 +++--- Divert.Windows/DivertValueTaskSource.cs | 287 ------------------ Divert.Windows/SafeHandleExtensions.cs | 24 ++ Examples/Ping/Ping.csproj | 4 +- Examples/Ping/Program.cs | 18 +- package.json | 4 +- 32 files changed, 818 insertions(+), 357 deletions(-) create mode 100644 Automation/Tasks/CIFSMount.cs create mode 100644 Automation/Test.cs create mode 100644 Divert.Windows.TestRunner/AssemblyInfo.cs create mode 100644 Divert.Windows.TestRunner/CoverageSettings.xml create mode 100644 Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj create mode 100644 Divert.Windows.TestRunner/Program.cs create mode 100644 Divert.Windows.TestRunner/app.manifest create mode 100644 Divert.Windows.Tests/AssemblyInfo.cs create mode 100644 Divert.Windows.Tests/Divert.Windows.Tests.csproj create mode 100644 Divert.Windows.Tests/FilterTests.cs create mode 100644 Divert.Windows/AsyncOperation/DivertReceiveValueTaskSource.cs create mode 100644 Divert.Windows/AsyncOperation/DivertSendValueTaskSource.cs create mode 100644 Divert.Windows/AsyncOperation/DivertValueTaskSource.cs create mode 100644 Divert.Windows/AsyncOperation/PendingOperation.cs delete mode 100644 Divert.Windows/DivertValueTaskSource.cs create mode 100644 Divert.Windows/SafeHandleExtensions.cs diff --git a/.config/cspell/cspell.json b/.config/cspell/cspell.json index 55bcb88..49ffc5c 100644 --- a/.config/cspell/cspell.json +++ b/.config/cspell/cspell.json @@ -5,14 +5,24 @@ "gitignoreRoot": ".", "ignorePaths": ["LICENSE"], "words": [ + "CIFS", "csharpierignore", "csharpierrc", "devcontainer", "devcontainers", + "getgid", + "getuid", "globalconfig", + "globaltool", + "libc", "msbuild", + "MSTEST", "packagejson", + "reportgenerator", "runas", + "runsettings", + "statvfs", + "Syscall", "WINDIVERT" ] } diff --git a/.config/dotnet/Packages.props b/.config/dotnet/Packages.props index 9e360e8..8a46bc8 100644 --- a/.config/dotnet/Packages.props +++ b/.config/dotnet/Packages.props @@ -4,7 +4,10 @@ + + + diff --git a/.config/dotnet/Project.props b/.config/dotnet/Project.props index f9260ca..10c3360 100644 --- a/.config/dotnet/Project.props +++ b/.config/dotnet/Project.props @@ -2,6 +2,7 @@ enable enable + true diff --git a/.config/dotnet/tools.json b/.config/dotnet/tools.json index 7f63919..20f019c 100644 --- a/.config/dotnet/tools.json +++ b/.config/dotnet/tools.json @@ -1,5 +1,16 @@ { "version": 1, "isRoot": true, - "tools": { "csharpier": { "version": "1.1.2", "commands": ["csharpier"], "rollForward": false } } + "tools": { + "csharpier": { + "version": "1.1.2", + "commands": ["csharpier"], + "rollForward": false + }, + "dotnet-reportgenerator-globaltool": { + "version": "5.4.18", + "commands": ["reportgenerator"], + "rollForward": false + } + } } diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index cc342aa..4c27707 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,3 +1,6 @@ FROM mcr.microsoft.com/devcontainers/javascript-node:22 +RUN apt-get update && apt-get install --yes \ + cifs-utils + RUN npm install --global pnpm@latest-10 diff --git a/.devcontainer/compose.yaml b/.devcontainer/compose.yaml index 8cabace..13d9f6b 100644 --- a/.devcontainer/compose.yaml +++ b/.devcontainer/compose.yaml @@ -2,10 +2,13 @@ services: devcontainer: env_file: - .env + - path: ../.local/.env + required: false build: context: . dockerfile: Dockerfile init: true + privileged: true volumes: - WORKSPACES:${WORKSPACES} - ..:${WORKSPACES}/divert-windows diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index bb26252..17cb4db 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -25,7 +25,8 @@ "ignore": "ignore", "attributes": "properties", "rc": "properties", - "*.globalconfig": "ini" + "*.globalconfig": "ini", + "app.manifest": "xml" }, "editor.formatOnSave": true, "editor.defaultFormatter": "esbenp.prettier-vscode", diff --git a/Automation/Automation.csproj b/Automation/Automation.csproj index e9c5ec0..0771360 100644 --- a/Automation/Automation.csproj +++ b/Automation/Automation.csproj @@ -2,8 +2,10 @@ Exe $(DefaultTargetFramework) + CA2007 + diff --git a/Automation/Context.cs b/Automation/Context.cs index a2ac672..1c1e0f1 100644 --- a/Automation/Context.cs +++ b/Automation/Context.cs @@ -11,4 +11,8 @@ public class Context(ICakeContext context) : FrostingContext(context) public static string ProjectRoot => new FileInfo(GetFilePath()).Directory!.Parent!.FullName; public static string Workspaces => new DirectoryInfo(ProjectRoot).Parent!.FullName; + + public static string LocalDirectory => Path.Combine(ProjectRoot, ".local"); + + public static string LocalWindowsDirectory => Path.Combine(LocalDirectory, "windows"); } diff --git a/Automation/Restore.cs b/Automation/Restore.cs index 50011cb..58c1764 100644 --- a/Automation/Restore.cs +++ b/Automation/Restore.cs @@ -1,9 +1,11 @@ +using Automation.Tasks; using Cake.Common.Tools.Command; using Cake.Common.Tools.DotNet; using Cake.Frosting; namespace Automation; +[IsDependentOn(typeof(CIFSMount))] public class Restore : FrostingTask { public override void Run(Context context) diff --git a/Automation/Tasks/CIFSMount.cs b/Automation/Tasks/CIFSMount.cs new file mode 100644 index 0000000..6e0c2d1 --- /dev/null +++ b/Automation/Tasks/CIFSMount.cs @@ -0,0 +1,75 @@ +using Cake.Common.Tools.Command; +using Cake.Core; +using Cake.Core.Diagnostics; +using Cake.Core.IO; +using Cake.Frosting; +using Mono.Unix.Native; + +namespace Automation.Tasks; + +internal sealed record class CIFSMountConfig(string RemoteHost, string Share, string UserName, string Password); + +// Optionally mounts Bin/ directory to a CIFS share, workaround for slow .exe startup time in WSL2 directories. +public class CIFSMount : FrostingTask +{ + public const string CIFS_REMOTE_HOST = nameof(CIFS_REMOTE_HOST); + public const string CIFS_SHARE = nameof(CIFS_SHARE); + public const string CIFS_USERNAME = nameof(CIFS_USERNAME); + public const string CIFS_PASSWORD = nameof(CIFS_PASSWORD); + + private static CIFSMountConfig? LoadConfig() + { + string? remoteHost = Environment.GetEnvironmentVariable(CIFS_REMOTE_HOST); + string? share = Environment.GetEnvironmentVariable(CIFS_SHARE); + string? userName = Environment.GetEnvironmentVariable(CIFS_USERNAME); + string? password = Environment.GetEnvironmentVariable(CIFS_PASSWORD); + if (share is null || userName is null || password is null) + { + return null; + } + return new CIFSMountConfig( + RemoteHost: remoteHost ?? "host.docker.internal", + Share: share, + UserName: userName, + Password: password + ); + } + + public override void Run(Context context) + { + var config = LoadConfig(); + if (config is null) + { + context.Log.Information("CIFS share not configured, skipping."); + return; + } + + string mountPath = Context.LocalWindowsDirectory; + try + { + if (DriveInfo.GetDrives().Any(drive => drive.DriveType is DriveType.Network && drive.Name == mountPath)) + { + context.Log.Information("CIFS share already mounted at {0}, skipping.", mountPath); + // Already mounted. + return; + } + + Directory.CreateDirectory(mountPath); + uint uid = Syscall.getuid(); + uint gid = Syscall.getgid(); + context.Command( + ["sudo"], + ProcessArgumentBuilder + .FromStrings(["mount", "--types", "cifs", $"//{config.RemoteHost}/{config.Share}", mountPath]) + .AppendSwitchSecret( + "--options", + $"username={config.UserName},password={config.Password},uid={uid},gid={gid}" + ) + ); + } + catch (Exception ex) + { + context.Log.Warning("Failed to mount CIFS share: {0}", ex.Message); + } + } +} diff --git a/Automation/Test.cs b/Automation/Test.cs new file mode 100644 index 0000000..8518155 --- /dev/null +++ b/Automation/Test.cs @@ -0,0 +1,101 @@ +using System.Threading.Channels; +using Cake.Common.Tools.DotNet; +using Cake.Common.Tools.DotNet.Tool; +using Cake.Core.Diagnostics; +using Cake.Core.IO; +using Cake.Frosting; +using Path = System.IO.Path; + +namespace Automation; + +internal static class Test +{ + public static string TestResultsDirectory => Path.Combine(Context.LocalWindowsDirectory, "TestResults"); + + public static string CoverletOutput => Path.Combine(TestResultsDirectory, "coverage.cobertura.xml"); + + public static string TestReportsDirectory => Path.Combine(Context.LocalDirectory, "TestReports"); + + public static void GenerateReport(Context context) + { + string sourcePath = Path.Combine(Context.ProjectRoot, "Divert.Windows"); + // spell-checker: ignore sourcedirs targetdir reporttypes + context.DotNetTool( + "reportgenerator", + new DotNetToolSettings + { + ArgumentCustomization = _ => + ProcessArgumentBuilder.FromStrings( + [ + "reportgenerator", + $"-reports:{CoverletOutput}", + $"-sourcedirs:{sourcePath}", + $"-targetdir:{TestReportsDirectory}", + "-reporttypes:Html;MarkdownSummary", + ] + ), + } + ); + } +} + +public class TestReport : FrostingTask +{ + public override void Run(Context context) + { + Test.GenerateReport(context); + } +} + +public class TestReportWatch : AsyncFrostingTask +{ + public override async Task RunAsync(Context context) + { + var changed = Channel.CreateBounded(10); + using var watcher = new FileSystemWatcher + { + Path = Context.LocalWindowsDirectory, + Filter = "coverage.cobertura.xml", + EnableRaisingEvents = true, + IncludeSubdirectories = false, + NotifyFilter = NotifyFilters.LastWrite, + }; + watcher.Changed += (_, e) => changed.Writer.TryWrite(0); + + var poll = Task.Run(async () => + { + while (changed.Writer.TryWrite(0)) + { + await Task.Delay(TimeSpan.FromMilliseconds(1000)); + } + }); + + var exit = Task.Run(() => + { + Console.ReadLine(); + context.Log.Information("Stopping coverage report watcher..."); + changed.Writer.TryComplete(); + }); + + DateTimeOffset? lastChanged = null; + await foreach (var _ in changed.Reader.ReadAllAsync()) + { + try + { + var offset = new DateTimeOffset(File.GetLastWriteTime(Test.CoverletOutput)); + if (lastChanged != offset) + { + lastChanged = offset; + context.Log.Information("Detected change to coverage file, regenerating report..."); + Test.GenerateReport(context); + } + } + catch (Exception e) + { + context.Log.Error("Error regenerating report: {0}", e); + } + } + await Task.WhenAll(poll, exit); + context.Log.Information("Coverage report watcher stopped."); + } +} diff --git a/Divert.Windows.TestRunner/AssemblyInfo.cs b/Divert.Windows.TestRunner/AssemblyInfo.cs new file mode 100644 index 0000000..a298fba --- /dev/null +++ b/Divert.Windows.TestRunner/AssemblyInfo.cs @@ -0,0 +1,4 @@ +using System.Runtime.Versioning; + +[assembly: DoNotParallelize] +[assembly: SupportedOSPlatform("windows6.0.6000")] diff --git a/Divert.Windows.TestRunner/CoverageSettings.xml b/Divert.Windows.TestRunner/CoverageSettings.xml new file mode 100644 index 0000000..48a287b --- /dev/null +++ b/Divert.Windows.TestRunner/CoverageSettings.xml @@ -0,0 +1,11 @@ + + + + + + .*\.g\.cs$ + .*\.generated\.cs$ + + + + diff --git a/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj b/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj new file mode 100644 index 0000000..9e9c4dc --- /dev/null +++ b/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj @@ -0,0 +1,24 @@ + + + Exe + $(DefaultTargetFramework) + app.manifest + win-x64 + true + $(ProjectRoot).local/windows/$(MSBuildThisFileName) + CA2007 + + + + + + + + + + + diff --git a/Divert.Windows.TestRunner/Program.cs b/Divert.Windows.TestRunner/Program.cs new file mode 100644 index 0000000..7e1647f --- /dev/null +++ b/Divert.Windows.TestRunner/Program.cs @@ -0,0 +1,29 @@ +using System.Diagnostics; + +string appPath = Path.GetDirectoryName(Environment.ProcessPath)!; +string testResultsDirectory = Path.Combine(appPath, "../TestResults"); +string coverageSettingsPath = Path.Combine(appPath, "CoverageSettings.xml"); +string coverageOutputPath = Path.Combine(testResultsDirectory, "coverage.cobertura.xml"); + +Directory.SetCurrentDirectory(appPath); + +using var process = Process.Start( + new ProcessStartInfo + { + FileName = Path.Combine(appPath, "Divert.Windows.Tests.exe"), + ArgumentList = + { + "--results-directory", + testResultsDirectory, + "--coverage", + "--coverage-settings", + coverageSettingsPath, + "--coverage-output-format", + "cobertura", + "--coverage-output", + coverageOutputPath, + }, + } +)!; +await process.WaitForExitAsync(); +return process.ExitCode; diff --git a/Divert.Windows.TestRunner/app.manifest b/Divert.Windows.TestRunner/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Divert.Windows.TestRunner/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/Divert.Windows.Tests/AssemblyInfo.cs b/Divert.Windows.Tests/AssemblyInfo.cs new file mode 100644 index 0000000..a298fba --- /dev/null +++ b/Divert.Windows.Tests/AssemblyInfo.cs @@ -0,0 +1,4 @@ +using System.Runtime.Versioning; + +[assembly: DoNotParallelize] +[assembly: SupportedOSPlatform("windows6.0.6000")] diff --git a/Divert.Windows.Tests/Divert.Windows.Tests.csproj b/Divert.Windows.Tests/Divert.Windows.Tests.csproj new file mode 100644 index 0000000..f2f94bb --- /dev/null +++ b/Divert.Windows.Tests/Divert.Windows.Tests.csproj @@ -0,0 +1,24 @@ + + + Exe + true + $(DefaultTargetFramework) + false + true + CA2007 + + win-x64 + true + + + + + all + + + + + + + + diff --git a/Divert.Windows.Tests/FilterTests.cs b/Divert.Windows.Tests/FilterTests.cs new file mode 100644 index 0000000..92b3642 --- /dev/null +++ b/Divert.Windows.Tests/FilterTests.cs @@ -0,0 +1,13 @@ +namespace Divert.Windows.Tests; + +[TestClass] +public class FilterTests +{ + [TestMethod] + public void TestFilterToString() + { + var filter = (DivertFilter.TCP | DivertFilter.UDP) & DivertFilter.Outbound; + var result = filter.ToString(); + Assert.AreEqual("(tcp or udp) and outbound", result); + } +} diff --git a/Divert.Windows/AsyncOperation/DivertReceiveValueTaskSource.cs b/Divert.Windows/AsyncOperation/DivertReceiveValueTaskSource.cs new file mode 100644 index 0000000..4457d43 --- /dev/null +++ b/Divert.Windows/AsyncOperation/DivertReceiveValueTaskSource.cs @@ -0,0 +1,121 @@ +using System.ComponentModel; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading.Channels; +using System.Threading.Tasks.Sources; +using Windows.Win32.Foundation; + +namespace Divert.Windows.AsyncOperation; + +internal sealed unsafe class DivertReceiveValueTaskSource( + Channel pool, + DivertHandle divertHandle, + ThreadPoolBoundHandle threadPoolBoundHandle, + bool runContinuationsAsynchronously +) : DivertValueTaskSource(divertHandle, threadPoolBoundHandle), IValueTaskSource +{ + private ManualResetValueTaskSourceCore source = new() + { + RunContinuationsAsynchronously = runContinuationsAsynchronously, + }; + + private readonly Memory addressesLengthBuffer = GC.AllocateArray(1, pinned: true); + + protected override void Complete(uint errorCode, uint numBytes, in PendingOperation pendingOperation) + { + var addresses = pendingOperation.Addresses; + int addressesLength = (int)addressesLengthBuffer.Span[0] / sizeof(DivertAddress); + var token = pendingOperation.CancellationToken; + + if (errorCode is (uint)WIN32_ERROR.ERROR_NO_DATA) + { + errorCode = 0; + } + + if (errorCode is (uint)WIN32_ERROR.ERROR_SUCCESS) + { + source.SetResult(new DivertReceiveResult((int)numBytes, addresses[..addressesLength])); + } + else if (errorCode is (uint)WIN32_ERROR.ERROR_OPERATION_ABORTED) + { + var exception = token.IsCancellationRequested + ? new OperationCanceledException() + : new OperationCanceledException(token); + source.SetException(exception); + } + else + { + var exception = new Win32Exception((int)errorCode); + source.SetException(exception); + } + } + + public DivertReceiveResult GetResult(short token) + { + try + { + return source.GetResult(token); + } + finally + { + source.Reset(); + if (!pool.Writer.TryWrite(this)) + { + Dispose(); + } + } + } + + public ValueTaskSourceStatus GetStatus(short token) => source.GetStatus(token); + + public void OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags + ) => source.OnCompleted(continuation, state, token, flags); + + public ValueTask ReceiveAsync( + Memory buffer, + Memory addresses, + CancellationToken cancellationToken + ) + { + using var _ = DivertHandle.GetReference(out var handle); + addressesLengthBuffer.Span[0] = (uint)(addresses.Length * sizeof(DivertAddress)); + var pendingOperation = PrepareOperation(buffer, addresses, cancellationToken); + try + { + bool success = NativeMethods.WinDivertRecvEx( + handle, + pendingOperation.PacketBufferHandle.Pointer, + (uint)buffer.Length, + null, + 0, + (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, + (uint*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(addressesLengthBuffer.Span)), + pendingOperation.NativeOverlapped + ); + if (!success) + { + int error = Marshal.GetLastPInvokeError(); + if (error is (int)WIN32_ERROR.ERROR_IO_PENDING) + { + CancelIfRequested(); + } + else + { + Dispose(); + return ValueTask.FromException(new Win32Exception(error)); + } + } + + return new ValueTask(this, source.Version); + } + catch + { + Dispose(); + throw; + } + } +} diff --git a/Divert.Windows/AsyncOperation/DivertSendValueTaskSource.cs b/Divert.Windows/AsyncOperation/DivertSendValueTaskSource.cs new file mode 100644 index 0000000..8364f70 --- /dev/null +++ b/Divert.Windows/AsyncOperation/DivertSendValueTaskSource.cs @@ -0,0 +1,114 @@ +using System.ComponentModel; +using System.Runtime.InteropServices; +using System.Threading.Channels; +using System.Threading.Tasks.Sources; +using Windows.Win32.Foundation; + +namespace Divert.Windows.AsyncOperation; + +internal sealed unsafe class DivertSendValueTaskSource( + Channel pool, + DivertHandle divertHandle, + ThreadPoolBoundHandle threadPoolBoundHandle, + bool runContinuationsAsynchronously +) : DivertValueTaskSource(divertHandle, threadPoolBoundHandle), IValueTaskSource +{ + private ManualResetValueTaskSourceCore source = new() + { + RunContinuationsAsynchronously = runContinuationsAsynchronously, + }; + + protected override void Complete(uint errorCode, uint numBytes, in PendingOperation pendingOperation) + { + var token = pendingOperation.CancellationToken; + + if (errorCode is (uint)WIN32_ERROR.ERROR_SUCCESS) + { + source.SetResult(0); + } + else if (errorCode is (uint)WIN32_ERROR.ERROR_OPERATION_ABORTED) + { + var exception = token.IsCancellationRequested + ? new OperationCanceledException() + : new OperationCanceledException(token); + source.SetException(exception); + } + else + { + var exception = new Win32Exception((int)errorCode); + source.SetException(exception); + } + } + + public void GetResult(short token) + { + try + { + source.GetResult(token); + } + finally + { + source.Reset(); + if (!pool.Writer.TryWrite(this)) + { + Dispose(); + } + } + } + + public ValueTaskSourceStatus GetStatus(short token) => source.GetStatus(token); + + public void OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags + ) => source.OnCompleted(continuation, state, token, flags); + + public ValueTask SendAsync( + ReadOnlyMemory buffer, + ReadOnlyMemory addresses, + CancellationToken cancellationToken + ) + { + using var _ = DivertHandle.GetReference(out var handle); + var pendingOperation = PrepareOperation( + MemoryMarshal.AsMemory(buffer), + MemoryMarshal.AsMemory(addresses), + cancellationToken + ); + try + { + bool success = NativeMethods.WinDivertSendEx( + handle, + pendingOperation.PacketBufferHandle.Pointer, + (uint)buffer.Length, + null, + 0, + (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, + (uint)(addresses.Length * sizeof(DivertAddress)), + pendingOperation.NativeOverlapped + ); + if (!success) + { + int error = Marshal.GetLastPInvokeError(); + if (error == (int)WIN32_ERROR.ERROR_IO_PENDING) + { + CancelIfRequested(); + } + else + { + Dispose(); + return ValueTask.FromException(new Win32Exception(error)); + } + } + + return new ValueTask(this, source.Version); + } + catch + { + Dispose(); + throw; + } + } +} diff --git a/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs new file mode 100644 index 0000000..7a78d8c --- /dev/null +++ b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs @@ -0,0 +1,135 @@ +using Windows.Win32; + +namespace Divert.Windows.AsyncOperation; + +internal abstract unsafe class DivertValueTaskSource : IDisposable +{ + private static class Status + { + public const uint Idle = 0; + public const uint Canceled = 1; + public const uint Pending = 2; + public const uint Disposed = 3; + } + + private static readonly IOCompletionCallback ioCompletionCallback = OnIOCompleted; + + private readonly DivertHandle divertHandle; + private readonly ThreadPoolBoundHandle threadPoolBoundHandle; + private readonly PreAllocatedOverlapped preAllocatedOverlapped; + + private IntPtr nativeOverlapped; + private PendingOperation pendingOperation; + private uint status; + + protected DivertHandle DivertHandle => divertHandle; + + public DivertValueTaskSource(DivertHandle divertHandle, ThreadPoolBoundHandle threadPoolBoundHandle) + { + this.divertHandle = divertHandle; + this.threadPoolBoundHandle = threadPoolBoundHandle; + preAllocatedOverlapped = new PreAllocatedOverlapped(ioCompletionCallback, this, null); + } + + // From cancellation registration. + private void ExecuteOrRequestCancel() + { + uint originalStatus = Interlocked.CompareExchange(ref status, Status.Canceled, Status.Idle); + if (originalStatus is Status.Pending) + { + using (divertHandle.GetReference(out var handle)) + { + _ = PInvoke.CancelIoEx(new(handle), pendingOperation.NativeOverlapped); + } + } + } + + // After ERROR_IO_PENDING. + protected void CancelIfRequested() + { + uint originalStatus = Interlocked.CompareExchange(ref status, Status.Pending, Status.Idle); + if (originalStatus is Status.Canceled) + { + using (divertHandle.GetReference(out var handle)) + { + _ = PInvoke.CancelIoEx(new(handle), pendingOperation.NativeOverlapped); + } + } + } + + // From completion callback. + private void ResetIfPendingOrCanceled() + { + uint originalStatus = Interlocked.CompareExchange(ref status, Status.Idle, Status.Pending); + if (originalStatus is Status.Canceled) + { + Interlocked.CompareExchange(ref status, Status.Idle, originalStatus); + } + } + + protected PendingOperation PrepareOperation( + Memory packetBuffer, + Memory addresses, + CancellationToken cancellationToken + ) + { + var cancellationTokenRegistration = cancellationToken.CanBeCanceled + ? cancellationToken.UnsafeRegister( + static state => ((DivertValueTaskSource)state!).ExecuteOrRequestCancel(), + this + ) + : default; + var nativeOverlapped = threadPoolBoundHandle.AllocateNativeOverlapped(preAllocatedOverlapped); + this.nativeOverlapped = new(nativeOverlapped); + pendingOperation = new PendingOperation( + nativeOverlapped, + packetBuffer, + addresses, + cancellationTokenRegistration + ); + return pendingOperation; + } + + private void DisposePendingOperation() + { + var overlapped = Interlocked.Exchange(ref nativeOverlapped, default); + if (overlapped != default) + { + pendingOperation.Dispose(); + threadPoolBoundHandle.FreeNativeOverlapped((NativeOverlapped*)overlapped); + } + } + + public void Dispose() + { + if (Interlocked.Exchange(ref status, Status.Disposed) is Status.Disposed) + { + return; + } + + DisposePendingOperation(); + preAllocatedOverlapped.Dispose(); + GC.SuppressFinalize(this); + } + + ~DivertValueTaskSource() + { + Dispose(); + } + + protected abstract void Complete(uint errorCode, uint numBytes, in PendingOperation pendingOperation); + + private static void OnIOCompleted(uint errorCode, uint numBytes, NativeOverlapped* pOVERLAP) + { + var vts = (DivertValueTaskSource)ThreadPoolBoundHandle.GetNativeOverlappedState(pOVERLAP)!; + try + { + vts.Complete(errorCode, numBytes, vts.pendingOperation); + vts.ResetIfPendingOrCanceled(); + } + finally + { + vts.DisposePendingOperation(); + } + } +} diff --git a/Divert.Windows/AsyncOperation/PendingOperation.cs b/Divert.Windows/AsyncOperation/PendingOperation.cs new file mode 100644 index 0000000..0de4c5a --- /dev/null +++ b/Divert.Windows/AsyncOperation/PendingOperation.cs @@ -0,0 +1,25 @@ +using System.Buffers; + +namespace Divert.Windows.AsyncOperation; + +internal readonly unsafe struct PendingOperation( + NativeOverlapped* nativeOverlapped, + Memory packetBuffer, + Memory addresses, + CancellationTokenRegistration cancellationTokenRegistration +) : IDisposable +{ + public NativeOverlapped* NativeOverlapped { get; } = nativeOverlapped; + public MemoryHandle PacketBufferHandle { get; } = packetBuffer.Pin(); + public MemoryHandle AddressesHandle { get; } = addresses.Pin(); + + public readonly Memory Addresses => addresses; + public readonly CancellationToken CancellationToken => cancellationTokenRegistration.Token; + + public void Dispose() + { + cancellationTokenRegistration.Dispose(); + PacketBufferHandle.Dispose(); + AddressesHandle.Dispose(); + } +} diff --git a/Divert.Windows/Divert.Windows.csproj b/Divert.Windows/Divert.Windows.csproj index ee50c1c..d475485 100644 --- a/Divert.Windows/Divert.Windows.csproj +++ b/Divert.Windows/Divert.Windows.csproj @@ -2,7 +2,10 @@ $(DefaultTargetFramework) x64 - true + + + + $(MSBuildProjectDirectory)=$(MSBuildProjectName) diff --git a/Divert.Windows/DivertHandle.cs b/Divert.Windows/DivertHandle.cs index bd2c568..3238d3a 100644 --- a/Divert.Windows/DivertHandle.cs +++ b/Divert.Windows/DivertHandle.cs @@ -4,13 +4,11 @@ namespace Divert.Windows; internal sealed class DivertHandle : SafeHandleZeroOrMinusOneIsInvalid { - internal IntPtr Handle => handle; - public DivertHandle(IntPtr handle) - : base(ownsHandle: false) + : base(ownsHandle: true) { this.handle = handle; } - protected override bool ReleaseHandle() => true; + protected override bool ReleaseHandle() => NativeMethods.WinDivertClose(handle); } diff --git a/Divert.Windows/DivertService.cs b/Divert.Windows/DivertService.cs index f907ac5..4fb61f5 100644 --- a/Divert.Windows/DivertService.cs +++ b/Divert.Windows/DivertService.cs @@ -1,6 +1,7 @@ using System.ComponentModel; using System.Runtime.InteropServices; using System.Threading.Channels; +using Divert.Windows.AsyncOperation; using Windows.Win32.Foundation; namespace Divert.Windows; @@ -10,10 +11,11 @@ namespace Divert.Windows; /// public sealed unsafe class DivertService : IDisposable { - private readonly DivertHandle handle; + private readonly DivertHandle divertHandle; private readonly ThreadPoolBoundHandle threadPoolBoundHandle; - private readonly Channel vtsPool = Channel.CreateUnbounded(); + private readonly Channel receiveVtsPool; + private readonly Channel sendVtsPool; /// /// Opens a WinDivert handle for the given filter. @@ -53,50 +55,39 @@ public DivertService( { throw new Win32Exception(Marshal.GetLastPInvokeError()); } - this.handle = new DivertHandle(handle); - threadPoolBoundHandle = ThreadPoolBoundHandle.BindHandle(this.handle); + divertHandle = new DivertHandle(handle); + threadPoolBoundHandle = ThreadPoolBoundHandle.BindHandle(divertHandle); + receiveVtsPool = Channel.CreateUnbounded(); + sendVtsPool = Channel.CreateUnbounded(); } private bool disposed = false; - private bool closed = false; - private void Close(bool throwOnError) + private static void DisposeVtsPool(Channel vtsPool) + where T : IDisposable { - if (!closed) + vtsPool.Writer.TryComplete(); + while (vtsPool.Reader.TryRead(out var vts)) { - bool success = NativeMethods.WinDivertClose(handle.Handle); - if (!success && throwOnError) - { - throw new Win32Exception(Marshal.GetLastPInvokeError()); - } - closed = true; + vts.Dispose(); } } - /// - /// Closes the WinDivert handle. - /// - public void Close() => Close(throwOnError: true); - /// /// Releases all resources used by the . /// public void Dispose() { - if (!disposed) + if (disposed) { - if (vtsPool.Writer.TryComplete()) - { - while (vtsPool.Reader.TryRead(out var vts)) - { - vts.Dispose(); - } - } - threadPoolBoundHandle.Dispose(); - handle.Dispose(); - Close(throwOnError: false); - disposed = true; + return; } + disposed = true; + + DisposeVtsPool(receiveVtsPool); + DisposeVtsPool(sendVtsPool); + threadPoolBoundHandle.Dispose(); + divertHandle.Dispose(); GC.SuppressFinalize(this); } @@ -105,24 +96,16 @@ public void Dispose() Dispose(); } - private void ThrowIfClosedOrDisposed() - { - if (closed) - { - throw new InvalidOperationException(nameof(closed)); - } - ObjectDisposedException.ThrowIf(disposed, this); - } - /// /// Shuts down the WinDivert handle. /// /// Specifies how to shut down the handle. public void Shutdown(DivertShutdown how) { - ThrowIfClosedOrDisposed(); + ObjectDisposedException.ThrowIf(disposed, this); - bool success = NativeMethods.WinDivertShutdown(handle.Handle, (WINDIVERT_SHUTDOWN)how); + using var _ = divertHandle.GetReference(out var handle); + bool success = NativeMethods.WinDivertShutdown(handle, (WINDIVERT_SHUTDOWN)how); if (!success) { throw new Win32Exception(Marshal.GetLastPInvokeError()); @@ -142,13 +125,17 @@ public ValueTask ReceiveAsync( CancellationToken cancellationToken = default ) { - ThrowIfClosedOrDisposed(); + ObjectDisposedException.ThrowIf(disposed, this); + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } - if (!vtsPool.Reader.TryRead(out var vts)) + if (!receiveVtsPool.Reader.TryRead(out var vts)) { - vts = new DivertValueTaskSource( - vtsPool, - handle.Handle, + vts = new DivertReceiveValueTaskSource( + receiveVtsPool, + divertHandle, threadPoolBoundHandle, runContinuationsAsynchronously: true ); @@ -169,13 +156,17 @@ public ValueTask SendAsync( CancellationToken cancellationToken = default ) { - ThrowIfClosedOrDisposed(); + ObjectDisposedException.ThrowIf(disposed, this); + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } - if (!vtsPool.Reader.TryRead(out var vts)) + if (!sendVtsPool.Reader.TryRead(out var vts)) { - vts = new DivertValueTaskSource( - vtsPool, - handle.Handle, + vts = new DivertSendValueTaskSource( + sendVtsPool, + divertHandle, threadPoolBoundHandle, runContinuationsAsynchronously: true ); diff --git a/Divert.Windows/DivertValueTaskSource.cs b/Divert.Windows/DivertValueTaskSource.cs deleted file mode 100644 index a2ce416..0000000 --- a/Divert.Windows/DivertValueTaskSource.cs +++ /dev/null @@ -1,287 +0,0 @@ -using System.Buffers; -using System.ComponentModel; -using System.Runtime.InteropServices; -using System.Threading.Channels; -using System.Threading.Tasks.Sources; -using Windows.Win32; -using Windows.Win32.Foundation; - -namespace Divert.Windows; - -internal sealed unsafe class DivertValueTaskSource - : IValueTaskSource, - IValueTaskSource, - IDisposable -{ - private static class Status - { - public const uint Idle = 0; - public const uint Canceled = 1; - public const uint Pending = 2; - public const uint Disposed = 3; - } - - private static readonly IOCompletionCallback ioCompletionCallback = OnIOCompleted; - - private ManualResetValueTaskSourceCore source; - private readonly Channel pool; - private readonly IntPtr divertHandle; - private readonly ThreadPoolBoundHandle threadPoolBoundHandle; - private readonly PreAllocatedOverlapped preAllocatedOverlapped; - private readonly Memory addressesLengthBuffer; - - private struct PendingOperation( - DivertValueTaskSource vts, - Memory packetBuffer, - Memory addresses, - CancellationTokenRegistration cancellationTokenRegistration - ) : IDisposable - { - public NativeOverlapped* NativeOverlapped { get; private set; } = - vts.threadPoolBoundHandle.AllocateNativeOverlapped(vts.preAllocatedOverlapped); - - public readonly Memory Addresses => addresses; - - public MemoryHandle PacketBufferHandle { get; } = packetBuffer.Pin(); - public MemoryHandle AddressesHandle { get; } = addresses.Pin(); - public MemoryHandle AddressesLengthBufferHandle { get; } = vts.addressesLengthBuffer.Pin(); - - public readonly CancellationToken CancellationToken => cancellationTokenRegistration.Token; - - public void Dispose() - { - cancellationTokenRegistration.Dispose(); - - PacketBufferHandle.Dispose(); - AddressesHandle.Dispose(); - AddressesLengthBufferHandle.Dispose(); - - if (NativeOverlapped is not null) - { - vts.threadPoolBoundHandle.FreeNativeOverlapped(NativeOverlapped); - NativeOverlapped = null; - } - } - } - - private PendingOperation pendingOperation; - private uint status; - - public DivertValueTaskSource( - Channel pool, - IntPtr divertHandle, - ThreadPoolBoundHandle threadPoolBoundHandle, - bool runContinuationsAsynchronously - ) - { - source.RunContinuationsAsynchronously = runContinuationsAsynchronously; - this.pool = pool; - this.divertHandle = divertHandle; - this.threadPoolBoundHandle = threadPoolBoundHandle; - preAllocatedOverlapped = new PreAllocatedOverlapped(ioCompletionCallback, this, null); - addressesLengthBuffer = GC.AllocateArray(1, pinned: true); - } - - private void ExecuteOrRequestCancel() - { - uint originalStatus = Interlocked.CompareExchange(ref status, Status.Canceled, Status.Idle); - if (originalStatus is Status.Pending) - { - _ = PInvoke.CancelIoEx(new(divertHandle), pendingOperation.NativeOverlapped); - } - } - - private void CancelIfRequested() - { - uint originalStatus = Interlocked.CompareExchange(ref status, Status.Pending, Status.Idle); - if (originalStatus is Status.Canceled) - { - _ = PInvoke.CancelIoEx(new(divertHandle), pendingOperation.NativeOverlapped); - Interlocked.CompareExchange(ref status, Status.Idle, originalStatus); - } - } - - private PendingOperation PrepareOperation( - Memory packetBuffer, - Memory addresses, - CancellationToken cancellationToken - ) - { - var cancellationTokenRegistration = cancellationToken.CanBeCanceled - ? cancellationToken.UnsafeRegister( - static state => ((DivertValueTaskSource)state!).ExecuteOrRequestCancel(), - this - ) - : default; - pendingOperation = new PendingOperation(this, packetBuffer, addresses, cancellationTokenRegistration); - return pendingOperation; - } - - public void Dispose() - { - if (Interlocked.Exchange(ref status, Status.Disposed) is Status.Disposed) - { - return; - } - - pendingOperation.Dispose(); - preAllocatedOverlapped.Dispose(); - GC.SuppressFinalize(this); - } - - ~DivertValueTaskSource() - { - Dispose(); - } - - internal short Version => source.Version; - - private void Complete(uint errorCode, uint numBytes) - { - var addresses = pendingOperation.Addresses; - int addressesLength = (int)addressesLengthBuffer.Span[0] / sizeof(DivertAddress); - var token = pendingOperation.CancellationToken; - pendingOperation.Dispose(); - - if (errorCode == (uint)WIN32_ERROR.ERROR_NO_DATA) - { - errorCode = 0; - } - - if (errorCode == (uint)WIN32_ERROR.ERROR_SUCCESS) - { - source.SetResult(new DivertReceiveResult((int)numBytes, addresses[..addressesLength])); - } - else if (errorCode == (uint)WIN32_ERROR.ERROR_OPERATION_ABORTED) - { - var exception = token.IsCancellationRequested - ? new OperationCanceledException() - : new OperationCanceledException(token); - source.SetException(exception); - } - else - { - var exception = new Win32Exception((int)errorCode); - source.SetException(exception); - } - } - - private static void OnIOCompleted(uint errorCode, uint numBytes, NativeOverlapped* pOVERLAP) - { - var vts = (DivertValueTaskSource)ThreadPoolBoundHandle.GetNativeOverlappedState(pOVERLAP)!; - vts.Complete(errorCode, numBytes); - } - - public ValueTaskSourceStatus GetStatus(short token) => source.GetStatus(token); - - public void OnCompleted( - Action continuation, - object? state, - short token, - ValueTaskSourceOnCompletedFlags flags - ) => source.OnCompleted(continuation, state, token, flags); - - public DivertReceiveResult GetResult(short token) - { - try - { - return source.GetResult(token); - } - finally - { - Interlocked.CompareExchange(ref status, Status.Idle, Status.Pending); - source.Reset(); - if (!pool.Writer.TryWrite(this)) - { - Dispose(); - } - } - } - - void IValueTaskSource.GetResult(short token) => GetResult(token); - - public ValueTask ReceiveAsync( - Memory buffer, - Memory addresses, - CancellationToken cancellationToken - ) - { - if (cancellationToken.IsCancellationRequested) - { - return ValueTask.FromCanceled(cancellationToken); - } - - addressesLengthBuffer.Span[0] = (uint)(addresses.Length * sizeof(DivertAddress)); - var pendingOperation = PrepareOperation(buffer, addresses, cancellationToken); - bool success = NativeMethods.WinDivertRecvEx( - divertHandle, - pendingOperation.PacketBufferHandle.Pointer, - (uint)buffer.Length, - null, - 0, - (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, - (uint*)pendingOperation.AddressesLengthBufferHandle.Pointer, - pendingOperation.NativeOverlapped - ); - if (!success) - { - int error = Marshal.GetLastPInvokeError(); - if (error == (int)WIN32_ERROR.ERROR_IO_PENDING) - { - CancelIfRequested(); - } - else - { - pendingOperation.Dispose(); - Interlocked.CompareExchange(ref status, Status.Idle, Status.Canceled); - return ValueTask.FromException(new Win32Exception(error)); - } - } - - return new ValueTask(this, source.Version); - } - - public ValueTask SendAsync( - ReadOnlyMemory buffer, - ReadOnlyMemory addresses, - CancellationToken cancellationToken - ) - { - if (cancellationToken.IsCancellationRequested) - { - return ValueTask.FromCanceled(cancellationToken); - } - - var pendingOperation = PrepareOperation( - MemoryMarshal.AsMemory(buffer), - MemoryMarshal.AsMemory(addresses), - cancellationToken - ); - bool success = NativeMethods.WinDivertSendEx( - divertHandle, - pendingOperation.PacketBufferHandle.Pointer, - (uint)buffer.Length, - null, - 0, - (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, - (uint)(addresses.Length * sizeof(DivertAddress)), - pendingOperation.NativeOverlapped - ); - if (!success) - { - int error = Marshal.GetLastPInvokeError(); - if (error == (int)WIN32_ERROR.ERROR_IO_PENDING) - { - CancelIfRequested(); - } - else - { - pendingOperation.Dispose(); - Interlocked.CompareExchange(ref status, Status.Idle, Status.Canceled); - return ValueTask.FromException(new Win32Exception(error)); - } - } - - return new ValueTask(this, source.Version); - } -} diff --git a/Divert.Windows/SafeHandleExtensions.cs b/Divert.Windows/SafeHandleExtensions.cs new file mode 100644 index 0000000..c5e69bb --- /dev/null +++ b/Divert.Windows/SafeHandleExtensions.cs @@ -0,0 +1,24 @@ +using System.Runtime.InteropServices; + +namespace Divert.Windows; + +internal readonly struct SafeHandleReference(T safeHandle) : IDisposable + where T : SafeHandle +{ + public void Dispose() + { + safeHandle.DangerousRelease(); + } +} + +internal static class SafeHandleExtensions +{ + public static SafeHandleReference GetReference(this T safeHandle, out IntPtr handle) + where T : SafeHandle + { + bool success = false; + safeHandle.DangerousAddRef(ref success); + handle = safeHandle.DangerousGetHandle(); + return new SafeHandleReference(safeHandle); + } +} diff --git a/Examples/Ping/Ping.csproj b/Examples/Ping/Ping.csproj index 848f1cd..47c9acc 100644 --- a/Examples/Ping/Ping.csproj +++ b/Examples/Ping/Ping.csproj @@ -2,9 +2,9 @@ Exe $(DefaultTargetFramework) - true win-x64 - $(ProjectRoot).local/Ping + true + $(ProjectRoot).local/windows/$(MSBuildThisFileName) CA2007 diff --git a/Examples/Ping/Program.cs b/Examples/Ping/Program.cs index bb3e8f4..71e97ef 100644 --- a/Examples/Ping/Program.cs +++ b/Examples/Ping/Program.cs @@ -26,15 +26,15 @@ string? remoteIP = args.Length > 0 ? args[0] : null; if (remoteIP is null) { - var gateway = NetworkInterface - .GetAllNetworkInterfaces() - .Where(nic => - nic.OperationalStatus == OperationalStatus.Up && nic.NetworkInterfaceType != NetworkInterfaceType.Loopback - ) - .SelectMany(nic => nic.GetIPProperties()?.GatewayAddresses?.ToArray() ?? []) - .Select(g => g.Address) - .FirstOrDefault(address => address.AddressFamily == AddressFamily.InterNetwork); - remoteIP = gateway?.ToString(); + var gateWays = + from nic in NetworkInterface.GetAllNetworkInterfaces() + where + nic is { OperationalStatus: OperationalStatus.Up, NetworkInterfaceType: not NetworkInterfaceType.Loopback } + from gateway in nic.GetIPProperties().GatewayAddresses + let address = gateway.Address + where address is { AddressFamily: AddressFamily.InterNetwork } + select address; + remoteIP = gateWays.FirstOrDefault()?.ToString(); } remoteIP ??= "1.1.1.1"; diff --git a/package.json b/package.json index 3be0d0a..a9bd675 100644 --- a/package.json +++ b/package.json @@ -6,7 +6,9 @@ "cake": "dotnet run --project Automation --target", "format": "pnpm run cake Format", "lint": "pnpm run cake Lint", - "restore": "pnpm run cake Restore" + "restore": "pnpm run cake Restore", + "test-report": "pnpm run cake TestReport", + "test-report-watch": "pnpm run cake TestReportWatch" }, "dependencies": { "@prettier/plugin-xml": "^3.4.2", From 7ced53f6fbe520dabe4e07e1ae62181db7206c74 Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 2 Nov 2025 08:49:09 +0000 Subject: [PATCH 03/20] test runner watch --- Automation/Build.cs | 2 + Automation/Tasks/CIFSMount.cs | 7 +- Automation/Test.cs | 99 +++++++------ Divert.Windows.TestRunner/AssemblyInfo.cs | 4 - .../Divert.Windows.TestRunner.csproj | 4 +- Divert.Windows.TestRunner/Program.cs | 140 +++++++++++++++--- .../Divert.Windows.Tests.csproj | 6 +- Examples/Ping/Ping.csproj | 2 +- package.json | 4 +- 9 files changed, 188 insertions(+), 80 deletions(-) delete mode 100644 Divert.Windows.TestRunner/AssemblyInfo.cs diff --git a/Automation/Build.cs b/Automation/Build.cs index 3736eb1..8b54cb9 100644 --- a/Automation/Build.cs +++ b/Automation/Build.cs @@ -1,3 +1,4 @@ +using Cake.Common.IO; using Cake.Common.Tools.DotNet; using Cake.Common.Tools.DotNet.MSBuild; using Cake.Frosting; @@ -8,6 +9,7 @@ public class Build : FrostingTask { public override void Run(Context context) { + context.CleanDirectory(Context.LocalWindowsDirectory); context.DotNetBuild( Context.ProjectRoot, new() { MSBuildSettings = new() { TreatAllWarningsAs = MSBuildTreatAllWarningsAs.Error } } diff --git a/Automation/Tasks/CIFSMount.cs b/Automation/Tasks/CIFSMount.cs index 6e0c2d1..a4f59b9 100644 --- a/Automation/Tasks/CIFSMount.cs +++ b/Automation/Tasks/CIFSMount.cs @@ -1,3 +1,4 @@ +using Cake.Common.Diagnostics; using Cake.Common.Tools.Command; using Cake.Core; using Cake.Core.Diagnostics; @@ -40,7 +41,7 @@ public override void Run(Context context) var config = LoadConfig(); if (config is null) { - context.Log.Information("CIFS share not configured, skipping."); + context.Information("CIFS share not configured, skipping."); return; } @@ -49,7 +50,7 @@ public override void Run(Context context) { if (DriveInfo.GetDrives().Any(drive => drive.DriveType is DriveType.Network && drive.Name == mountPath)) { - context.Log.Information("CIFS share already mounted at {0}, skipping.", mountPath); + context.Information("CIFS share already mounted at {0}, skipping.", mountPath); // Already mounted. return; } @@ -69,7 +70,7 @@ public override void Run(Context context) } catch (Exception ex) { - context.Log.Warning("Failed to mount CIFS share: {0}", ex.Message); + context.Warning("Failed to mount CIFS share: {0}", ex.Message); } } } diff --git a/Automation/Test.cs b/Automation/Test.cs index 8518155..b297759 100644 --- a/Automation/Test.cs +++ b/Automation/Test.cs @@ -1,4 +1,5 @@ -using System.Threading.Channels; +using System.Net.Sockets; +using Cake.Common.Diagnostics; using Cake.Common.Tools.DotNet; using Cake.Common.Tools.DotNet.Tool; using Cake.Core.Diagnostics; @@ -8,14 +9,23 @@ namespace Automation; -internal static class Test +public class Test : AsyncFrostingTask { + public const string TEST_HOST = nameof(TEST_HOST); + public const string TestProjectName = "Divert.Windows.Tests"; + public const string TestRunnerProjectName = "Divert.Windows.TestRunner"; + + public static string GetTestHost() => Environment.GetEnvironmentVariable(TEST_HOST) ?? "host.docker.internal"; + public static string TestResultsDirectory => Path.Combine(Context.LocalWindowsDirectory, "TestResults"); public static string CoverletOutput => Path.Combine(TestResultsDirectory, "coverage.cobertura.xml"); public static string TestReportsDirectory => Path.Combine(Context.LocalDirectory, "TestReports"); + public static string TestRunnerLockFilePath => + Path.Combine(Context.LocalWindowsDirectory, $"{TestRunnerProjectName}/TestRunner.lock"); + public static void GenerateReport(Context context) { string sourcePath = Path.Combine(Context.ProjectRoot, "Divert.Windows"); @@ -37,65 +47,58 @@ public static void GenerateReport(Context context) } ); } -} -public class TestReport : FrostingTask -{ - public override void Run(Context context) - { - Test.GenerateReport(context); - } -} - -public class TestReportWatch : AsyncFrostingTask -{ public override async Task RunAsync(Context context) { - var changed = Channel.CreateBounded(10); - using var watcher = new FileSystemWatcher + context.DotNetBuild(Path.Combine(Context.ProjectRoot, TestProjectName)); + if (!Directory.Exists(Path.Combine(Context.LocalWindowsDirectory, TestRunnerProjectName))) { - Path = Context.LocalWindowsDirectory, - Filter = "coverage.cobertura.xml", - EnableRaisingEvents = true, - IncludeSubdirectories = false, - NotifyFilter = NotifyFilters.LastWrite, - }; - watcher.Changed += (_, e) => changed.Writer.TryWrite(0); + context.DotNetBuild(Path.Combine(Context.ProjectRoot, TestRunnerProjectName)); + } - var poll = Task.Run(async () => + context.Information($"Waiting for test runner lock file {TestRunnerLockFilePath}..."); + int port = 0; + while (port is 0) { - while (changed.Writer.TryWrite(0)) + try + { + string text = await File.ReadAllTextAsync(TestRunnerLockFilePath); + port = int.Parse(text); + break; + } + catch { await Task.Delay(TimeSpan.FromMilliseconds(1000)); } - }); - - var exit = Task.Run(() => - { - Console.ReadLine(); - context.Log.Information("Stopping coverage report watcher..."); - changed.Writer.TryComplete(); - }); + } - DateTimeOffset? lastChanged = null; - await foreach (var _ in changed.Reader.ReadAllAsync()) + using var client = new TcpClient(); + await client.ConnectAsync(GetTestHost(), port); + using var stream = client.GetStream(); + using var reader = new StreamReader(stream); + string lastLine = string.Empty; + while (true) { - try + string? line = await reader.ReadLineAsync(); + if (line is null) { - var offset = new DateTimeOffset(File.GetLastWriteTime(Test.CoverletOutput)); - if (lastChanged != offset) - { - lastChanged = offset; - context.Log.Information("Detected change to coverage file, regenerating report..."); - Test.GenerateReport(context); - } - } - catch (Exception e) - { - context.Log.Error("Error regenerating report: {0}", e); + break; } + lastLine = line; + Console.WriteLine(line); + } + if (!int.TryParse(lastLine, out int exitCode) || exitCode != 0) + { + throw new Exception("Tests failed"); } - await Task.WhenAll(poll, exit); - context.Log.Information("Coverage report watcher stopped."); + GenerateReport(context); + } +} + +public class TestReport : FrostingTask +{ + public override void Run(Context context) + { + Test.GenerateReport(context); } } diff --git a/Divert.Windows.TestRunner/AssemblyInfo.cs b/Divert.Windows.TestRunner/AssemblyInfo.cs deleted file mode 100644 index a298fba..0000000 --- a/Divert.Windows.TestRunner/AssemblyInfo.cs +++ /dev/null @@ -1,4 +0,0 @@ -using System.Runtime.Versioning; - -[assembly: DoNotParallelize] -[assembly: SupportedOSPlatform("windows6.0.6000")] diff --git a/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj b/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj index 9e9c4dc..efea64a 100644 --- a/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj +++ b/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj @@ -10,7 +10,9 @@ - + + none + diff --git a/Divert.Windows.TestRunner/Program.cs b/Divert.Windows.TestRunner/Program.cs index 7e1647f..0ad9c9e 100644 --- a/Divert.Windows.TestRunner/Program.cs +++ b/Divert.Windows.TestRunner/Program.cs @@ -1,29 +1,133 @@ using System.Diagnostics; +using System.Net; +using System.Net.Sockets; +using System.Reflection; +using System.Text; +using System.Threading.Channels; string appPath = Path.GetDirectoryName(Environment.ProcessPath)!; string testResultsDirectory = Path.Combine(appPath, "../TestResults"); string coverageSettingsPath = Path.Combine(appPath, "CoverageSettings.xml"); string coverageOutputPath = Path.Combine(testResultsDirectory, "coverage.cobertura.xml"); +string testAppPath = Path.Combine(appPath, "../Divert.Windows.Tests"); -Directory.SetCurrentDirectory(appPath); +Directory.SetCurrentDirectory(testAppPath); -using var process = Process.Start( - new ProcessStartInfo +Process LaunchTestProcess(bool redirect) => + Process.Start( + new ProcessStartInfo() + { + FileName = Path.Combine(testAppPath, "Divert.Windows.Tests.exe"), + ArgumentList = + { + "--results-directory", + testResultsDirectory, + "--coverage", + "--coverage-settings", + coverageSettingsPath, + "--coverage-output-format", + "cobertura", + "--coverage-output", + coverageOutputPath, + }, + RedirectStandardOutput = redirect, + RedirectStandardError = redirect, + Environment = + { + // MSTest should respect DOTNET_SYSTEM_CONSOLE_ALLOW_ANSI_COLOR_REDIRECTION instead. + ["GITHUB_ACTIONS"] = "true", // Trick MSTest to output ANSI colors. + }, + } + )!; + +if (args is not ["watch", ..]) +{ + using var process = LaunchTestProcess(redirect: false); + await process.WaitForExitAsync(); + return process.ExitCode; +} + +Console.WriteLine("Starting Test Runner in watch mode..."); +using var mutex = new Mutex(true, Assembly.GetExecutingAssembly().FullName, out bool createdNew); +if (!createdNew) +{ + Console.WriteLine("Another instance is already running. Exiting..."); + return 1; +} + +using var cts = new CancellationTokenSource(); +Console.CancelKeyPress += (s, e) => +{ + e.Cancel = true; + cts.Cancel(); +}; +var token = cts.Token; + +using var listener = new TcpListener(IPAddress.Any, 0); +listener.Start(); +int port = ((IPEndPoint)listener.LocalEndpoint).Port; +Console.WriteLine($"Listening on port {port}..."); +using var lockFile = new FileStream( + Path.Combine(appPath, "TestRunner.lock"), + FileMode.Create, + FileAccess.ReadWrite, + FileShare.Read, + bufferSize: 0, + FileOptions.Asynchronous | FileOptions.DeleteOnClose +); +await lockFile.WriteAsync(Encoding.UTF8.GetBytes(port.ToString() + '\n'), token); +await lockFile.FlushAsync(token); + +try +{ + while (!token.IsCancellationRequested) { - FileName = Path.Combine(appPath, "Divert.Windows.Tests.exe"), - ArgumentList = + Console.WriteLine("Waiting for client connection..."); + using var client = await listener.AcceptSocketAsync(token); + Console.WriteLine("Received client connection."); + using var stream = new NetworkStream(client, ownsSocket: false); + using var process = LaunchTestProcess(redirect: true); + using var _ = token.Register(() => process.Kill(entireProcessTree: true)); + + var lines = Channel.CreateBounded(1024); + var stdOutReader = new StreamReader(process.StandardOutput.BaseStream); + var stdErrReader = new StreamReader(process.StandardError.BaseStream); + + async Task ForwardLines(StreamReader reader) { - "--results-directory", - testResultsDirectory, - "--coverage", - "--coverage-settings", - coverageSettingsPath, - "--coverage-output-format", - "cobertura", - "--coverage-output", - coverageOutputPath, - }, + string? line = null; + while (true) + { + line = await reader.ReadLineAsync(token); + if (line is null) + { + break; + } + await lines.Writer.WriteAsync(line, token); + } + } + var readStdOutTask = ForwardLines(stdOutReader); + var readStdErrTask = ForwardLines(stdErrReader); + + using var writer = new StreamWriter(stream) { AutoFlush = true }; + var writeTask = Task.Run( + async () => + { + await foreach (var line in lines.Reader.ReadAllAsync(token)) + { + Console.WriteLine(line); + await writer.WriteLineAsync(line.AsMemory(), token); + } + }, + token + ); + + await process.WaitForExitAsync(token); + lines.Writer.Complete(); + await Task.WhenAll(readStdOutTask, readStdErrTask, writeTask); + await writer.WriteLineAsync(process.ExitCode.ToString().AsMemory(), token); } -)!; -await process.WaitForExitAsync(); -return process.ExitCode; +} +catch (OperationCanceledException) when (token.IsCancellationRequested) { } + +return 0; diff --git a/Divert.Windows.Tests/Divert.Windows.Tests.csproj b/Divert.Windows.Tests/Divert.Windows.Tests.csproj index f2f94bb..209d8eb 100644 --- a/Divert.Windows.Tests/Divert.Windows.Tests.csproj +++ b/Divert.Windows.Tests/Divert.Windows.Tests.csproj @@ -3,12 +3,12 @@ Exe true $(DefaultTargetFramework) + win-x64 + true + $(ProjectRoot).local/windows/$(MSBuildThisFileName) false true CA2007 - - win-x64 - true diff --git a/Examples/Ping/Ping.csproj b/Examples/Ping/Ping.csproj index 47c9acc..f9fd15d 100644 --- a/Examples/Ping/Ping.csproj +++ b/Examples/Ping/Ping.csproj @@ -4,7 +4,7 @@ $(DefaultTargetFramework) win-x64 true - $(ProjectRoot).local/windows/$(MSBuildThisFileName) + $(ProjectRoot).local/windows/Examples/$(MSBuildThisFileName) CA2007 diff --git a/package.json b/package.json index a9bd675..72ff098 100644 --- a/package.json +++ b/package.json @@ -7,8 +7,8 @@ "format": "pnpm run cake Format", "lint": "pnpm run cake Lint", "restore": "pnpm run cake Restore", - "test-report": "pnpm run cake TestReport", - "test-report-watch": "pnpm run cake TestReportWatch" + "test": "pnpm run cake Test", + "test-report": "pnpm run cake TestReport" }, "dependencies": { "@prettier/plugin-xml": "^3.4.2", From e0fcd42a09472d29bd1b0666a65ce3a2577c6c72 Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 2 Nov 2025 17:24:06 +0000 Subject: [PATCH 04/20] tests --- .../Divert.Windows.TestRunner.csproj | 5 -- Divert.Windows.TestRunner/Program.cs | 7 +- .../Divert.Windows.Tests.csproj | 8 ++ Divert.Windows.Tests/DivertServiceTests.cs | 79 ++++++++++++++++++ Divert.Windows.Tests/FilterTests.cs | 81 ++++++++++++++++++- Divert.Windows/AssemblyInfo.cs | 2 + Divert.Windows/DivertFilter.cs | 19 +++-- Divert.Windows/DivertHelper.cs | 2 +- 8 files changed, 187 insertions(+), 16 deletions(-) create mode 100644 Divert.Windows.Tests/DivertServiceTests.cs diff --git a/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj b/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj index efea64a..59acd39 100644 --- a/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj +++ b/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj @@ -16,11 +16,6 @@ - diff --git a/Divert.Windows.TestRunner/Program.cs b/Divert.Windows.TestRunner/Program.cs index 0ad9c9e..69320d6 100644 --- a/Divert.Windows.TestRunner/Program.cs +++ b/Divert.Windows.TestRunner/Program.cs @@ -124,8 +124,11 @@ async Task ForwardLines(StreamReader reader) await process.WaitForExitAsync(token); lines.Writer.Complete(); - await Task.WhenAll(readStdOutTask, readStdErrTask, writeTask); - await writer.WriteLineAsync(process.ExitCode.ToString().AsMemory(), token); + await Task.WhenAll(readStdOutTask, readStdErrTask, writeTask) + .ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + await writer + .WriteLineAsync(process.ExitCode.ToString().AsMemory(), token) + .ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); } } catch (OperationCanceledException) when (token.IsCancellationRequested) { } diff --git a/Divert.Windows.Tests/Divert.Windows.Tests.csproj b/Divert.Windows.Tests/Divert.Windows.Tests.csproj index 209d8eb..b45d517 100644 --- a/Divert.Windows.Tests/Divert.Windows.Tests.csproj +++ b/Divert.Windows.Tests/Divert.Windows.Tests.csproj @@ -21,4 +21,12 @@ + + + + diff --git a/Divert.Windows.Tests/DivertServiceTests.cs b/Divert.Windows.Tests/DivertServiceTests.cs new file mode 100644 index 0000000..c1a09bb --- /dev/null +++ b/Divert.Windows.Tests/DivertServiceTests.cs @@ -0,0 +1,79 @@ +using System.Net; +using System.Net.Sockets; + +namespace Divert.Windows.Tests; + +[TestClass] +public sealed class DivertServiceTests : IDisposable +{ + private readonly CancellationTokenSource cts; + private readonly CancellationToken token; + + public DivertServiceTests() + { + cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + token = cts.Token; + } + + public void Dispose() + { + cts.Dispose(); + } + + private static UdpClient CreateUdpListener(out int port) + { + var client = new UdpClient(new IPEndPoint(IPAddress.Loopback, 0)); + var localEndPoint = (IPEndPoint)client.Client.LocalEndPoint!; + port = localEndPoint.Port; + return client; + } + + [TestMethod] + public async Task ModifyPayload() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port.ToString()); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + await client.SendAsync(new byte[] { 1, 2, 3 }, token); + + var divertResult = await divertReceive; + Assert.AreEqual(20 + 8 + 3, divertResult.Length); + var packet = packetBuffer.AsMemory(0, divertResult.Length); + CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, packet.ToArray()[28..]); + Assert.IsFalse(receive.IsCompleted); + + // Re-inject + addressBuffer[0] = new DivertAddress { IsImpostor = true, IsOutbound = true }; + await service.SendAsync(packet, addressBuffer, token); + var result = await receive; + Assert.HasCount(3, result.Buffer); + CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, result.Buffer); + + // Re-inject with modified data + receive = listener.ReceiveAsync(token).AsTask(); + new byte[] { 4, 5, 6 }.CopyTo(packet.Span[28..]); + DivertHelper.CalculateChecksums(packet.Span); + Assert.IsFalse(receive.IsCompleted); + await service.SendAsync(packet, addressBuffer, token); + result = await receive; + Assert.HasCount(3, result.Buffer); + CollectionAssert.AreEqual(new byte[] { 4, 5, 6 }, result.Buffer); + } +} diff --git a/Divert.Windows.Tests/FilterTests.cs b/Divert.Windows.Tests/FilterTests.cs index 92b3642..7005f67 100644 --- a/Divert.Windows.Tests/FilterTests.cs +++ b/Divert.Windows.Tests/FilterTests.cs @@ -3,11 +3,86 @@ [TestClass] public class FilterTests { + public static IEnumerable FilterCases() + { + return + [ + [(DivertFilter.TCP | DivertFilter.UDP), "tcp or udp"], + [(DivertFilter.TCP & DivertFilter.Loopback), "tcp and loopback"], + [(DivertFilter.TCP | DivertFilter.UDP) & DivertFilter.Outbound, "(tcp or udp) and outbound"], + [DivertFilter.Outbound & (DivertFilter.TCP | DivertFilter.UDP), "outbound and (tcp or udp)"], + [ + (DivertFilter.TCP & (DivertFilter.RemotePort == "80" | DivertFilter.RemotePort == "443")) + | DivertFilter.Outbound, + "(tcp and (remotePort = 80 or remotePort = 443)) or outbound", + ], + [ + DivertFilter.Outbound + & (DivertFilter.TCP | DivertFilter.UDP) + & (DivertFilter.RemotePort == "80" | DivertFilter.RemotePort == "443"), + "outbound and (tcp or udp) and (remotePort = 80 or remotePort = 443)", + ], + [ + !DivertFilter.Loopback | (DivertFilter.TCP & DivertFilter.RemotePort == "443"), + "not loopback or (tcp and remotePort = 443)", + ], + [ + DivertFilter.Loopback | (!DivertFilter.TCP & DivertFilter.RemotePort != "443"), + "loopback or (not tcp and remotePort != 443)", + ], + [ + DivertFilter.Inbound & (DivertFilter.LocalPort > "30000" | DivertFilter.LocalPort < "1024"), + "inbound and (localPort > 30000 or localPort < 1024)", + ], + [ + DivertFilter.RemotePort >= "5000" | DivertFilter.RemotePort <= "6000", + "remotePort >= 5000 or remotePort <= 6000", + ], + [ + DivertFilter.Packet(10) == "0x1" + | DivertFilter.Packet16(20) == "0x2" + | DivertFilter.Packet32(30) == "0x3", + "packet[10] = 0x1 or packet16[20] = 0x2 or packet32[30] = 0x3", + ], + ]; + } + [TestMethod] - public void TestFilterToString() + [DynamicData(nameof(FilterCases))] + public void FilterToString(DivertFilter filter, string expected) { - var filter = (DivertFilter.TCP | DivertFilter.UDP) & DivertFilter.Outbound; var result = filter.ToString(); - Assert.AreEqual("(tcp or udp) and outbound", result); + Assert.AreEqual(expected, result); + } + + [TestMethod] + public void InvalidFilter() + { + // Unmatched parentheses + Assert.Throws(() => DivertFilter.Outbound & "(tcp or udp))"); + Assert.Throws(() => DivertFilter.Outbound & "((tcp or udp)"); + } + + [TestMethod] + public void Equals() + { + var filter = new DivertFilter("(tcp or udp) and outbound"); + Assert.IsTrue(filter.Equals(filter)); + Assert.IsTrue(filter.Equals("(tcp or udp) and outbound")); + Assert.IsTrue(filter.Equals((DivertFilter.TCP | DivertFilter.UDP) & DivertFilter.Outbound)); + Assert.IsFalse(filter.Equals(DivertFilter.TCP)); + Assert.IsFalse(filter.Equals(null)); + Assert.IsFalse(new DivertFilter("5").Equals(5)); + Assert.AreEqual(filter.GetHashCode(), filter.ToString().GetHashCode()); + + var field = new DivertFilter.Field("tcp"); + Assert.IsTrue(field.Equals(field)); + Assert.IsTrue(field.Equals("tcp")); + Assert.IsTrue(field.Equals(DivertFilter.TCP)); + Assert.IsFalse(field.Equals("udp")); + Assert.IsFalse(field.Equals(DivertFilter.UDP)); + Assert.IsFalse(field.Equals(null)); + Assert.IsFalse(new DivertFilter.Field("5").Equals(5)); + Assert.AreEqual(field.GetHashCode(), field.ToString().GetHashCode()); } } diff --git a/Divert.Windows/AssemblyInfo.cs b/Divert.Windows/AssemblyInfo.cs index 654c4a8..a39cdcf 100644 --- a/Divert.Windows/AssemblyInfo.cs +++ b/Divert.Windows/AssemblyInfo.cs @@ -1,3 +1,5 @@ +using System.Runtime.CompilerServices; using System.Runtime.Versioning; [assembly: SupportedOSPlatform("windows6.0.6000")] +[assembly: InternalsVisibleTo("Divert.Windows.Tests")] diff --git a/Divert.Windows/DivertFilter.cs b/Divert.Windows/DivertFilter.cs index 7d3cb17..ceebee9 100644 --- a/Divert.Windows/DivertFilter.cs +++ b/Divert.Windows/DivertFilter.cs @@ -112,6 +112,19 @@ public override string ToString() return Clause; } + public override bool Equals(object? obj) + { + return obj switch + { + null => false, + DivertFilter filter => Clause == filter.Clause, + string s => Clause == s, + _ => false, + }; + } + + public override int GetHashCode() => Clause.GetHashCode(); + public class Field { private readonly string field; @@ -134,7 +147,7 @@ public override bool Equals(object? obj) }; } - public override int GetHashCode() => base.GetHashCode(); + public override int GetHashCode() => field.GetHashCode(); public override string ToString() => field; @@ -162,10 +175,6 @@ public static implicit operator DivertFilter(Field field) public static DivertFilter operator !(Field value) { - if (MatchAndPattern(value.field) || MatchOrPattern(value.field)) - { - value = new Field($"({value.field})"); - } string clause = $"not {value}"; return new DivertFilter(clause); } diff --git a/Divert.Windows/DivertHelper.cs b/Divert.Windows/DivertHelper.cs index f2f8f44..26d1d6c 100644 --- a/Divert.Windows/DivertHelper.cs +++ b/Divert.Windows/DivertHelper.cs @@ -5,7 +5,7 @@ namespace Divert.Windows; public static unsafe class DivertHelper { - public static void CalculateChecksums(Span packet, DivertHelperFlags flags) + public static void CalculateChecksums(Span packet, DivertHelperFlags flags = DivertHelperFlags.None) { fixed (byte* pPacket = packet) { From 1e332f7ac6d35960c5f63164359e34bd4023e7bf Mon Sep 17 00:00:00 2001 From: gdlol Date: Wed, 5 Nov 2025 12:22:41 +0000 Subject: [PATCH 05/20] Refactoring --- Divert.Windows.Tests/DivertServiceTests.cs | 33 +++- .../AsyncOperation/CancellationHandle.cs | 60 ++++++ .../AsyncOperation/DivertReceiveExecutor.cs | 37 ++++ .../DivertReceiveValueTaskSource.cs | 121 ------------ .../AsyncOperation/DivertSendExecutor.cs | 36 ++++ .../DivertSendValueTaskSource.cs | 114 ----------- .../AsyncOperation/DivertValueTaskSource.cs | 184 +++++++++--------- .../AsyncOperation/IOCompletionOperation.cs | 86 ++++++++ .../AsyncOperation/PendingOperation.cs | 20 +- Divert.Windows/DivertFilter.cs | 12 +- Divert.Windows/DivertReceiveResult.cs | 13 +- Divert.Windows/DivertService.cs | 80 +++----- Divert.Windows/SafeHandleExtensions.cs | 2 +- 13 files changed, 406 insertions(+), 392 deletions(-) create mode 100644 Divert.Windows/AsyncOperation/CancellationHandle.cs create mode 100644 Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs delete mode 100644 Divert.Windows/AsyncOperation/DivertReceiveValueTaskSource.cs create mode 100644 Divert.Windows/AsyncOperation/DivertSendExecutor.cs delete mode 100644 Divert.Windows/AsyncOperation/DivertSendValueTaskSource.cs create mode 100644 Divert.Windows/AsyncOperation/IOCompletionOperation.cs diff --git a/Divert.Windows.Tests/DivertServiceTests.cs b/Divert.Windows.Tests/DivertServiceTests.cs index c1a09bb..439fc0f 100644 --- a/Divert.Windows.Tests/DivertServiceTests.cs +++ b/Divert.Windows.Tests/DivertServiceTests.cs @@ -39,7 +39,7 @@ public async Task ModifyPayload() & DivertFilter.Loopback & DivertFilter.Ip & !DivertFilter.Impostor - & (DivertFilter.RemotePort == port.ToString()); + & (DivertFilter.RemotePort == port); using var service = new DivertService(filter); var packetBuffer = new byte[ushort.MaxValue + 40]; @@ -54,9 +54,10 @@ public async Task ModifyPayload() await client.SendAsync(new byte[] { 1, 2, 3 }, token); var divertResult = await divertReceive; - Assert.AreEqual(20 + 8 + 3, divertResult.Length); - var packet = packetBuffer.AsMemory(0, divertResult.Length); + Assert.AreEqual(20 + 8 + 3, divertResult.DataLength); + var packet = packetBuffer.AsMemory(0, divertResult.DataLength); CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, packet.ToArray()[28..]); + Assert.AreEqual(1, divertResult.AddressLength); Assert.IsFalse(receive.IsCompleted); // Re-inject @@ -76,4 +77,30 @@ public async Task ModifyPayload() Assert.HasCount(3, result.Buffer); CollectionAssert.AreEqual(new byte[] { 4, 5, 6 }, result.Buffer); } + + [TestMethod] + public async Task Close() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + service.Dispose(); + var exception = await Assert.ThrowsAsync(async () => await divertReceive); + Assert.IsFalse(token.IsCancellationRequested); + Assert.AreNotEqual(token, exception.CancellationToken); + Assert.AreEqual(CancellationToken.None, exception.CancellationToken); + } } diff --git a/Divert.Windows/AsyncOperation/CancellationHandle.cs b/Divert.Windows/AsyncOperation/CancellationHandle.cs new file mode 100644 index 0000000..a25786e --- /dev/null +++ b/Divert.Windows/AsyncOperation/CancellationHandle.cs @@ -0,0 +1,60 @@ +using System.Runtime.InteropServices; +using Windows.Win32; + +namespace Divert.Windows.AsyncOperation; + +internal sealed unsafe class CancellationHandle(SafeHandle safeHandle) : IDisposable +{ + private static class Status + { + public const int Idle = 0; + public const int Canceled = 1; + public const int Pending = 2; + public const int Disposed = 3; + } + + private int status; + + public SafeHandle SafeHandle => safeHandle; + + // From cancellation registration. + public void RequestOrInvokeCancel(NativeOverlapped* nativeOverlapped) + { + int originalStatus = Interlocked.CompareExchange(ref status, Status.Canceled, Status.Idle); + if (originalStatus is Status.Pending) + { + using (safeHandle.Reference(out var handle)) + { + _ = PInvoke.CancelIoEx(new(handle), nativeOverlapped); + } + } + } + + // After ERROR_IO_PENDING. + public void CancelWhenRequested(NativeOverlapped* nativeOverlapped) + { + int originalStatus = Interlocked.CompareExchange(ref status, Status.Pending, Status.Idle); + if (originalStatus is Status.Canceled) + { + using (safeHandle.Reference(out var handle)) + { + _ = PInvoke.CancelIoEx(new(handle), nativeOverlapped); + } + } + } + + // From completion callback. + public void Reset() + { + int originalStatus = Interlocked.CompareExchange(ref status, Status.Idle, Status.Pending); + if (originalStatus is Status.Canceled) + { + Interlocked.CompareExchange(ref status, Status.Idle, originalStatus); + } + } + + public void Dispose() + { + Interlocked.Exchange(ref status, Status.Disposed); + } +} diff --git a/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs b/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs new file mode 100644 index 0000000..c706f01 --- /dev/null +++ b/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs @@ -0,0 +1,37 @@ +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Divert.Windows.AsyncOperation; + +internal sealed class DivertReceiveExecutor : IDivertValueTaskExecutor +{ + private readonly Memory addressesLengthBuffer = GC.AllocateArray(1, pinned: true); + + public unsafe bool Execute(SafeHandle safeHandle, ref readonly PendingOperation pendingOperation) + { + using var _ = safeHandle.Reference(out var handle); + addressesLengthBuffer.Span[0] = (uint)(pendingOperation.Addresses.Length * sizeof(DivertAddress)); + return NativeMethods.WinDivertRecvEx( + handle, + pendingOperation.PacketBufferHandle.Pointer, + (uint)pendingOperation.PacketBuffer.Length, + null, + 0, + (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, + (uint*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(addressesLengthBuffer.Span)), + pendingOperation.NativeOverlapped + ); + } + + public async ValueTask ReceiveAsync( + DivertValueTaskSource source, + Memory buffer, + Memory addresses, + CancellationToken cancellationToken + ) + { + int dataLength = await source.ExecuteAsync(this, buffer, addresses, cancellationToken).ConfigureAwait(false); + int addressesLength = (int)addressesLengthBuffer.Span[0] / Marshal.SizeOf(); + return new DivertReceiveResult(dataLength, addressesLength); + } +} diff --git a/Divert.Windows/AsyncOperation/DivertReceiveValueTaskSource.cs b/Divert.Windows/AsyncOperation/DivertReceiveValueTaskSource.cs deleted file mode 100644 index 4457d43..0000000 --- a/Divert.Windows/AsyncOperation/DivertReceiveValueTaskSource.cs +++ /dev/null @@ -1,121 +0,0 @@ -using System.ComponentModel; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using System.Threading.Channels; -using System.Threading.Tasks.Sources; -using Windows.Win32.Foundation; - -namespace Divert.Windows.AsyncOperation; - -internal sealed unsafe class DivertReceiveValueTaskSource( - Channel pool, - DivertHandle divertHandle, - ThreadPoolBoundHandle threadPoolBoundHandle, - bool runContinuationsAsynchronously -) : DivertValueTaskSource(divertHandle, threadPoolBoundHandle), IValueTaskSource -{ - private ManualResetValueTaskSourceCore source = new() - { - RunContinuationsAsynchronously = runContinuationsAsynchronously, - }; - - private readonly Memory addressesLengthBuffer = GC.AllocateArray(1, pinned: true); - - protected override void Complete(uint errorCode, uint numBytes, in PendingOperation pendingOperation) - { - var addresses = pendingOperation.Addresses; - int addressesLength = (int)addressesLengthBuffer.Span[0] / sizeof(DivertAddress); - var token = pendingOperation.CancellationToken; - - if (errorCode is (uint)WIN32_ERROR.ERROR_NO_DATA) - { - errorCode = 0; - } - - if (errorCode is (uint)WIN32_ERROR.ERROR_SUCCESS) - { - source.SetResult(new DivertReceiveResult((int)numBytes, addresses[..addressesLength])); - } - else if (errorCode is (uint)WIN32_ERROR.ERROR_OPERATION_ABORTED) - { - var exception = token.IsCancellationRequested - ? new OperationCanceledException() - : new OperationCanceledException(token); - source.SetException(exception); - } - else - { - var exception = new Win32Exception((int)errorCode); - source.SetException(exception); - } - } - - public DivertReceiveResult GetResult(short token) - { - try - { - return source.GetResult(token); - } - finally - { - source.Reset(); - if (!pool.Writer.TryWrite(this)) - { - Dispose(); - } - } - } - - public ValueTaskSourceStatus GetStatus(short token) => source.GetStatus(token); - - public void OnCompleted( - Action continuation, - object? state, - short token, - ValueTaskSourceOnCompletedFlags flags - ) => source.OnCompleted(continuation, state, token, flags); - - public ValueTask ReceiveAsync( - Memory buffer, - Memory addresses, - CancellationToken cancellationToken - ) - { - using var _ = DivertHandle.GetReference(out var handle); - addressesLengthBuffer.Span[0] = (uint)(addresses.Length * sizeof(DivertAddress)); - var pendingOperation = PrepareOperation(buffer, addresses, cancellationToken); - try - { - bool success = NativeMethods.WinDivertRecvEx( - handle, - pendingOperation.PacketBufferHandle.Pointer, - (uint)buffer.Length, - null, - 0, - (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, - (uint*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(addressesLengthBuffer.Span)), - pendingOperation.NativeOverlapped - ); - if (!success) - { - int error = Marshal.GetLastPInvokeError(); - if (error is (int)WIN32_ERROR.ERROR_IO_PENDING) - { - CancelIfRequested(); - } - else - { - Dispose(); - return ValueTask.FromException(new Win32Exception(error)); - } - } - - return new ValueTask(this, source.Version); - } - catch - { - Dispose(); - throw; - } - } -} diff --git a/Divert.Windows/AsyncOperation/DivertSendExecutor.cs b/Divert.Windows/AsyncOperation/DivertSendExecutor.cs new file mode 100644 index 0000000..e116dfc --- /dev/null +++ b/Divert.Windows/AsyncOperation/DivertSendExecutor.cs @@ -0,0 +1,36 @@ +using System.Runtime.InteropServices; + +namespace Divert.Windows.AsyncOperation; + +internal sealed unsafe class DivertSendExecutor : IDivertValueTaskExecutor +{ + public bool Execute(SafeHandle safeHandle, ref readonly PendingOperation pendingOperation) + { + using var _ = safeHandle.Reference(out var handle); + return NativeMethods.WinDivertSendEx( + handle, + pendingOperation.PacketBufferHandle.Pointer, + (uint)pendingOperation.PacketBuffer.Length, + null, + 0, + (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, + (uint)(pendingOperation.Addresses.Length * sizeof(DivertAddress)), + pendingOperation.NativeOverlapped + ); + } + + public ValueTask SendAsync( + DivertValueTaskSource source, + ReadOnlyMemory buffer, + ReadOnlyMemory addresses, + CancellationToken cancellationToken + ) + { + return source.ExecuteAsync( + this, + MemoryMarshal.AsMemory(buffer), + MemoryMarshal.AsMemory(addresses), + cancellationToken + ); + } +} diff --git a/Divert.Windows/AsyncOperation/DivertSendValueTaskSource.cs b/Divert.Windows/AsyncOperation/DivertSendValueTaskSource.cs deleted file mode 100644 index 8364f70..0000000 --- a/Divert.Windows/AsyncOperation/DivertSendValueTaskSource.cs +++ /dev/null @@ -1,114 +0,0 @@ -using System.ComponentModel; -using System.Runtime.InteropServices; -using System.Threading.Channels; -using System.Threading.Tasks.Sources; -using Windows.Win32.Foundation; - -namespace Divert.Windows.AsyncOperation; - -internal sealed unsafe class DivertSendValueTaskSource( - Channel pool, - DivertHandle divertHandle, - ThreadPoolBoundHandle threadPoolBoundHandle, - bool runContinuationsAsynchronously -) : DivertValueTaskSource(divertHandle, threadPoolBoundHandle), IValueTaskSource -{ - private ManualResetValueTaskSourceCore source = new() - { - RunContinuationsAsynchronously = runContinuationsAsynchronously, - }; - - protected override void Complete(uint errorCode, uint numBytes, in PendingOperation pendingOperation) - { - var token = pendingOperation.CancellationToken; - - if (errorCode is (uint)WIN32_ERROR.ERROR_SUCCESS) - { - source.SetResult(0); - } - else if (errorCode is (uint)WIN32_ERROR.ERROR_OPERATION_ABORTED) - { - var exception = token.IsCancellationRequested - ? new OperationCanceledException() - : new OperationCanceledException(token); - source.SetException(exception); - } - else - { - var exception = new Win32Exception((int)errorCode); - source.SetException(exception); - } - } - - public void GetResult(short token) - { - try - { - source.GetResult(token); - } - finally - { - source.Reset(); - if (!pool.Writer.TryWrite(this)) - { - Dispose(); - } - } - } - - public ValueTaskSourceStatus GetStatus(short token) => source.GetStatus(token); - - public void OnCompleted( - Action continuation, - object? state, - short token, - ValueTaskSourceOnCompletedFlags flags - ) => source.OnCompleted(continuation, state, token, flags); - - public ValueTask SendAsync( - ReadOnlyMemory buffer, - ReadOnlyMemory addresses, - CancellationToken cancellationToken - ) - { - using var _ = DivertHandle.GetReference(out var handle); - var pendingOperation = PrepareOperation( - MemoryMarshal.AsMemory(buffer), - MemoryMarshal.AsMemory(addresses), - cancellationToken - ); - try - { - bool success = NativeMethods.WinDivertSendEx( - handle, - pendingOperation.PacketBufferHandle.Pointer, - (uint)buffer.Length, - null, - 0, - (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, - (uint)(addresses.Length * sizeof(DivertAddress)), - pendingOperation.NativeOverlapped - ); - if (!success) - { - int error = Marshal.GetLastPInvokeError(); - if (error == (int)WIN32_ERROR.ERROR_IO_PENDING) - { - CancelIfRequested(); - } - else - { - Dispose(); - return ValueTask.FromException(new Win32Exception(error)); - } - } - - return new ValueTask(this, source.Version); - } - catch - { - Dispose(); - throw; - } - } -} diff --git a/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs index 7a78d8c..155084c 100644 --- a/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs +++ b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs @@ -1,135 +1,145 @@ -using Windows.Win32; +using System.ComponentModel; +using System.Runtime.InteropServices; +using System.Threading.Channels; +using System.Threading.Tasks.Sources; +using Windows.Win32.Foundation; namespace Divert.Windows.AsyncOperation; -internal abstract unsafe class DivertValueTaskSource : IDisposable +internal interface IDivertValueTaskExecutor { - private static class Status - { - public const uint Idle = 0; - public const uint Canceled = 1; - public const uint Pending = 2; - public const uint Disposed = 3; - } - - private static readonly IOCompletionCallback ioCompletionCallback = OnIOCompleted; + bool Execute(SafeHandle safeHandle, ref readonly PendingOperation pendingOperation); +} - private readonly DivertHandle divertHandle; - private readonly ThreadPoolBoundHandle threadPoolBoundHandle; - private readonly PreAllocatedOverlapped preAllocatedOverlapped; +internal unsafe class DivertValueTaskSource : IDisposable, IValueTaskSource, IOCompletionHandler +{ + private readonly Channel pool; + private readonly IOCompletionOperation ioCompletionOperation; - private IntPtr nativeOverlapped; + private ManualResetValueTaskSourceCore source; private PendingOperation pendingOperation; - private uint status; - - protected DivertHandle DivertHandle => divertHandle; - public DivertValueTaskSource(DivertHandle divertHandle, ThreadPoolBoundHandle threadPoolBoundHandle) + public DivertValueTaskSource( + Channel pool, + DivertHandle divertHandle, + ThreadPoolBoundHandle threadPoolBoundHandle, + bool runContinuationsAsynchronously + ) { - this.divertHandle = divertHandle; - this.threadPoolBoundHandle = threadPoolBoundHandle; - preAllocatedOverlapped = new PreAllocatedOverlapped(ioCompletionCallback, this, null); + this.pool = pool; + ioCompletionOperation = new IOCompletionOperation( + divertHandle, + threadPoolBoundHandle, + this + ); + source = new ManualResetValueTaskSourceCore + { + RunContinuationsAsynchronously = runContinuationsAsynchronously, + }; } - // From cancellation registration. - private void ExecuteOrRequestCancel() + private SafeHandle SafeHandle => ioCompletionOperation.SafeHandle; + + public void Dispose() { - uint originalStatus = Interlocked.CompareExchange(ref status, Status.Canceled, Status.Idle); - if (originalStatus is Status.Pending) - { - using (divertHandle.GetReference(out var handle)) - { - _ = PInvoke.CancelIoEx(new(handle), pendingOperation.NativeOverlapped); - } - } + pendingOperation.Dispose(); + ioCompletionOperation.Dispose(); } - // After ERROR_IO_PENDING. - protected void CancelIfRequested() + public int GetResult(short token) { - uint originalStatus = Interlocked.CompareExchange(ref status, Status.Pending, Status.Idle); - if (originalStatus is Status.Canceled) + try { - using (divertHandle.GetReference(out var handle)) + return source.GetResult(token); + } + finally + { + source.Reset(); + if (!pool.Writer.TryWrite(this)) { - _ = PInvoke.CancelIoEx(new(handle), pendingOperation.NativeOverlapped); + Dispose(); } } } - // From completion callback. - private void ResetIfPendingOrCanceled() - { - uint originalStatus = Interlocked.CompareExchange(ref status, Status.Idle, Status.Pending); - if (originalStatus is Status.Canceled) - { - Interlocked.CompareExchange(ref status, Status.Idle, originalStatus); - } - } + public ValueTaskSourceStatus GetStatus(short token) => source.GetStatus(token); + + public void OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags + ) => source.OnCompleted(continuation, state, token, flags); - protected PendingOperation PrepareOperation( + private ref readonly PendingOperation PrepareOperation( Memory packetBuffer, Memory addresses, CancellationToken cancellationToken ) { - var cancellationTokenRegistration = cancellationToken.CanBeCanceled - ? cancellationToken.UnsafeRegister( - static state => ((DivertValueTaskSource)state!).ExecuteOrRequestCancel(), - this - ) - : default; - var nativeOverlapped = threadPoolBoundHandle.AllocateNativeOverlapped(preAllocatedOverlapped); - this.nativeOverlapped = new(nativeOverlapped); + var cancellationTokenRegistration = ioCompletionOperation.Prepare(cancellationToken, out var nativeOverlapped); pendingOperation = new PendingOperation( nativeOverlapped, + cancellationTokenRegistration, packetBuffer, - addresses, - cancellationTokenRegistration + addresses ); - return pendingOperation; + return ref pendingOperation; } - private void DisposePendingOperation() + public void OnCompleted(uint errorCode, uint numBytes) { - var overlapped = Interlocked.Exchange(ref nativeOverlapped, default); - if (overlapped != default) + using var _ = pendingOperation; + if (errorCode is (uint)WIN32_ERROR.ERROR_SUCCESS) { - pendingOperation.Dispose(); - threadPoolBoundHandle.FreeNativeOverlapped((NativeOverlapped*)overlapped); + source.SetResult((int)numBytes); } - } - - public void Dispose() - { - if (Interlocked.Exchange(ref status, Status.Disposed) is Status.Disposed) + else if (errorCode is (uint)WIN32_ERROR.ERROR_OPERATION_ABORTED) { - return; + var token = pendingOperation.CancellationToken; + var exception = token.IsCancellationRequested + ? new OperationCanceledException(token) + : new OperationCanceledException(); + source.SetException(exception); + } + else + { + var exception = new Win32Exception((int)errorCode); + source.SetException(exception); } - - DisposePendingOperation(); - preAllocatedOverlapped.Dispose(); - GC.SuppressFinalize(this); - } - - ~DivertValueTaskSource() - { - Dispose(); } - protected abstract void Complete(uint errorCode, uint numBytes, in PendingOperation pendingOperation); - - private static void OnIOCompleted(uint errorCode, uint numBytes, NativeOverlapped* pOVERLAP) + public ValueTask ExecuteAsync( + TExecutor executor, + Memory buffer, + Memory addresses, + CancellationToken cancellationToken + ) + where TExecutor : IDivertValueTaskExecutor { - var vts = (DivertValueTaskSource)ThreadPoolBoundHandle.GetNativeOverlappedState(pOVERLAP)!; + ref readonly var pendingOperation = ref PrepareOperation(buffer, addresses, cancellationToken); try { - vts.Complete(errorCode, numBytes, vts.pendingOperation); - vts.ResetIfPendingOrCanceled(); + bool success = executor.Execute(SafeHandle, in pendingOperation); + if (!success) + { + int error = Marshal.GetLastPInvokeError(); + if (error is (int)WIN32_ERROR.ERROR_IO_PENDING) + { + ioCompletionOperation.CancelWhenRequested(); + } + else + { + Dispose(); + return ValueTask.FromException(new Win32Exception(error)); + } + } + return new ValueTask(this, source.Version); } - finally + catch { - vts.DisposePendingOperation(); + Dispose(); + throw; } } } diff --git a/Divert.Windows/AsyncOperation/IOCompletionOperation.cs b/Divert.Windows/AsyncOperation/IOCompletionOperation.cs new file mode 100644 index 0000000..e4e7653 --- /dev/null +++ b/Divert.Windows/AsyncOperation/IOCompletionOperation.cs @@ -0,0 +1,86 @@ +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Divert.Windows.AsyncOperation; + +internal interface IOCompletionHandler +{ + void OnCompleted(uint errorCode, uint numBytes); +} + +internal sealed unsafe class IOCompletionOperation : IDisposable, IOCompletionHandler + where THandler : IOCompletionHandler +{ + private static void OnIOCompleted(uint errorCode, uint numBytes, NativeOverlapped* pOVERLAP) + { + var operation = (IOCompletionOperation)ThreadPoolBoundHandle.GetNativeOverlappedState(pOVERLAP)!; + operation.OnCompleted(errorCode, numBytes); + } + + private static readonly IOCompletionCallback ioCompletionCallback = OnIOCompleted; + + private readonly ThreadPoolBoundHandle threadPoolBoundHandle; + private readonly CancellationHandle cancellationHandle; + private readonly PreAllocatedOverlapped preAllocatedOverlapped; + private readonly THandler handler; + + private NativeOverlapped* nativeOverlapped; + + public IOCompletionOperation(SafeHandle safeHandle, ThreadPoolBoundHandle threadPoolBoundHandle, THandler handler) + { + cancellationHandle = new CancellationHandle(safeHandle); + this.threadPoolBoundHandle = threadPoolBoundHandle; + this.handler = handler; + preAllocatedOverlapped = new PreAllocatedOverlapped(ioCompletionCallback, this, null); + } + + public SafeHandle SafeHandle => cancellationHandle.SafeHandle; + + public void Dispose() + { + if (nativeOverlapped is not null) + { + threadPoolBoundHandle.FreeNativeOverlapped(nativeOverlapped); + nativeOverlapped = null; + } + preAllocatedOverlapped.Dispose(); + cancellationHandle.Dispose(); + } + + public CancellationTokenRegistration Prepare(CancellationToken cancellationToken, out NativeOverlapped* overlapped) + { + Debug.Assert(nativeOverlapped is null); + nativeOverlapped = threadPoolBoundHandle.AllocateNativeOverlapped(preAllocatedOverlapped); + var registration = cancellationToken.CanBeCanceled + ? cancellationToken.UnsafeRegister( + static state => + { + var operation = (IOCompletionOperation)state!; + operation.cancellationHandle.RequestOrInvokeCancel(operation.nativeOverlapped); + }, + this + ) + : default; + overlapped = nativeOverlapped; + return registration; + } + + public void CancelWhenRequested() + { + cancellationHandle.CancelWhenRequested(nativeOverlapped); + } + + public void OnCompleted(uint errorCode, uint numBytes) + { + try + { + handler.OnCompleted(errorCode, numBytes); + } + finally + { + cancellationHandle.Reset(); + threadPoolBoundHandle.FreeNativeOverlapped(nativeOverlapped); + nativeOverlapped = null; + } + } +} diff --git a/Divert.Windows/AsyncOperation/PendingOperation.cs b/Divert.Windows/AsyncOperation/PendingOperation.cs index 0de4c5a..282a21d 100644 --- a/Divert.Windows/AsyncOperation/PendingOperation.cs +++ b/Divert.Windows/AsyncOperation/PendingOperation.cs @@ -2,24 +2,28 @@ namespace Divert.Windows.AsyncOperation; -internal readonly unsafe struct PendingOperation( +internal unsafe struct PendingOperation( NativeOverlapped* nativeOverlapped, + CancellationTokenRegistration cancellationTokenRegistration, Memory packetBuffer, - Memory addresses, - CancellationTokenRegistration cancellationTokenRegistration + Memory addresses ) : IDisposable { - public NativeOverlapped* NativeOverlapped { get; } = nativeOverlapped; - public MemoryHandle PacketBufferHandle { get; } = packetBuffer.Pin(); - public MemoryHandle AddressesHandle { get; } = addresses.Pin(); + private MemoryHandle packetBufferHandle = packetBuffer.Pin(); + private MemoryHandle addressesHandle = addresses.Pin(); + public readonly NativeOverlapped* NativeOverlapped => nativeOverlapped; + public readonly MemoryHandle PacketBufferHandle => packetBufferHandle; + public readonly MemoryHandle AddressesHandle => addressesHandle; + + public readonly Memory PacketBuffer => packetBuffer; public readonly Memory Addresses => addresses; public readonly CancellationToken CancellationToken => cancellationTokenRegistration.Token; public void Dispose() { cancellationTokenRegistration.Dispose(); - PacketBufferHandle.Dispose(); - AddressesHandle.Dispose(); + packetBufferHandle.Dispose(); + addressesHandle.Dispose(); } } diff --git a/Divert.Windows/DivertFilter.cs b/Divert.Windows/DivertFilter.cs index ceebee9..d47ff80 100644 --- a/Divert.Windows/DivertFilter.cs +++ b/Divert.Windows/DivertFilter.cs @@ -179,37 +179,37 @@ public static implicit operator DivertFilter(Field field) return new DivertFilter(clause); } - public static DivertFilter operator ==(Field left, string right) + public static DivertFilter operator ==(Field left, object right) { string clause = $"{left} = {right}"; return new DivertFilter(clause); } - public static DivertFilter operator !=(Field left, string right) + public static DivertFilter operator !=(Field left, object right) { string clause = $"{left} != {right}"; return new DivertFilter(clause); } - public static DivertFilter operator <(Field left, string right) + public static DivertFilter operator <(Field left, object right) { string clause = $"{left} < {right}"; return new DivertFilter(clause); } - public static DivertFilter operator >(Field left, string right) + public static DivertFilter operator >(Field left, object right) { string clause = $"{left} > {right}"; return new DivertFilter(clause); } - public static DivertFilter operator <=(Field left, string right) + public static DivertFilter operator <=(Field left, object right) { string clause = $"{left} <= {right}"; return new DivertFilter(clause); } - public static DivertFilter operator >=(Field left, string right) + public static DivertFilter operator >=(Field left, object right) { string clause = $"{left} >= {right}"; return new DivertFilter(clause); diff --git a/Divert.Windows/DivertReceiveResult.cs b/Divert.Windows/DivertReceiveResult.cs index 16dec94..3587ed0 100644 --- a/Divert.Windows/DivertReceiveResult.cs +++ b/Divert.Windows/DivertReceiveResult.cs @@ -1,7 +1,14 @@ namespace Divert.Windows; -public readonly struct DivertReceiveResult(int length, Memory addresses) +public readonly struct DivertReceiveResult(int dataLength, int addressLength) { - public int Length => length; - public Memory Addresses => addresses; + /// + /// Gets the length of the received data. + /// + public int DataLength { get; } = dataLength; + + /// + /// Gets the length of the addresses. + /// + public int AddressLength { get; } = addressLength; } diff --git a/Divert.Windows/DivertService.cs b/Divert.Windows/DivertService.cs index 4fb61f5..cccff2d 100644 --- a/Divert.Windows/DivertService.cs +++ b/Divert.Windows/DivertService.cs @@ -12,10 +12,13 @@ namespace Divert.Windows; public sealed unsafe class DivertService : IDisposable { private readonly DivertHandle divertHandle; + private readonly bool runContinuationsAsynchronously; private readonly ThreadPoolBoundHandle threadPoolBoundHandle; - private readonly Channel receiveVtsPool; - private readonly Channel sendVtsPool; + private readonly Channel sendVtsPool; + private readonly Channel receiveVtsPool; + private readonly DivertReceiveExecutor receiveExecutor; + private readonly DivertSendExecutor sendExecutor; /// /// Opens a WinDivert handle for the given filter. @@ -28,7 +31,8 @@ public DivertService( DivertFilter filter, DivertLayer layer = DivertLayer.Network, short priority = 0, - DivertFlags flags = DivertFlags.None + DivertFlags flags = DivertFlags.None, + bool runContinuationsAsynchronously = true ) { ArgumentNullException.ThrowIfNull(filter); @@ -56,18 +60,20 @@ public DivertService( throw new Win32Exception(Marshal.GetLastPInvokeError()); } divertHandle = new DivertHandle(handle); + this.runContinuationsAsynchronously = runContinuationsAsynchronously; threadPoolBoundHandle = ThreadPoolBoundHandle.BindHandle(divertHandle); - receiveVtsPool = Channel.CreateUnbounded(); - sendVtsPool = Channel.CreateUnbounded(); + sendVtsPool = Channel.CreateUnbounded(); + receiveVtsPool = Channel.CreateUnbounded(); + receiveExecutor = new DivertReceiveExecutor(); + sendExecutor = new DivertSendExecutor(); } private bool disposed = false; - private static void DisposeVtsPool(Channel vtsPool) - where T : IDisposable + private static void DisposeVtsPool(Channel pool) { - vtsPool.Writer.TryComplete(); - while (vtsPool.Reader.TryRead(out var vts)) + pool.Writer.TryComplete(); + while (pool.Reader.TryRead(out var vts)) { vts.Dispose(); } @@ -84,32 +90,24 @@ public void Dispose() } disposed = true; - DisposeVtsPool(receiveVtsPool); DisposeVtsPool(sendVtsPool); + DisposeVtsPool(receiveVtsPool); threadPoolBoundHandle.Dispose(); divertHandle.Dispose(); - GC.SuppressFinalize(this); } - ~DivertService() - { - Dispose(); - } - - /// - /// Shuts down the WinDivert handle. - /// - /// Specifies how to shut down the handle. - public void Shutdown(DivertShutdown how) + private DivertValueTaskSource GetVts(Channel vtsPool) { - ObjectDisposedException.ThrowIf(disposed, this); - - using var _ = divertHandle.GetReference(out var handle); - bool success = NativeMethods.WinDivertShutdown(handle, (WINDIVERT_SHUTDOWN)how); - if (!success) + if (!vtsPool.Reader.TryRead(out var vts)) { - throw new Win32Exception(Marshal.GetLastPInvokeError()); + vts = new DivertValueTaskSource( + sendVtsPool, + divertHandle, + threadPoolBoundHandle, + runContinuationsAsynchronously + ); } + return vts; } /// @@ -131,16 +129,8 @@ public ValueTask ReceiveAsync( return ValueTask.FromCanceled(cancellationToken); } - if (!receiveVtsPool.Reader.TryRead(out var vts)) - { - vts = new DivertReceiveValueTaskSource( - receiveVtsPool, - divertHandle, - threadPoolBoundHandle, - runContinuationsAsynchronously: true - ); - } - return vts.ReceiveAsync(buffer, addresses, cancellationToken); + var vts = GetVts(receiveVtsPool); + return receiveExecutor.ReceiveAsync(vts, buffer, addresses, cancellationToken); } /// @@ -150,7 +140,7 @@ public ValueTask ReceiveAsync( /// Addresses of the packets to be injected. /// Token to observe cancellation requests. /// A ValueTask representing the asynchronous send operation. - public ValueTask SendAsync( + public ValueTask SendAsync( ReadOnlyMemory buffer, ReadOnlyMemory addresses, CancellationToken cancellationToken = default @@ -159,18 +149,10 @@ public ValueTask SendAsync( ObjectDisposedException.ThrowIf(disposed, this); if (cancellationToken.IsCancellationRequested) { - return ValueTask.FromCanceled(cancellationToken); + return ValueTask.FromCanceled(cancellationToken); } - if (!sendVtsPool.Reader.TryRead(out var vts)) - { - vts = new DivertSendValueTaskSource( - sendVtsPool, - divertHandle, - threadPoolBoundHandle, - runContinuationsAsynchronously: true - ); - } - return vts.SendAsync(buffer, addresses, cancellationToken); + var vts = GetVts(sendVtsPool); + return sendExecutor.SendAsync(vts, buffer, addresses, cancellationToken); } } diff --git a/Divert.Windows/SafeHandleExtensions.cs b/Divert.Windows/SafeHandleExtensions.cs index c5e69bb..d99aecd 100644 --- a/Divert.Windows/SafeHandleExtensions.cs +++ b/Divert.Windows/SafeHandleExtensions.cs @@ -13,7 +13,7 @@ public void Dispose() internal static class SafeHandleExtensions { - public static SafeHandleReference GetReference(this T safeHandle, out IntPtr handle) + public static SafeHandleReference Reference(this T safeHandle, out IntPtr handle) where T : SafeHandle { bool success = false; From 25e10d06e59bb04a6d344d57968f6c353820a669 Mon Sep 17 00:00:00 2001 From: gdlol Date: Wed, 5 Nov 2025 18:02:34 +0000 Subject: [PATCH 06/20] Test cancellation --- Automation/Test.cs | 14 +++- .../Divert.Windows.Tests.csproj | 6 +- Divert.Windows.Tests/DivertServiceTests.cs | 76 +++++++++++++++++++ Divert.Windows.Tests/ExecutorDelayPipe.cs | 24 ++++++ .../AsyncOperation/DivertValueTaskSource.cs | 6 +- .../IDivertValueTaskExecutor.cs | 26 +++++++ Divert.Windows/Divert.Windows.csproj | 4 + Divert.Windows/NativeMethods.txt | 1 - 8 files changed, 145 insertions(+), 12 deletions(-) create mode 100644 Divert.Windows.Tests/ExecutorDelayPipe.cs create mode 100644 Divert.Windows/AsyncOperation/IDivertValueTaskExecutor.cs diff --git a/Automation/Test.cs b/Automation/Test.cs index b297759..eece2d9 100644 --- a/Automation/Test.cs +++ b/Automation/Test.cs @@ -50,7 +50,19 @@ public static void GenerateReport(Context context) public override async Task RunAsync(Context context) { - context.DotNetBuild(Path.Combine(Context.ProjectRoot, TestProjectName)); + context.DotNetBuild( + Path.Combine(Context.ProjectRoot, TestProjectName), + new() + { + MSBuildSettings = new() + { + Properties = + { + ["DivertWindowsTests"] = ["true"], // Enable DivertValueTaskExecutorDelay + }, + }, + } + ); if (!Directory.Exists(Path.Combine(Context.LocalWindowsDirectory, TestRunnerProjectName))) { context.DotNetBuild(Path.Combine(Context.ProjectRoot, TestRunnerProjectName)); diff --git a/Divert.Windows.Tests/Divert.Windows.Tests.csproj b/Divert.Windows.Tests/Divert.Windows.Tests.csproj index b45d517..d73b683 100644 --- a/Divert.Windows.Tests/Divert.Windows.Tests.csproj +++ b/Divert.Windows.Tests/Divert.Windows.Tests.csproj @@ -23,10 +23,6 @@ - + diff --git a/Divert.Windows.Tests/DivertServiceTests.cs b/Divert.Windows.Tests/DivertServiceTests.cs index 439fc0f..fa275de 100644 --- a/Divert.Windows.Tests/DivertServiceTests.cs +++ b/Divert.Windows.Tests/DivertServiceTests.cs @@ -103,4 +103,80 @@ public async Task Close() Assert.AreNotEqual(token, exception.CancellationToken); Assert.AreEqual(CancellationToken.None, exception.CancellationToken); } + + [TestMethod] + public async Task CancelReceive() + { + using var listener = CreateUdpListener(out int port); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + // Cancel before receive + { + using var cts = new CancellationTokenSource(); + cts.Cancel(); + var token = cts.Token; + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, token).AsTask(); + Assert.IsTrue(divertReceive.IsCanceled); + } + + // Cancel on receive + using (var pipe = new ExecutorDelayPipe()) + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(this.token); + var token = cts.Token; + + var divertReceive = Task.Run(async () => await service.ReceiveAsync(packetBuffer, addressBuffer, token)); + Assert.IsFalse(divertReceive.IsCompleted); + + await pipe.Stream.WaitForConnectionAsync(token); + cts.Cancel(); + Assert.IsFalse(divertReceive.IsCompleted); + pipe.Stream.WriteByte(0); + await pipe.Stream.FlushAsync(token); + pipe.Stream.Disconnect(); + + var exception = await Assert.ThrowsAsync(async () => await divertReceive); + Assert.IsTrue(token.IsCancellationRequested); + Assert.AreEqual(token, exception.CancellationToken); + } + + // Cancel pending receive + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(this.token); + var token = cts.Token; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + cts.Cancel(); + var exception = await Assert.ThrowsAsync(async () => await divertReceive); + Assert.IsTrue(token.IsCancellationRequested); + Assert.AreEqual(token, exception.CancellationToken); + } + } + + [TestMethod] + public void CancelSend() + { + using var service = new DivertService(DivertFilter.False); + + var packetBuffer = new byte[100]; + var addressBuffer = new DivertAddress[1]; + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(this.token); + var token = cts.Token; + cts.Cancel(); + var divertSend = service.SendAsync(packetBuffer, addressBuffer, token).AsTask(); + Assert.IsTrue(divertSend.IsCanceled); + } } diff --git a/Divert.Windows.Tests/ExecutorDelayPipe.cs b/Divert.Windows.Tests/ExecutorDelayPipe.cs new file mode 100644 index 0000000..5f50037 --- /dev/null +++ b/Divert.Windows.Tests/ExecutorDelayPipe.cs @@ -0,0 +1,24 @@ +using System.IO.Pipes; +using Divert.Windows.AsyncOperation; + +namespace Divert.Windows.Tests; + +internal sealed class ExecutorDelayPipe : IDisposable +{ + private readonly string name; + + public NamedPipeServerStream Stream { get; } + + public ExecutorDelayPipe() + { + name = Guid.NewGuid().ToString("N"); + Environment.SetEnvironmentVariable(nameof(DivertValueTaskExecutorDelay), name); + Stream = new NamedPipeServerStream(name, PipeDirection.InOut); + } + + public void Dispose() + { + Environment.SetEnvironmentVariable(nameof(DivertValueTaskExecutorDelay), null); + Stream.Dispose(); + } +} diff --git a/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs index 155084c..199d0b1 100644 --- a/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs +++ b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs @@ -6,11 +6,6 @@ namespace Divert.Windows.AsyncOperation; -internal interface IDivertValueTaskExecutor -{ - bool Execute(SafeHandle safeHandle, ref readonly PendingOperation pendingOperation); -} - internal unsafe class DivertValueTaskSource : IDisposable, IValueTaskSource, IOCompletionHandler { private readonly Channel pool; @@ -120,6 +115,7 @@ CancellationToken cancellationToken ref readonly var pendingOperation = ref PrepareOperation(buffer, addresses, cancellationToken); try { + executor.DelayExecutionInTests(); bool success = executor.Execute(SafeHandle, in pendingOperation); if (!success) { diff --git a/Divert.Windows/AsyncOperation/IDivertValueTaskExecutor.cs b/Divert.Windows/AsyncOperation/IDivertValueTaskExecutor.cs new file mode 100644 index 0000000..177769d --- /dev/null +++ b/Divert.Windows/AsyncOperation/IDivertValueTaskExecutor.cs @@ -0,0 +1,26 @@ +using System.Diagnostics; +using System.IO.Pipes; +using System.Runtime.InteropServices; + +namespace Divert.Windows.AsyncOperation; + +internal interface IDivertValueTaskExecutor +{ + bool Execute(SafeHandle safeHandle, ref readonly PendingOperation pendingOperation); +} + +internal static class DivertValueTaskExecutorDelay +{ + [Conditional("DIVERT_WINDOWS_TESTS")] + public static void DelayExecutionInTests(this IDivertValueTaskExecutor executor) + { + if (Environment.GetEnvironmentVariable(nameof(DivertValueTaskExecutorDelay)) is not string name) + { + return; + } + + using var stream = new NamedPipeClientStream(".", name, PipeDirection.InOut); + stream.Connect(); // Notify delay + stream.ReadByte(); // continue + } +} diff --git a/Divert.Windows/Divert.Windows.csproj b/Divert.Windows/Divert.Windows.csproj index d475485..e2586ce 100644 --- a/Divert.Windows/Divert.Windows.csproj +++ b/Divert.Windows/Divert.Windows.csproj @@ -4,6 +4,10 @@ x64 + + $(DefineConstants);DIVERT_WINDOWS_TESTS + + $(MSBuildProjectDirectory)=$(MSBuildProjectName) diff --git a/Divert.Windows/NativeMethods.txt b/Divert.Windows/NativeMethods.txt index 3737ecf..2a1e234 100644 --- a/Divert.Windows/NativeMethods.txt +++ b/Divert.Windows/NativeMethods.txt @@ -1,4 +1,3 @@ INVALID_HANDLE_VALUE WIN32_ERROR -GetOverlappedResult CancelIoEx From 2b22704315cff4355fa21ed5215ac7d2d9167e5e Mon Sep 17 00:00:00 2001 From: gdlol Date: Thu, 6 Nov 2025 14:15:22 +0000 Subject: [PATCH 07/20] Refactoring --- .../Divert.Windows.TestRunner.csproj | 6 ------ .../AsyncOperation/DivertValueTaskSource.cs | 11 +++-------- .../AsyncOperation/IOCompletionOperation.cs | 11 ++++++----- Divert.Windows/AsyncOperation/PendingOperation.cs | 7 +++---- Examples/Ping/Program.cs | 4 ++-- 5 files changed, 14 insertions(+), 25 deletions(-) diff --git a/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj b/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj index 59acd39..70d7fbe 100644 --- a/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj +++ b/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj @@ -9,12 +9,6 @@ CA2007 - - - none - - - diff --git a/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs index 199d0b1..1a6abef 100644 --- a/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs +++ b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs @@ -72,13 +72,8 @@ private ref readonly PendingOperation PrepareOperation( CancellationToken cancellationToken ) { - var cancellationTokenRegistration = ioCompletionOperation.Prepare(cancellationToken, out var nativeOverlapped); - pendingOperation = new PendingOperation( - nativeOverlapped, - cancellationTokenRegistration, - packetBuffer, - addresses - ); + var nativeOverlapped = ioCompletionOperation.Prepare(cancellationToken); + pendingOperation = new PendingOperation(nativeOverlapped, packetBuffer, addresses, cancellationToken); return ref pendingOperation; } @@ -122,7 +117,7 @@ CancellationToken cancellationToken int error = Marshal.GetLastPInvokeError(); if (error is (int)WIN32_ERROR.ERROR_IO_PENDING) { - ioCompletionOperation.CancelWhenRequested(); + ioCompletionOperation.CancelWhenRequested(pendingOperation.NativeOverlapped); } else { diff --git a/Divert.Windows/AsyncOperation/IOCompletionOperation.cs b/Divert.Windows/AsyncOperation/IOCompletionOperation.cs index e4e7653..733ed55 100644 --- a/Divert.Windows/AsyncOperation/IOCompletionOperation.cs +++ b/Divert.Windows/AsyncOperation/IOCompletionOperation.cs @@ -24,6 +24,7 @@ private static void OnIOCompleted(uint errorCode, uint numBytes, NativeOverlappe private readonly PreAllocatedOverlapped preAllocatedOverlapped; private readonly THandler handler; + private CancellationTokenRegistration cancellationRegistration; private NativeOverlapped* nativeOverlapped; public IOCompletionOperation(SafeHandle safeHandle, ThreadPoolBoundHandle threadPoolBoundHandle, THandler handler) @@ -47,11 +48,11 @@ public void Dispose() cancellationHandle.Dispose(); } - public CancellationTokenRegistration Prepare(CancellationToken cancellationToken, out NativeOverlapped* overlapped) + public NativeOverlapped* Prepare(CancellationToken cancellationToken) { Debug.Assert(nativeOverlapped is null); nativeOverlapped = threadPoolBoundHandle.AllocateNativeOverlapped(preAllocatedOverlapped); - var registration = cancellationToken.CanBeCanceled + cancellationRegistration = cancellationToken.CanBeCanceled ? cancellationToken.UnsafeRegister( static state => { @@ -61,17 +62,17 @@ public CancellationTokenRegistration Prepare(CancellationToken cancellationToken this ) : default; - overlapped = nativeOverlapped; - return registration; + return nativeOverlapped; } - public void CancelWhenRequested() + public void CancelWhenRequested(NativeOverlapped* nativeOverlapped) { cancellationHandle.CancelWhenRequested(nativeOverlapped); } public void OnCompleted(uint errorCode, uint numBytes) { + cancellationRegistration.Dispose(); try { handler.OnCompleted(errorCode, numBytes); diff --git a/Divert.Windows/AsyncOperation/PendingOperation.cs b/Divert.Windows/AsyncOperation/PendingOperation.cs index 282a21d..b6ca76d 100644 --- a/Divert.Windows/AsyncOperation/PendingOperation.cs +++ b/Divert.Windows/AsyncOperation/PendingOperation.cs @@ -4,9 +4,9 @@ namespace Divert.Windows.AsyncOperation; internal unsafe struct PendingOperation( NativeOverlapped* nativeOverlapped, - CancellationTokenRegistration cancellationTokenRegistration, Memory packetBuffer, - Memory addresses + Memory addresses, + CancellationToken cancellationToken ) : IDisposable { private MemoryHandle packetBufferHandle = packetBuffer.Pin(); @@ -18,11 +18,10 @@ Memory addresses public readonly Memory PacketBuffer => packetBuffer; public readonly Memory Addresses => addresses; - public readonly CancellationToken CancellationToken => cancellationTokenRegistration.Token; + public readonly CancellationToken CancellationToken => cancellationToken; public void Dispose() { - cancellationTokenRegistration.Dispose(); packetBufferHandle.Dispose(); addressesHandle.Dispose(); } diff --git a/Examples/Ping/Program.cs b/Examples/Ping/Program.cs index 71e97ef..30d2a88 100644 --- a/Examples/Ping/Program.cs +++ b/Examples/Ping/Program.cs @@ -84,7 +84,7 @@ from gateway in nic.GetIPProperties().GatewayAddresses try { var result = await outDivert.ReceiveAsync(buffer, addresses, cts.Token).ConfigureAwait(false); - var packet = buffer.AsMemory(0, result.Length); + var packet = buffer.AsMemory(0, result.DataLength); var remoteAddress = new IPAddress(packet[16..20].Span); Console.WriteLine($"Pinging {remoteAddress} with {packet.Length - 28} bytes of data (Divert):"); await outDivert.SendAsync(packet, addresses, cts.Token); @@ -113,7 +113,7 @@ from gateway in nic.GetIPProperties().GatewayAddresses try { var receiveResult = await inDivert.ReceiveAsync(buffer, addresses, cts.Token).ConfigureAwait(false); - var packet = buffer.AsMemory(0, receiveResult.Length); + var packet = buffer.AsMemory(0, receiveResult.DataLength); var remoteAddress = new IPAddress(packet[12..16].Span); long timestamp = BitConverter.ToInt64(packet.Slice(28, sizeof(long)).Span); Console.WriteLine( From d790128c38fadbd661a7354c3464ecfd2cb21267 Mon Sep 17 00:00:00 2001 From: gdlol Date: Thu, 6 Nov 2025 15:11:08 +0000 Subject: [PATCH 08/20] tests --- .config/cspell/cspell.json | 1 + .config/dotnet/Project.props | 1 + .../Divert.Windows.Tests.csproj | 6 ++ Divert.Windows.Tests/DivertServiceTests.cs | 85 +++++++++++++++++++ Divert.Windows.Tests/ExecutorDelayPipe.cs | 7 +- Divert.Windows.Tests/NativeMethods.txt | 3 + Divert.Windows/AssemblyInfo.cs | 2 - .../AsyncOperation/CancellationHandle.cs | 4 +- .../AsyncOperation/DivertReceiveExecutor.cs | 2 +- .../AsyncOperation/DivertSendExecutor.cs | 2 +- .../AsyncOperation/DivertValueTaskSource.cs | 4 +- .../IDivertValueTaskExecutor.cs | 6 +- Divert.Windows/CString.cs | 6 -- Divert.Windows/DivertHandle.cs | 4 +- Divert.Windows/DivertService.cs | 14 +++ Divert.Windows/SafeHandleExtensions.cs | 12 ++- 16 files changed, 135 insertions(+), 24 deletions(-) create mode 100644 Divert.Windows.Tests/NativeMethods.txt diff --git a/.config/cspell/cspell.json b/.config/cspell/cspell.json index 49ffc5c..9ed6f75 100644 --- a/.config/cspell/cspell.json +++ b/.config/cspell/cspell.json @@ -10,6 +10,7 @@ "csharpierrc", "devcontainer", "devcontainers", + "Finalizers", "getgid", "getuid", "globalconfig", diff --git a/.config/dotnet/Project.props b/.config/dotnet/Project.props index 10c3360..a384f84 100644 --- a/.config/dotnet/Project.props +++ b/.config/dotnet/Project.props @@ -3,6 +3,7 @@ enable enable true + preview diff --git a/Divert.Windows.Tests/Divert.Windows.Tests.csproj b/Divert.Windows.Tests/Divert.Windows.Tests.csproj index d73b683..5d93a34 100644 --- a/Divert.Windows.Tests/Divert.Windows.Tests.csproj +++ b/Divert.Windows.Tests/Divert.Windows.Tests.csproj @@ -25,4 +25,10 @@ + + + + all + + diff --git a/Divert.Windows.Tests/DivertServiceTests.cs b/Divert.Windows.Tests/DivertServiceTests.cs index fa275de..3b9083b 100644 --- a/Divert.Windows.Tests/DivertServiceTests.cs +++ b/Divert.Windows.Tests/DivertServiceTests.cs @@ -1,5 +1,9 @@ +using System.ComponentModel; using System.Net; using System.Net.Sockets; +using Windows.Win32; +using Windows.Win32.Foundation; +using Windows.Win32.Storage.FileSystem; namespace Divert.Windows.Tests; @@ -179,4 +183,85 @@ public void CancelSend() var divertSend = service.SendAsync(packetBuffer, addressBuffer, token).AsTask(); Assert.IsTrue(divertSend.IsCanceled); } + + [TestMethod] + public async Task InsufficientBuffer() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[10]; // insufficient buffer + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + await client.SendAsync(new byte[] { 1, 2, 3 }, token); + + var exception = await Assert.ThrowsAsync(async () => await divertReceive); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INSUFFICIENT_BUFFER, exception.NativeErrorCode); + } + + [TestMethod] + public async Task DisposeOnReceive() + { + using var listener = CreateUdpListener(out int port); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + using var pipe = new ExecutorDelayPipe(); + var divertReceive = Task.Run(async () => await service.ReceiveAsync(packetBuffer, addressBuffer, token)); + Assert.IsFalse(divertReceive.IsCompleted); + + await pipe.Stream.WaitForConnectionAsync(token); + Assert.IsFalse(divertReceive.IsCompleted); + service.Dispose(); + pipe.Stream.WriteByte(0); + await pipe.Stream.FlushAsync(token); + pipe.Stream.Disconnect(); + + await Assert.ThrowsAsync(async () => await divertReceive); + } + + [TestMethod] + public async Task InvalidHandle() + { + var handle = PInvoke.CreateFile( + Path.GetTempFileName(), + (uint)GENERIC_ACCESS_RIGHTS.GENERIC_READ, + FILE_SHARE_MODE.FILE_SHARE_NONE, + null, + FILE_CREATION_DISPOSITION.OPEN_EXISTING, + FILE_FLAGS_AND_ATTRIBUTES.FILE_ATTRIBUTE_NORMAL | FILE_FLAGS_AND_ATTRIBUTES.FILE_FLAG_OVERLAPPED, + null + ); + using var service = new DivertService(handle); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + var exception = await Assert.ThrowsAsync(async () => + await service.ReceiveAsync(packetBuffer, addressBuffer, token) + ); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + } } diff --git a/Divert.Windows.Tests/ExecutorDelayPipe.cs b/Divert.Windows.Tests/ExecutorDelayPipe.cs index 5f50037..5b4dbe0 100644 --- a/Divert.Windows.Tests/ExecutorDelayPipe.cs +++ b/Divert.Windows.Tests/ExecutorDelayPipe.cs @@ -1,10 +1,11 @@ using System.IO.Pipes; -using Divert.Windows.AsyncOperation; namespace Divert.Windows.Tests; internal sealed class ExecutorDelayPipe : IDisposable { + private const string DIVERT_WINDOWS_TESTS = nameof(DIVERT_WINDOWS_TESTS); + private readonly string name; public NamedPipeServerStream Stream { get; } @@ -12,13 +13,13 @@ internal sealed class ExecutorDelayPipe : IDisposable public ExecutorDelayPipe() { name = Guid.NewGuid().ToString("N"); - Environment.SetEnvironmentVariable(nameof(DivertValueTaskExecutorDelay), name); + Environment.SetEnvironmentVariable(DIVERT_WINDOWS_TESTS, name); Stream = new NamedPipeServerStream(name, PipeDirection.InOut); } public void Dispose() { - Environment.SetEnvironmentVariable(nameof(DivertValueTaskExecutorDelay), null); + Environment.SetEnvironmentVariable(DIVERT_WINDOWS_TESTS, null); Stream.Dispose(); } } diff --git a/Divert.Windows.Tests/NativeMethods.txt b/Divert.Windows.Tests/NativeMethods.txt new file mode 100644 index 0000000..7e147bb --- /dev/null +++ b/Divert.Windows.Tests/NativeMethods.txt @@ -0,0 +1,3 @@ +WIN32_ERROR +CreateFile +GENERIC_ACCESS_RIGHTS diff --git a/Divert.Windows/AssemblyInfo.cs b/Divert.Windows/AssemblyInfo.cs index a39cdcf..654c4a8 100644 --- a/Divert.Windows/AssemblyInfo.cs +++ b/Divert.Windows/AssemblyInfo.cs @@ -1,5 +1,3 @@ -using System.Runtime.CompilerServices; using System.Runtime.Versioning; [assembly: SupportedOSPlatform("windows6.0.6000")] -[assembly: InternalsVisibleTo("Divert.Windows.Tests")] diff --git a/Divert.Windows/AsyncOperation/CancellationHandle.cs b/Divert.Windows/AsyncOperation/CancellationHandle.cs index a25786e..dfd8d45 100644 --- a/Divert.Windows/AsyncOperation/CancellationHandle.cs +++ b/Divert.Windows/AsyncOperation/CancellationHandle.cs @@ -23,7 +23,7 @@ public void RequestOrInvokeCancel(NativeOverlapped* nativeOverlapped) int originalStatus = Interlocked.CompareExchange(ref status, Status.Canceled, Status.Idle); if (originalStatus is Status.Pending) { - using (safeHandle.Reference(out var handle)) + using (safeHandle.DangerousGetHandle(out var handle)) { _ = PInvoke.CancelIoEx(new(handle), nativeOverlapped); } @@ -36,7 +36,7 @@ public void CancelWhenRequested(NativeOverlapped* nativeOverlapped) int originalStatus = Interlocked.CompareExchange(ref status, Status.Pending, Status.Idle); if (originalStatus is Status.Canceled) { - using (safeHandle.Reference(out var handle)) + using (safeHandle.DangerousGetHandle(out var handle)) { _ = PInvoke.CancelIoEx(new(handle), nativeOverlapped); } diff --git a/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs b/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs index c706f01..38db653 100644 --- a/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs +++ b/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs @@ -9,7 +9,7 @@ internal sealed class DivertReceiveExecutor : IDivertValueTaskExecutor public unsafe bool Execute(SafeHandle safeHandle, ref readonly PendingOperation pendingOperation) { - using var _ = safeHandle.Reference(out var handle); + using var _ = safeHandle.DangerousGetHandle(out var handle); addressesLengthBuffer.Span[0] = (uint)(pendingOperation.Addresses.Length * sizeof(DivertAddress)); return NativeMethods.WinDivertRecvEx( handle, diff --git a/Divert.Windows/AsyncOperation/DivertSendExecutor.cs b/Divert.Windows/AsyncOperation/DivertSendExecutor.cs index e116dfc..a010e5a 100644 --- a/Divert.Windows/AsyncOperation/DivertSendExecutor.cs +++ b/Divert.Windows/AsyncOperation/DivertSendExecutor.cs @@ -6,7 +6,7 @@ internal sealed unsafe class DivertSendExecutor : IDivertValueTaskExecutor { public bool Execute(SafeHandle safeHandle, ref readonly PendingOperation pendingOperation) { - using var _ = safeHandle.Reference(out var handle); + using var _ = safeHandle.DangerousGetHandle(out var handle); return NativeMethods.WinDivertSendEx( handle, pendingOperation.PacketBufferHandle.Pointer, diff --git a/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs index 1a6abef..7c77360 100644 --- a/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs +++ b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs @@ -6,7 +6,7 @@ namespace Divert.Windows.AsyncOperation; -internal unsafe class DivertValueTaskSource : IDisposable, IValueTaskSource, IOCompletionHandler +internal sealed unsafe class DivertValueTaskSource : IDisposable, IValueTaskSource, IOCompletionHandler { private readonly Channel pool; private readonly IOCompletionOperation ioCompletionOperation; @@ -107,9 +107,9 @@ CancellationToken cancellationToken ) where TExecutor : IDivertValueTaskExecutor { - ref readonly var pendingOperation = ref PrepareOperation(buffer, addresses, cancellationToken); try { + ref readonly var pendingOperation = ref PrepareOperation(buffer, addresses, cancellationToken); executor.DelayExecutionInTests(); bool success = executor.Execute(SafeHandle, in pendingOperation); if (!success) diff --git a/Divert.Windows/AsyncOperation/IDivertValueTaskExecutor.cs b/Divert.Windows/AsyncOperation/IDivertValueTaskExecutor.cs index 177769d..b8d5394 100644 --- a/Divert.Windows/AsyncOperation/IDivertValueTaskExecutor.cs +++ b/Divert.Windows/AsyncOperation/IDivertValueTaskExecutor.cs @@ -11,10 +11,12 @@ internal interface IDivertValueTaskExecutor internal static class DivertValueTaskExecutorDelay { - [Conditional("DIVERT_WINDOWS_TESTS")] + private const string DIVERT_WINDOWS_TESTS = "DIVERT_WINDOWS_TESTS"; + + [Conditional(DIVERT_WINDOWS_TESTS)] public static void DelayExecutionInTests(this IDivertValueTaskExecutor executor) { - if (Environment.GetEnvironmentVariable(nameof(DivertValueTaskExecutorDelay)) is not string name) + if (Environment.GetEnvironmentVariable(DIVERT_WINDOWS_TESTS) is not string name) { return; } diff --git a/Divert.Windows/CString.cs b/Divert.Windows/CString.cs index 7dcdcca..328cd64 100644 --- a/Divert.Windows/CString.cs +++ b/Divert.Windows/CString.cs @@ -15,11 +15,5 @@ public void Dispose() Marshal.FreeHGlobal(Ptr); disposed = true; } - GC.SuppressFinalize(this); - } - - ~CString() - { - Dispose(); } } diff --git a/Divert.Windows/DivertHandle.cs b/Divert.Windows/DivertHandle.cs index 3238d3a..0a2952a 100644 --- a/Divert.Windows/DivertHandle.cs +++ b/Divert.Windows/DivertHandle.cs @@ -4,8 +4,8 @@ namespace Divert.Windows; internal sealed class DivertHandle : SafeHandleZeroOrMinusOneIsInvalid { - public DivertHandle(IntPtr handle) - : base(ownsHandle: true) + public DivertHandle(IntPtr handle, bool ownsHandle = true) + : base(ownsHandle) { this.handle = handle; } diff --git a/Divert.Windows/DivertService.cs b/Divert.Windows/DivertService.cs index cccff2d..42ca7c5 100644 --- a/Divert.Windows/DivertService.cs +++ b/Divert.Windows/DivertService.cs @@ -68,6 +68,20 @@ public DivertService( sendExecutor = new DivertSendExecutor(); } + public DivertService(SafeHandle handle, bool runContinuationsAsynchronously = true) + { + ArgumentNullException.ThrowIfNull(handle); + + using var _ = handle.DangerousGetHandle(out var nativeHandle); + divertHandle = new DivertHandle(nativeHandle, ownsHandle: false); + this.runContinuationsAsynchronously = runContinuationsAsynchronously; + threadPoolBoundHandle = ThreadPoolBoundHandle.BindHandle(divertHandle); + sendVtsPool = Channel.CreateUnbounded(); + receiveVtsPool = Channel.CreateUnbounded(); + receiveExecutor = new DivertReceiveExecutor(); + sendExecutor = new DivertSendExecutor(); + } + private bool disposed = false; private static void DisposeVtsPool(Channel pool) diff --git a/Divert.Windows/SafeHandleExtensions.cs b/Divert.Windows/SafeHandleExtensions.cs index d99aecd..0124f5b 100644 --- a/Divert.Windows/SafeHandleExtensions.cs +++ b/Divert.Windows/SafeHandleExtensions.cs @@ -2,18 +2,24 @@ namespace Divert.Windows; -internal readonly struct SafeHandleReference(T safeHandle) : IDisposable +internal ref struct SafeHandleReference(T safeHandle) : IDisposable where T : SafeHandle { + private bool disposed; + public void Dispose() { - safeHandle.DangerousRelease(); + if (!disposed) + { + safeHandle.DangerousRelease(); + disposed = true; + } } } internal static class SafeHandleExtensions { - public static SafeHandleReference Reference(this T safeHandle, out IntPtr handle) + public static SafeHandleReference DangerousGetHandle(this T safeHandle, out IntPtr handle) where T : SafeHandle { bool success = false; From b4f7126847d6c489c0fc23a0aadff7ca8a819816 Mon Sep 17 00:00:00 2001 From: gdlol Date: Sat, 8 Nov 2025 04:35:48 +0000 Subject: [PATCH 09/20] tests --- .config/cspell/cspell.json | 1 + Divert.Windows.TestRunner/Program.cs | 33 ++++- Divert.Windows.Tests/ChecksumTests.cs | 85 ++++++++++++ Divert.Windows.Tests/DivertServiceTests.cs | 110 +++++++++------ Divert.Windows.Tests/DivertTests.cs | 42 ++++++ Divert.Windows.Tests/FlowTests.cs | 69 ++++++++++ Divert.Windows.Tests/HelperTests.cs | 18 +++ Divert.Windows.Tests/ReflectTests.cs | 148 +++++++++++++++++++++ Divert.Windows.Tests/SocketTests.cs | 40 ++++++ Divert.Windows/DivertAddress.cs | 41 +----- Divert.Windows/DivertFilter.cs | 77 +++++++++-- Divert.Windows/DivertHelper.cs | 39 +++++- Divert.Windows/DivertService.cs | 5 +- Divert.Windows/NativeMethods.cs | 9 ++ 14 files changed, 625 insertions(+), 92 deletions(-) create mode 100644 Divert.Windows.Tests/ChecksumTests.cs create mode 100644 Divert.Windows.Tests/DivertTests.cs create mode 100644 Divert.Windows.Tests/FlowTests.cs create mode 100644 Divert.Windows.Tests/HelperTests.cs create mode 100644 Divert.Windows.Tests/ReflectTests.cs create mode 100644 Divert.Windows.Tests/SocketTests.cs diff --git a/.config/cspell/cspell.json b/.config/cspell/cspell.json index 9ed6f75..27b7028 100644 --- a/.config/cspell/cspell.json +++ b/.config/cspell/cspell.json @@ -15,6 +15,7 @@ "getuid", "globalconfig", "globaltool", + "ICMPV6", "libc", "msbuild", "MSTEST", diff --git a/Divert.Windows.TestRunner/Program.cs b/Divert.Windows.TestRunner/Program.cs index 69320d6..1ab0781 100644 --- a/Divert.Windows.TestRunner/Program.cs +++ b/Divert.Windows.TestRunner/Program.cs @@ -85,9 +85,12 @@ Process LaunchTestProcess(bool redirect) => Console.WriteLine("Waiting for client connection..."); using var client = await listener.AcceptSocketAsync(token); Console.WriteLine("Received client connection."); + + using var sessionCts = CancellationTokenSource.CreateLinkedTokenSource(token); + var sessionToken = sessionCts.Token; using var stream = new NetworkStream(client, ownsSocket: false); using var process = LaunchTestProcess(redirect: true); - using var _ = token.Register(() => process.Kill(entireProcessTree: true)); + using var _ = sessionToken.Register(() => process.Kill(entireProcessTree: true)); var lines = Channel.CreateBounded(1024); var stdOutReader = new StreamReader(process.StandardOutput.BaseStream); @@ -98,28 +101,42 @@ async Task ForwardLines(StreamReader reader) string? line = null; while (true) { - line = await reader.ReadLineAsync(token); + line = await reader.ReadLineAsync(sessionToken); if (line is null) { break; } - await lines.Writer.WriteAsync(line, token); + await lines.Writer.WriteAsync(line, sessionToken); } } var readStdOutTask = ForwardLines(stdOutReader); var readStdErrTask = ForwardLines(stdErrReader); using var writer = new StreamWriter(stream) { AutoFlush = true }; + var readTask = Task.Run( + async () => + { + try + { + await client.ReceiveAsync(new byte[ushort.MaxValue], sessionToken); // Monitor disconnect + } + finally + { + sessionCts.Cancel(); + } + }, + sessionToken + ); var writeTask = Task.Run( async () => { - await foreach (var line in lines.Reader.ReadAllAsync(token)) + await foreach (var line in lines.Reader.ReadAllAsync(sessionToken)) { Console.WriteLine(line); - await writer.WriteLineAsync(line.AsMemory(), token); + await writer.WriteLineAsync(line.AsMemory(), sessionToken); } }, - token + sessionToken ); await process.WaitForExitAsync(token); @@ -127,8 +144,10 @@ async Task ForwardLines(StreamReader reader) await Task.WhenAll(readStdOutTask, readStdErrTask, writeTask) .ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); await writer - .WriteLineAsync(process.ExitCode.ToString().AsMemory(), token) + .WriteLineAsync(process.ExitCode.ToString().AsMemory(), sessionToken) .ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + client.Close(); + await readTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); } } catch (OperationCanceledException) when (token.IsCancellationRequested) { } diff --git a/Divert.Windows.Tests/ChecksumTests.cs b/Divert.Windows.Tests/ChecksumTests.cs new file mode 100644 index 0000000..1c39311 --- /dev/null +++ b/Divert.Windows.Tests/ChecksumTests.cs @@ -0,0 +1,85 @@ +using System.Buffers.Binary; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; + +namespace Divert.Windows.Tests; + +[TestClass] +public class ChecksumTests : DivertTests +{ + [TestMethod] + public async Task InvalidChecksum() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(Token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); + + var divertResult = await divertReceive; + var packet = packetBuffer.AsMemory(0, divertResult.DataLength); + + // Calculate and then invalidate checksums + Assert.IsTrue(DivertHelper.CalculateChecksums(packet.Span)); + ushort ipChecksum = BinaryPrimitives.ReadUInt16BigEndian(packet.Span[10..12]); + ushort udpChecksum = BinaryPrimitives.ReadUInt16BigEndian(packet.Span[26..28]); + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[10..12], (ushort)~ipChecksum); // invalidate IP checksum + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[26..28], (ushort)~udpChecksum); // invalidate UDP checksum + + // Recalculate + addressBuffer[0].IsIPChecksumValid = false; + addressBuffer[0].IsTCPChecksumValid = false; + addressBuffer[0].IsUDPChecksumValid = false; + Assert.IsTrue(DivertHelper.CalculateChecksums(packet.Span, ref addressBuffer[0])); + Assert.AreEqual(ipChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[10..12])); + Assert.AreEqual(udpChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[26..28])); + Assert.IsTrue(addressBuffer[0].IsIPChecksumValid); + Assert.IsTrue(addressBuffer[0].IsUDPChecksumValid); + Assert.IsFalse(addressBuffer[0].IsTCPChecksumValid); + + // Recalculate only UDP checksum + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[10..12], (ushort)~ipChecksum); // invalidate IP checksum + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[26..28], (ushort)~udpChecksum); // invalidate UDP checksum + addressBuffer[0].IsIPChecksumValid = false; + addressBuffer[0].IsTCPChecksumValid = false; + addressBuffer[0].IsUDPChecksumValid = false; + Assert.IsTrue( + DivertHelper.CalculateChecksums(packet.Span, ref addressBuffer[0], DivertHelperFlags.NoIPChecksum) + ); + Assert.AreNotEqual(ipChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[10..12])); + Assert.AreEqual(udpChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[26..28])); + Assert.IsFalse(addressBuffer[0].IsIPChecksumValid); + Assert.IsTrue(addressBuffer[0].IsUDPChecksumValid); + Assert.IsFalse(addressBuffer[0].IsTCPChecksumValid); + + // Recalculate only IP checksum + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[10..12], (ushort)~ipChecksum); // invalidate IP checksum + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[26..28], (ushort)~udpChecksum); // invalidate UDP checksum + Assert.IsTrue(DivertHelper.CalculateChecksums(packet.Span, DivertHelperFlags.NoUDPChecksum)); + Assert.AreEqual(ipChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[10..12])); + Assert.AreNotEqual(udpChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[26..28])); + + // Invalid packet + Assert.IsFalse(DivertHelper.CalculateChecksums(default)); + Assert.AreEqual(0, Marshal.GetLastPInvokeError()); + Assert.IsFalse(DivertHelper.CalculateChecksums(default, ref addressBuffer[0])); + Assert.AreEqual(0, Marshal.GetLastPInvokeError()); + } +} diff --git a/Divert.Windows.Tests/DivertServiceTests.cs b/Divert.Windows.Tests/DivertServiceTests.cs index 3b9083b..c995c6a 100644 --- a/Divert.Windows.Tests/DivertServiceTests.cs +++ b/Divert.Windows.Tests/DivertServiceTests.cs @@ -1,4 +1,5 @@ using System.ComponentModel; +using System.Diagnostics; using System.Net; using System.Net.Sockets; using Windows.Win32; @@ -8,35 +9,44 @@ namespace Divert.Windows.Tests; [TestClass] -public sealed class DivertServiceTests : IDisposable +public sealed class DivertServiceTests : DivertTests { - private readonly CancellationTokenSource cts; - private readonly CancellationToken token; - - public DivertServiceTests() - { - cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); - token = cts.Token; - } - - public void Dispose() + private static IEnumerable InvalidFilterCases() { - cts.Dispose(); + return + [ + [ + DivertLayer.Network, + DivertFilter.Protocol == "invalid", + "Filter expression contains a bad token (11): ...invalid", + ], + [ + DivertLayer.Network, + DivertFilter.Ip & DivertFilter.Layer == DivertLayer.Network & DivertFilter.Loopback, + "Filter expression contains a bad token for layer (7): ...layer = NETWORK and loopback", + ], + [ + DivertLayer.Network, + DivertFilter.Ip & DivertFilter.Event == DivertEvent.SocketBind & DivertFilter.Loopback, + "Filter expression parse error (15): ...BIND and loopback", + ], + ]; } - private static UdpClient CreateUdpListener(out int port) + [TestMethod] + [DynamicData(nameof(InvalidFilterCases))] + public void InvalidFilter(DivertLayer layer, DivertFilter filter, string message) { - var client = new UdpClient(new IPEndPoint(IPAddress.Loopback, 0)); - var localEndPoint = (IPEndPoint)client.Client.LocalEndPoint!; - port = localEndPoint.Port; - return client; + var exception = Assert.Throws(() => new DivertService(filter, layer: layer)); + Assert.AreEqual("filter", exception.ParamName); + Assert.StartsWith(message, exception.Message); } [TestMethod] public async Task ModifyPayload() { using var listener = CreateUdpListener(out int port); - var receive = listener.ReceiveAsync(token).AsTask(); + var receive = listener.ReceiveAsync(Token).AsTask(); var filter = DivertFilter.UDP @@ -49,34 +59,56 @@ public async Task ModifyPayload() var packetBuffer = new byte[ushort.MaxValue + 40]; var addressBuffer = new DivertAddress[1]; - var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, token).AsTask(); + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); Assert.IsFalse(divertReceive.IsCompleted); // send 3 bytes payload using var client = new UdpClient(); client.Connect(IPAddress.Loopback, port); - await client.SendAsync(new byte[] { 1, 2, 3 }, token); + + long begin = Stopwatch.GetTimestamp(); + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); var divertResult = await divertReceive; + long end = Stopwatch.GetTimestamp(); + Assert.AreEqual(20 + 8 + 3, divertResult.DataLength); var packet = packetBuffer.AsMemory(0, divertResult.DataLength); CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, packet.ToArray()[28..]); Assert.AreEqual(1, divertResult.AddressLength); Assert.IsFalse(receive.IsCompleted); + Assert.AreEqual(DivertLayer.Network, addressBuffer[0].Layer); + Assert.IsGreaterThanOrEqualTo(begin, addressBuffer[0].Timestamp); + Assert.IsLessThanOrEqualTo(end, addressBuffer[0].Timestamp); + Assert.AreEqual(DivertEvent.NetworkPacket, addressBuffer[0].Event); + Assert.IsFalse(addressBuffer[0].IsSniffed); + Assert.IsTrue(addressBuffer[0].IsOutbound); + Assert.IsTrue(addressBuffer[0].IsLoopback); + Assert.IsFalse(addressBuffer[0].IsImpostor); + Assert.IsFalse(addressBuffer[0].IsIPv6); + var networkData = addressBuffer[0].GetNetworkData(); + Assert.AreEqual(GetLoopbackInterfaceIndex(), (int)networkData.InterfaceIndex); + Assert.AreEqual(0, (int)networkData.SubInterfaceIndex); + Assert.Throws(() => addressBuffer[0].GetFlowData()); + Assert.Throws(() => addressBuffer[0].GetSocketData()); + Assert.Throws(() => addressBuffer[0].GetReflectData()); // Re-inject - addressBuffer[0] = new DivertAddress { IsImpostor = true, IsOutbound = true }; - await service.SendAsync(packet, addressBuffer, token); + addressBuffer[0] = new DivertAddress(1, 1) { IsImpostor = true, IsOutbound = true }; + await service.SendAsync(packet, addressBuffer, Token); var result = await receive; Assert.HasCount(3, result.Buffer); CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, result.Buffer); // Re-inject with modified data - receive = listener.ReceiveAsync(token).AsTask(); + addressBuffer[0].Reset(); + addressBuffer[0].IsImpostor = true; + addressBuffer[0].IsOutbound = true; + receive = listener.ReceiveAsync(Token).AsTask(); new byte[] { 4, 5, 6 }.CopyTo(packet.Span[28..]); - DivertHelper.CalculateChecksums(packet.Span); + Assert.IsTrue(DivertHelper.CalculateChecksums(packet.Span)); Assert.IsFalse(receive.IsCompleted); - await service.SendAsync(packet, addressBuffer, token); + await service.SendAsync(packet, addressBuffer, CancellationToken.None); result = await receive; Assert.HasCount(3, result.Buffer); CollectionAssert.AreEqual(new byte[] { 4, 5, 6 }, result.Buffer); @@ -86,7 +118,7 @@ public async Task ModifyPayload() public async Task Close() { using var listener = CreateUdpListener(out int port); - var receive = listener.ReceiveAsync(token).AsTask(); + var receive = listener.ReceiveAsync(Token).AsTask(); var filter = DivertFilter.UDP @@ -98,13 +130,13 @@ public async Task Close() var packetBuffer = new byte[ushort.MaxValue + 40]; var addressBuffer = new DivertAddress[1]; - var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, token).AsTask(); + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); Assert.IsFalse(divertReceive.IsCompleted); service.Dispose(); var exception = await Assert.ThrowsAsync(async () => await divertReceive); - Assert.IsFalse(token.IsCancellationRequested); - Assert.AreNotEqual(token, exception.CancellationToken); + Assert.IsFalse(Token.IsCancellationRequested); + Assert.AreNotEqual(Token, exception.CancellationToken); Assert.AreEqual(CancellationToken.None, exception.CancellationToken); } @@ -136,7 +168,7 @@ public async Task CancelReceive() // Cancel on receive using (var pipe = new ExecutorDelayPipe()) { - using var cts = CancellationTokenSource.CreateLinkedTokenSource(this.token); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(Token); var token = cts.Token; var divertReceive = Task.Run(async () => await service.ReceiveAsync(packetBuffer, addressBuffer, token)); @@ -156,7 +188,7 @@ public async Task CancelReceive() // Cancel pending receive { - using var cts = CancellationTokenSource.CreateLinkedTokenSource(this.token); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(Token); var token = cts.Token; var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, token).AsTask(); @@ -177,7 +209,7 @@ public void CancelSend() var packetBuffer = new byte[100]; var addressBuffer = new DivertAddress[1]; - using var cts = CancellationTokenSource.CreateLinkedTokenSource(this.token); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(Token); var token = cts.Token; cts.Cancel(); var divertSend = service.SendAsync(packetBuffer, addressBuffer, token).AsTask(); @@ -188,7 +220,7 @@ public void CancelSend() public async Task InsufficientBuffer() { using var listener = CreateUdpListener(out int port); - var receive = listener.ReceiveAsync(token).AsTask(); + var receive = listener.ReceiveAsync(Token).AsTask(); var filter = DivertFilter.UDP @@ -201,13 +233,13 @@ public async Task InsufficientBuffer() var packetBuffer = new byte[10]; // insufficient buffer var addressBuffer = new DivertAddress[1]; - var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, token).AsTask(); + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); Assert.IsFalse(divertReceive.IsCompleted); // send 3 bytes payload using var client = new UdpClient(); client.Connect(IPAddress.Loopback, port); - await client.SendAsync(new byte[] { 1, 2, 3 }, token); + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); var exception = await Assert.ThrowsAsync(async () => await divertReceive); Assert.AreEqual((int)WIN32_ERROR.ERROR_INSUFFICIENT_BUFFER, exception.NativeErrorCode); @@ -230,14 +262,14 @@ public async Task DisposeOnReceive() var addressBuffer = new DivertAddress[1]; using var pipe = new ExecutorDelayPipe(); - var divertReceive = Task.Run(async () => await service.ReceiveAsync(packetBuffer, addressBuffer, token)); + var divertReceive = Task.Run(async () => await service.ReceiveAsync(packetBuffer, addressBuffer, Token)); Assert.IsFalse(divertReceive.IsCompleted); - await pipe.Stream.WaitForConnectionAsync(token); + await pipe.Stream.WaitForConnectionAsync(Token); Assert.IsFalse(divertReceive.IsCompleted); service.Dispose(); pipe.Stream.WriteByte(0); - await pipe.Stream.FlushAsync(token); + await pipe.Stream.FlushAsync(Token); pipe.Stream.Disconnect(); await Assert.ThrowsAsync(async () => await divertReceive); @@ -260,7 +292,7 @@ public async Task InvalidHandle() var packetBuffer = new byte[ushort.MaxValue + 40]; var addressBuffer = new DivertAddress[1]; var exception = await Assert.ThrowsAsync(async () => - await service.ReceiveAsync(packetBuffer, addressBuffer, token) + await service.ReceiveAsync(packetBuffer, addressBuffer, Token) ); Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); } diff --git a/Divert.Windows.Tests/DivertTests.cs b/Divert.Windows.Tests/DivertTests.cs new file mode 100644 index 0000000..29bfeb4 --- /dev/null +++ b/Divert.Windows.Tests/DivertTests.cs @@ -0,0 +1,42 @@ +using System.Net; +using System.Net.NetworkInformation; +using System.Net.Sockets; + +namespace Divert.Windows.Tests; + +public abstract class DivertTests : IDisposable +{ + private readonly CancellationTokenSource cts; + private readonly CancellationToken token; + + protected CancellationToken Token => token; + + public DivertTests() + { + cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + token = cts.Token; + } + + public void Dispose() + { + cts.Dispose(); + } + + public static UdpClient CreateUdpListener(out int port) + { + var client = new UdpClient(new IPEndPoint(IPAddress.Loopback, 0)); + var localEndPoint = (IPEndPoint)client.Client.LocalEndPoint!; + port = localEndPoint.Port; + return client; + } + + public static int GetLoopbackInterfaceIndex() + { + return NetworkInterface + .GetAllNetworkInterfaces() + .Single(i => i.NetworkInterfaceType == NetworkInterfaceType.Loopback) + .GetIPProperties() + .GetIPv4Properties() + .Index; + } +} diff --git a/Divert.Windows.Tests/FlowTests.cs b/Divert.Windows.Tests/FlowTests.cs new file mode 100644 index 0000000..799240b --- /dev/null +++ b/Divert.Windows.Tests/FlowTests.cs @@ -0,0 +1,69 @@ +using System.Net; +using System.Net.Sockets; + +namespace Divert.Windows.Tests; + +[TestClass] +public class FlowTests : DivertTests +{ + [TestMethod] + public async Task FlowData() + { + using var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + int port = ((IPEndPoint)listener.LocalEndpoint).Port; + + var filter = DivertFilter.TCP & DivertFilter.Loopback & (DivertFilter.RemotePort == port); + using var service = new DivertService( + filter, + DivertLayer.Flow, + flags: DivertFlags.Sniff | DivertFlags.ReceiveOnly + ); + + var addressBuffer = new DivertAddress[1]; + var divertReceive = service.ReceiveAsync(default, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + var accept = listener.AcceptSocketAsync(Token); + Assert.IsFalse(accept.IsCompleted); + await client.ConnectAsync(IPAddress.Loopback, port, Token); + int clientPort = ((IPEndPoint)client.LocalEndPoint!).Port; + using var serverSocket = await accept; + await client.SendAsync(new byte[] { 1, 2, 3 }, SocketFlags.None); + + var divertResult = await divertReceive; + var address = addressBuffer[0]; + Assert.AreEqual(0, divertResult.DataLength); + Assert.AreEqual(DivertLayer.Flow, address.Layer); + Assert.AreEqual(DivertEvent.FlowEstablished, address.Event); + Assert.Throws(() => address.GetNetworkData()); + var flowData = address.GetFlowData(); + Assert.AreEqual(Environment.ProcessId, (int)flowData.ProcessId); + Assert.AreEqual(IPAddress.Loopback, flowData.LocalAddress); + Assert.AreEqual(IPAddress.Loopback, flowData.RemoteAddress); + Assert.AreEqual((ushort)port, flowData.RemotePort); + Assert.AreEqual((ushort)clientPort, flowData.LocalPort); + Assert.AreEqual((byte)ProtocolType.Tcp, flowData.Protocol); + + addressBuffer[0].Reset(); + divertReceive = service.ReceiveAsync(default, addressBuffer, Token).AsTask(); + + client.Shutdown(SocketShutdown.Both); + serverSocket.Shutdown(SocketShutdown.Both); + serverSocket.Close(); + client.Close(); + + divertResult = await divertReceive; + address = addressBuffer[0]; + Assert.AreEqual(DivertLayer.Flow, address.Layer); + Assert.AreEqual(DivertEvent.FlowDeleted, address.Event); + flowData = address.GetFlowData(); + Assert.AreEqual(Environment.ProcessId, (int)flowData.ProcessId); + Assert.AreEqual(IPAddress.Loopback, flowData.LocalAddress); + Assert.AreEqual(IPAddress.Loopback, flowData.RemoteAddress); + Assert.AreEqual((ushort)port, flowData.RemotePort); + Assert.AreEqual((ushort)clientPort, flowData.LocalPort); + Assert.AreEqual((byte)ProtocolType.Tcp, flowData.Protocol); + } +} diff --git a/Divert.Windows.Tests/HelperTests.cs b/Divert.Windows.Tests/HelperTests.cs new file mode 100644 index 0000000..28da73f --- /dev/null +++ b/Divert.Windows.Tests/HelperTests.cs @@ -0,0 +1,18 @@ +using System.ComponentModel; +using Windows.Win32.Foundation; + +namespace Divert.Windows.Tests; + +[TestClass] +public class HelperTests : DivertTests +{ + [TestMethod] + public void InvalidFilter() + { + var invalidFilter = Array.Empty(); + var exception = Assert.Throws(() => + DivertHelper.FormatFilter(invalidFilter, DivertLayer.Network) + ); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + } +} diff --git a/Divert.Windows.Tests/ReflectTests.cs b/Divert.Windows.Tests/ReflectTests.cs new file mode 100644 index 0000000..1984be5 --- /dev/null +++ b/Divert.Windows.Tests/ReflectTests.cs @@ -0,0 +1,148 @@ +using System.ComponentModel; +using System.Diagnostics; +using System.Net.Sockets; +using Windows.Win32.Foundation; + +namespace Divert.Windows.Tests; + +[TestClass] +public class ReflectTests : DivertTests +{ + [TestMethod] + public async Task ReflectData() + { + using var reflectService = new DivertService( + true, + DivertLayer.Reflect, + flags: DivertFlags.ReceiveOnly | DivertFlags.Sniff | DivertFlags.NoInstall + ); + + var dataBuffer = new byte[ushort.MaxValue]; + var addressBuffer = new DivertAddress[1]; + var divertReceive = reflectService.ReceiveAsync(dataBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + long begin = Stopwatch.GetTimestamp(); + using var service = new DivertService(false, priority: 5, flags: DivertFlags.WriteOnly); + long end = Stopwatch.GetTimestamp(); + + var divertResult = await divertReceive; + var address = addressBuffer[0]; + Assert.AreEqual(DivertLayer.Reflect, address.Layer); + Assert.AreEqual(DivertEvent.ReflectOpen, address.Event); + var reflectData = address.GetReflectData(); + Assert.IsTrue(reflectData.Timestamp >= begin && reflectData.Timestamp <= end); + Assert.AreEqual(Environment.ProcessId, (int)reflectData.ProcessId); + Assert.AreEqual(DivertLayer.Network, reflectData.Layer); + Assert.AreEqual(DivertFlags.WriteOnly, reflectData.Flags); + Assert.AreEqual(5, reflectData.Priority); + var packet = dataBuffer.AsSpan(0, divertResult.DataLength); + string filterString = DivertHelper.FormatFilter(packet, reflectData.Layer); + Assert.AreEqual("false", filterString); + } + + public static IEnumerable FormatFilterCases() + { + return + [ + [ + DivertLayer.Socket, + DivertFilter.Protocol == ProtocolType.Tcp + | DivertFilter.Protocol == ProtocolType.Udp + | DivertFilter.Protocol == ProtocolType.Icmp + | DivertFilter.Protocol == ProtocolType.IcmpV6, + $"protocol = {(int)ProtocolType.Tcp} or " + + $"protocol = {(int)ProtocolType.Udp} or " + + $"protocol = {(int)ProtocolType.Icmp} or " + + $"protocol = {(int)ProtocolType.IcmpV6}", + ], + [DivertLayer.Socket, DivertFilter.Protocol == ProtocolType.IP, $"protocol = {(int)ProtocolType.IP}"], + [DivertLayer.Socket, DivertFilter.Ip == true, "ip"], + [DivertLayer.Socket, DivertFilter.Ip == false, "not ip"], + [DivertLayer.Socket, DivertFilter.Ip != true, "not ip"], + [DivertLayer.Socket, DivertFilter.Ip != false, "ip"], + [DivertLayer.Socket, DivertFilter.Protocol == ProtocolType.Unknown, "false"], + [DivertLayer.Socket, DivertFilter.Protocol != ProtocolType.Unknown, "true"], + [DivertLayer.Socket, DivertFilter.Protocol == Enum.GetValues().Max() + 1, "false"], + [ + DivertLayer.Socket, + ( + DivertFilter.Event == DivertEvent.SocketBind + | DivertFilter.Event == DivertEvent.SocketConnect + | DivertFilter.Event == DivertEvent.SocketListen + | DivertFilter.Event == DivertEvent.SocketAccept + | DivertFilter.Event == DivertEvent.SocketClose + ), + "event = BIND or event = CONNECT or event = LISTEN or event = ACCEPT or event = CLOSE", + ], + [DivertLayer.Network, DivertFilter.Event == DivertEvent.NetworkPacket, "event = PACKET"], + [DivertLayer.Forward, DivertFilter.Event != DivertEvent.NetworkPacket, "event != PACKET"], + [ + DivertLayer.Flow, + (DivertFilter.Event == DivertEvent.FlowEstablished | DivertFilter.Event == DivertEvent.FlowDeleted), + "event = ESTABLISHED or event = DELETED", + ], + ]; + } + + [TestMethod] + [DynamicData(nameof(FormatFilterCases))] + public async Task FormatFilter(DivertLayer layer, DivertFilter filter, string expected) + { + using var reflectService = new DivertService( + DivertFilter.Event == DivertEvent.ReflectOpen + & DivertFilter.ProcessId == Environment.ProcessId + & ( + DivertFilter.Layer == DivertLayer.Network + | DivertFilter.Layer == DivertLayer.Socket + | DivertFilter.Layer == DivertLayer.Forward + | DivertFilter.Layer == DivertLayer.Flow + | DivertFilter.Layer != DivertLayer.Reflect + ), + DivertLayer.Reflect, + priority: 2, + DivertFlags.ReceiveOnly | DivertFlags.Sniff | DivertFlags.NoInstall + ); + + var dataBuffer = new byte[ushort.MaxValue]; + var addressBuffer = new DivertAddress[1]; + var divertReceive = reflectService.ReceiveAsync(dataBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + using var service = new DivertService(filter, layer, priority: 1, DivertFlags.Sniff | DivertFlags.ReceiveOnly); + + var divertResult = await divertReceive; + var address = addressBuffer[0]; + Assert.AreEqual(DivertLayer.Reflect, address.Layer); + Assert.AreEqual(DivertEvent.ReflectOpen, address.Event); + var reflectData = address.GetReflectData(); + Assert.AreEqual(Environment.ProcessId, (int)reflectData.ProcessId); + Assert.AreEqual(layer, reflectData.Layer); + Assert.AreEqual(DivertFlags.Sniff | DivertFlags.ReceiveOnly, reflectData.Flags); + Assert.AreEqual(1, reflectData.Priority); + var packet = dataBuffer.AsSpan(0, divertResult.DataLength); + string filterString = DivertHelper.FormatFilter(packet, reflectData.Layer); + Assert.AreEqual(expected, filterString); + } + + [TestMethod] + public void InvalidFilter() + { + var exception = Assert.Throws(() => + { + using var service = new DivertService( + DivertFilter.Event == Enum.GetValues().Max() + 1, + DivertLayer.Reflect + ); + }); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + exception = Assert.Throws(() => + { + using var service = new DivertService( + DivertFilter.Layer == Enum.GetValues().Max() + 1, + DivertLayer.Reflect + ); + }); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + } +} diff --git a/Divert.Windows.Tests/SocketTests.cs b/Divert.Windows.Tests/SocketTests.cs new file mode 100644 index 0000000..9dece24 --- /dev/null +++ b/Divert.Windows.Tests/SocketTests.cs @@ -0,0 +1,40 @@ +using System.Net; +using System.Net.Sockets; + +namespace Divert.Windows.Tests; + +[TestClass] +public class SocketTests : DivertTests +{ + [TestMethod] + public async Task SocketData() + { + var filter = DivertFilter.TCP & DivertFilter.Loopback & (DivertFilter.ProcessId == Environment.ProcessId); + using var service = new DivertService( + filter, + DivertLayer.Socket, + flags: DivertFlags.Sniff | DivertFlags.ReceiveOnly + ); + + var addressBuffer = new DivertAddress[1]; + var divertReceive = service.ReceiveAsync(default, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + socket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + int port = ((IPEndPoint)socket.LocalEndPoint!).Port; + + var divertResult = await divertReceive; + var address = addressBuffer[0]; + Assert.AreEqual(0, divertResult.DataLength); + Assert.AreEqual(DivertLayer.Socket, address.Layer); + Assert.AreEqual(DivertEvent.SocketBind, address.Event); + var socketData = address.GetSocketData(); + Assert.AreEqual(Environment.ProcessId, (int)socketData.ProcessId); + Assert.AreEqual(IPAddress.Loopback, socketData.LocalAddress); + Assert.AreEqual(IPAddress.Any, socketData.RemoteAddress); + Assert.AreEqual((ushort)port, socketData.LocalPort); + Assert.AreEqual(0, socketData.RemotePort); + Assert.AreEqual((byte)ProtocolType.Tcp, socketData.Protocol); + } +} diff --git a/Divert.Windows/DivertAddress.cs b/Divert.Windows/DivertAddress.cs index 334d9a0..f152b35 100644 --- a/Divert.Windows/DivertAddress.cs +++ b/Divert.Windows/DivertAddress.cs @@ -48,11 +48,6 @@ public struct ReflectData private WINDIVERT_ADDRESS address; - internal DivertAddress(WINDIVERT_ADDRESS address) - { - this.address = address; - } - public DivertAddress(int interfaceIndex, int subInterfaceIndex) { address = new WINDIVERT_ADDRESS @@ -73,23 +68,11 @@ public void Reset() } } - public long Timestamp - { - readonly get { return address.Timestamp; } - set { address.Timestamp = value; } - } + public readonly long Timestamp => address.Timestamp; - public DivertLayer Layer - { - readonly get { return (DivertLayer)address.Layer; } - set { address.Layer = (byte)value; } - } + public readonly DivertLayer Layer => (DivertLayer)address.Layer; - public DivertEvent Event - { - readonly get { return (DivertEvent)address.Event; } - set { address.Layer = (byte)value; } - } + public readonly DivertEvent Event => (DivertEvent)address.Event; private readonly bool GetBit(WINDIVERT_ADDRESS_BITS bit) { @@ -108,11 +91,7 @@ private void SetBit(WINDIVERT_ADDRESS_BITS bit, bool value) } } - public bool IsSniffed - { - readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.Sniffed); } - set { SetBit(WINDIVERT_ADDRESS_BITS.Sniffed, value); } - } + public readonly bool IsSniffed => GetBit(WINDIVERT_ADDRESS_BITS.Sniffed); public bool IsOutbound { @@ -120,11 +99,7 @@ public bool IsOutbound set { SetBit(WINDIVERT_ADDRESS_BITS.Outbound, value); } } - public bool IsLoopback - { - readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.Loopback); } - set { SetBit(WINDIVERT_ADDRESS_BITS.Loopback, value); } - } + public readonly bool IsLoopback => GetBit(WINDIVERT_ADDRESS_BITS.Loopback); public bool IsImpostor { @@ -132,11 +107,7 @@ public bool IsImpostor set { SetBit(WINDIVERT_ADDRESS_BITS.Impostor, value); } } - public bool IsIPv6 - { - readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.IPv6); } - set { SetBit(WINDIVERT_ADDRESS_BITS.IPv6, value); } - } + public readonly bool IsIPv6 => GetBit(WINDIVERT_ADDRESS_BITS.IPv6); public bool IsIPChecksumValid { diff --git a/Divert.Windows/DivertFilter.cs b/Divert.Windows/DivertFilter.cs index d47ff80..6968ca9 100644 --- a/Divert.Windows/DivertFilter.cs +++ b/Divert.Windows/DivertFilter.cs @@ -1,3 +1,4 @@ +using System.Net.Sockets; using System.Text; using System.Text.RegularExpressions; @@ -16,12 +17,8 @@ public DivertFilter(string clause) Clause = clause; } - public static implicit operator DivertFilter(string clause) - { - return new DivertFilter(clause); - } - - private static string ReplaceParentheses(string expression) + // e.g. "(a and (b or c)) or d" -> "() or d" + private static string CollapseParentheses(string expression) { var builder = new StringBuilder(); int index = 0; @@ -68,12 +65,12 @@ private static string ReplaceParentheses(string expression) [GeneratedRegex(@"\s(or|\|\|)\s")] private static partial Regex OrPatternRegex(); - private static bool MatchOrPattern(string s) => OrPatternRegex().IsMatch(ReplaceParentheses(s)); + private static bool MatchOrPattern(string s) => OrPatternRegex().IsMatch(CollapseParentheses(s)); [GeneratedRegex(@"\s(and|&&)\s")] private static partial Regex AndPatternRegex(); - private static bool MatchAndPattern(string s) => AndPatternRegex().IsMatch(ReplaceParentheses(s)); + private static bool MatchAndPattern(string s) => AndPatternRegex().IsMatch(CollapseParentheses(s)); public static DivertFilter operator &(DivertFilter left, DivertFilter right) { @@ -179,6 +176,44 @@ public static implicit operator DivertFilter(Field field) return new DivertFilter(clause); } + private static string Macro(bool value) => value ? "TRUE" : "FALSE"; + + private static string Macro(ProtocolType protocolType) => + protocolType switch + { + ProtocolType.Tcp => "TCP", + ProtocolType.Udp => "UDP", + ProtocolType.Icmp => "ICMP", + ProtocolType.IcmpV6 => "ICMPV6", + _ => ((int)protocolType).ToString(), + }; + + private static string Macro(DivertEvent e) => + e switch + { + DivertEvent.NetworkPacket => "PACKET", + DivertEvent.FlowEstablished => "ESTABLISHED", + DivertEvent.FlowDeleted => "DELETED", + DivertEvent.SocketBind => "BIND", + DivertEvent.SocketConnect => "CONNECT", + DivertEvent.SocketListen => "LISTEN", + DivertEvent.SocketAccept => "ACCEPT", + DivertEvent.ReflectOpen => "OPEN", + DivertEvent.SocketClose or DivertEvent.ReflectClose => "CLOSE", + _ => e.ToString(), + }; + + private static string Macro(DivertLayer layer) => + layer switch + { + DivertLayer.Network => "NETWORK", + DivertLayer.Forward => "NETWORK_FORWARD", + DivertLayer.Flow => "FLOW", + DivertLayer.Socket => "SOCKET", + DivertLayer.Reflect => "REFLECT", + _ => layer.ToString(), + }; + public static DivertFilter operator ==(Field left, object right) { string clause = $"{left} = {right}"; @@ -191,6 +226,22 @@ public static implicit operator DivertFilter(Field field) return new DivertFilter(clause); } + public static DivertFilter operator ==(Field left, bool right) => left == Macro(right); + + public static DivertFilter operator !=(Field left, bool right) => left != Macro(right); + + public static DivertFilter operator ==(Field left, ProtocolType right) => left == Macro(right); + + public static DivertFilter operator !=(Field left, ProtocolType right) => left != Macro(right); + + public static DivertFilter operator ==(Field left, DivertEvent right) => left == Macro(right); + + public static DivertFilter operator !=(Field left, DivertEvent right) => left != Macro(right); + + public static DivertFilter operator ==(Field left, DivertLayer right) => left == Macro(right); + + public static DivertFilter operator !=(Field left, DivertLayer right) => left != Macro(right); + public static DivertFilter operator <(Field left, object right) { string clause = $"{left} < {right}"; @@ -286,4 +337,14 @@ public static implicit operator DivertFilter(Field field) public static Field RemoteAddress { get; } = "remoteAddr"; public static Field RemotePort { get; } = "remotePort"; + + public static implicit operator DivertFilter(string clause) + { + return new DivertFilter(clause); + } + + public static implicit operator DivertFilter(bool value) + { + return value ? True : False; + } } diff --git a/Divert.Windows/DivertHelper.cs b/Divert.Windows/DivertHelper.cs index 26d1d6c..fdfdb54 100644 --- a/Divert.Windows/DivertHelper.cs +++ b/Divert.Windows/DivertHelper.cs @@ -5,15 +5,50 @@ namespace Divert.Windows; public static unsafe class DivertHelper { - public static void CalculateChecksums(Span packet, DivertHelperFlags flags = DivertHelperFlags.None) + public static bool CalculateChecksums(Span packet, DivertHelperFlags flags = DivertHelperFlags.None) { fixed (byte* pPacket = packet) { - bool success = NativeMethods.WinDivertHelperCalcChecksums(pPacket, (uint)packet.Length, null, (ulong)flags); + return NativeMethods.WinDivertHelperCalcChecksums(pPacket, (uint)packet.Length, null, (ulong)flags); + } + } + + public static bool CalculateChecksums( + Span packet, + ref DivertAddress address, + DivertHelperFlags flags = DivertHelperFlags.None + ) + { + fixed (byte* pPacket = packet) + fixed (DivertAddress* pAddress = &address) + { + return NativeMethods.WinDivertHelperCalcChecksums( + pPacket, + (uint)packet.Length, + (WINDIVERT_ADDRESS*)pAddress, + (ulong)flags + ); + } + } + + public static string FormatFilter(Span filter, DivertLayer layer) + { + Memory buffer = GC.AllocateArray(ushort.MaxValue, pinned: true); + using var bufferHandle = buffer.Pin(); + fixed (byte* pFilter = filter) + { + bool success = NativeMethods.WinDivertHelperFormatFilter( + new(pFilter), + (WINDIVERT_LAYER)layer, + (byte*)bufferHandle.Pointer, + (uint)buffer.Length + ); if (!success) { throw new Win32Exception(Marshal.GetLastPInvokeError()); } + + return Marshal.PtrToStringAnsi(new IntPtr(bufferHandle.Pointer))!; } } } diff --git a/Divert.Windows/DivertService.cs b/Divert.Windows/DivertService.cs index 42ca7c5..95fb5c3 100644 --- a/Divert.Windows/DivertService.cs +++ b/Divert.Windows/DivertService.cs @@ -51,7 +51,10 @@ public DivertService( if (!success) { string? errorString = Marshal.PtrToStringAnsi(errorStr); - throw new ArgumentException($"{errorPos}: {errorString}", nameof(filter)); + throw new ArgumentException( + $"{errorString} ({errorPos}): ...{filter.Clause[(int)errorPos..]}", + nameof(filter) + ); } var handle = NativeMethods.WinDivertOpen(s.Ptr, (WINDIVERT_LAYER)layer, priority, (ulong)flags); diff --git a/Divert.Windows/NativeMethods.cs b/Divert.Windows/NativeMethods.cs index d817327..120f303 100644 --- a/Divert.Windows/NativeMethods.cs +++ b/Divert.Windows/NativeMethods.cs @@ -90,4 +90,13 @@ public static partial bool WinDivertHelperCompileFilter( IntPtr* errorStr, uint* errorPos ); + + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertHelperFormatFilter( + IntPtr filter, + WINDIVERT_LAYER layer, + byte* buffer, + uint bufLen + ); } From ed427b7cae7deac28484c6332139ee7ce54fbdd4 Mon Sep 17 00:00:00 2001 From: gdlol Date: Sat, 8 Nov 2025 15:45:13 +0000 Subject: [PATCH 10/20] DivertIOControl and helpers. --- Divert.Windows.Tests/DivertServiceTests.cs | 49 +++++- Divert.Windows.Tests/HelperTests.cs | 126 +++++++++++++- .../AsyncOperation/DivertReceiveExecutor.cs | 8 +- Divert.Windows/CString.cs | 4 +- Divert.Windows/Constants.cs | 16 ++ Divert.Windows/DivertHandle.cs | 2 +- Divert.Windows/DivertHelper.cs | 96 ++++++++++- Divert.Windows/DivertIOControl.cs | 84 +++++++++ Divert.Windows/DivertService.cs | 163 +++++++++++++----- Divert.Windows/NativeMethods.cs | 13 ++ Divert.Windows/NativeMethods.txt | 4 + Divert.Windows/NativeTypes.cs | 13 ++ 12 files changed, 523 insertions(+), 55 deletions(-) create mode 100644 Divert.Windows/Constants.cs create mode 100644 Divert.Windows/DivertIOControl.cs diff --git a/Divert.Windows.Tests/DivertServiceTests.cs b/Divert.Windows.Tests/DivertServiceTests.cs index c995c6a..624ed62 100644 --- a/Divert.Windows.Tests/DivertServiceTests.cs +++ b/Divert.Windows.Tests/DivertServiceTests.cs @@ -11,7 +11,50 @@ namespace Divert.Windows.Tests; [TestClass] public sealed class DivertServiceTests : DivertTests { - private static IEnumerable InvalidFilterCases() + [TestMethod] + public void Parameters() + { + Assert.Throws(() => + new DivertService(false, priority: DivertService.HighestPriority + 1) + ); + Assert.Throws(() => + new DivertService(false, priority: DivertService.LowestPriority - 1) + ); + + using var service = new DivertService(false, priority: DivertService.HighestPriority); + using var _ = new DivertService(false, priority: DivertService.LowestPriority); + Assert.AreEqual(new Version(2, 2), service.Version); + + Assert.AreEqual(DivertService.DefaultQueueLength, service.QueueLength); + service.QueueLength = DivertService.MinQueueLength; + Assert.AreEqual(DivertService.MinQueueLength, service.QueueLength); + service.QueueLength = DivertService.MaxQueueLength; + Assert.AreEqual(DivertService.MaxQueueLength, service.QueueLength); + Assert.Throws(() => service.QueueLength = DivertService.MinQueueLength - 1); + Assert.Throws(() => service.QueueLength = DivertService.MaxQueueLength + 1); + + Assert.AreEqual(DivertService.DefaultQueueTime, service.QueueTime); + service.QueueTime = DivertService.MinQueueTime; + Assert.AreEqual(DivertService.MinQueueTime, service.QueueTime); + service.QueueTime = DivertService.MaxQueueTime; + Assert.AreEqual(DivertService.MaxQueueTime, service.QueueTime); + Assert.Throws(() => + service.QueueTime = DivertService.MinQueueTime - TimeSpan.FromMilliseconds(1) + ); + Assert.Throws(() => + service.QueueTime = DivertService.MaxQueueTime + TimeSpan.FromMilliseconds(1) + ); + + Assert.AreEqual(DivertService.DefaultQueueSize, service.QueueSize); + service.QueueSize = DivertService.MinQueueSize; + Assert.AreEqual(DivertService.MinQueueSize, service.QueueSize); + service.QueueSize = DivertService.MaxQueueSize; + Assert.AreEqual(DivertService.MaxQueueSize, service.QueueSize); + Assert.Throws(() => service.QueueSize = DivertService.MinQueueSize - 1); + Assert.Throws(() => service.QueueSize = DivertService.MaxQueueSize + 1); + } + + public static IEnumerable InvalidFilterCases() { return [ @@ -288,6 +331,7 @@ public async Task InvalidHandle() null ); using var service = new DivertService(handle); + Assert.AreEqual(handle.DangerousGetHandle(), service.SafeHandle.DangerousGetHandle()); var packetBuffer = new byte[ushort.MaxValue + 40]; var addressBuffer = new DivertAddress[1]; @@ -295,5 +339,8 @@ public async Task InvalidHandle() await service.ReceiveAsync(packetBuffer, addressBuffer, Token) ); Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + + exception = Assert.Throws(() => service.QueueLength = DivertService.DefaultQueueLength); + Assert.AreEqual((int)WIN32_ERROR.ERROR_ACCESS_DENIED, exception.NativeErrorCode); } } diff --git a/Divert.Windows.Tests/HelperTests.cs b/Divert.Windows.Tests/HelperTests.cs index 28da73f..7dd3c18 100644 --- a/Divert.Windows.Tests/HelperTests.cs +++ b/Divert.Windows.Tests/HelperTests.cs @@ -1,4 +1,6 @@ using System.ComponentModel; +using System.Net; +using System.Net.Sockets; using Windows.Win32.Foundation; namespace Divert.Windows.Tests; @@ -7,7 +9,7 @@ namespace Divert.Windows.Tests; public class HelperTests : DivertTests { [TestMethod] - public void InvalidFilter() + public void FormatFilter() { var invalidFilter = Array.Empty(); var exception = Assert.Throws(() => @@ -15,4 +17,126 @@ public void InvalidFilter() ); Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); } + + [TestMethod] + public async Task DecrementTtl() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(Token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); + var divertResult = await divertReceive; + var packet = packetBuffer.AsMemory(0, divertResult.DataLength); + + // set ttl to 2 + packet.Span[8] = 2; + Assert.IsTrue(DivertHelper.DecrementTtl(packet.Span)); + Assert.AreEqual(1, packet.Span[8]); + Assert.IsFalse(DivertHelper.DecrementTtl(packet.Span)); + Assert.AreEqual(1, packet.Span[8]); + } + + [TestMethod] + public async Task CompileFilter() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(Token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + var filterBuffer = DivertHelper.CompileFilter(filter, DivertLayer.Network); + using var service = new DivertService(filterBuffer); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); + var divertResult = await divertReceive; + + Assert.AreEqual(20 + 8 + 3, divertResult.DataLength); + var packet = packetBuffer.AsMemory(0, divertResult.DataLength); + CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, packet.ToArray()[28..]); + } + + private static IEnumerable InvalidFilterCases() => DivertServiceTests.InvalidFilterCases(); + + [TestMethod] + [DynamicData(nameof(InvalidFilterCases))] + public void CompileInvalidFilter(DivertLayer layer, DivertFilter filter, string message) + { + var exception = Assert.Throws(() => DivertHelper.CompileFilter(filter, layer)); + Assert.AreEqual("filter", exception.ParamName); + Assert.StartsWith(message, exception.Message); + } + + [TestMethod] + public async Task EvaluateFilter() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(Token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); + var divertResult = await divertReceive; + var packet = packetBuffer.AsMemory(0, divertResult.DataLength); + var address = addressBuffer[0]; + + Assert.IsTrue(DivertHelper.EvaluateFilter(filter, packet.Span, address)); + Assert.IsTrue( + DivertHelper.EvaluateFilter(DivertHelper.CompileFilter(filter, DivertLayer.Network), packet.Span, address) + ); + Assert.IsTrue(DivertHelper.EvaluateFilter(DivertFilter.UDP, packet.Span, address)); + Assert.IsFalse(DivertHelper.EvaluateFilter(DivertFilter.TCP, packet.Span, address)); + + var exception = Assert.Throws(() => DivertHelper.EvaluateFilter([], packet.Span, address)); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + exception = Assert.Throws(() => DivertHelper.EvaluateFilter(filter, [], address)); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + } } diff --git a/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs b/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs index 38db653..a373376 100644 --- a/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs +++ b/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs @@ -5,12 +5,12 @@ namespace Divert.Windows.AsyncOperation; internal sealed class DivertReceiveExecutor : IDivertValueTaskExecutor { - private readonly Memory addressesLengthBuffer = GC.AllocateArray(1, pinned: true); + private readonly uint[] addressesLengthBuffer = GC.AllocateArray(1, pinned: true); public unsafe bool Execute(SafeHandle safeHandle, ref readonly PendingOperation pendingOperation) { using var _ = safeHandle.DangerousGetHandle(out var handle); - addressesLengthBuffer.Span[0] = (uint)(pendingOperation.Addresses.Length * sizeof(DivertAddress)); + addressesLengthBuffer[0] = (uint)(pendingOperation.Addresses.Length * sizeof(DivertAddress)); return NativeMethods.WinDivertRecvEx( handle, pendingOperation.PacketBufferHandle.Pointer, @@ -18,7 +18,7 @@ public unsafe bool Execute(SafeHandle safeHandle, ref readonly PendingOperation null, 0, (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, - (uint*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(addressesLengthBuffer.Span)), + (uint*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(addressesLengthBuffer.AsSpan())), pendingOperation.NativeOverlapped ); } @@ -31,7 +31,7 @@ CancellationToken cancellationToken ) { int dataLength = await source.ExecuteAsync(this, buffer, addresses, cancellationToken).ConfigureAwait(false); - int addressesLength = (int)addressesLengthBuffer.Span[0] / Marshal.SizeOf(); + int addressesLength = (int)addressesLengthBuffer[0] / Marshal.SizeOf(); return new DivertReceiveResult(dataLength, addressesLength); } } diff --git a/Divert.Windows/CString.cs b/Divert.Windows/CString.cs index 328cd64..de23046 100644 --- a/Divert.Windows/CString.cs +++ b/Divert.Windows/CString.cs @@ -4,7 +4,7 @@ namespace Divert.Windows; internal sealed class CString(string str) : IDisposable { - internal IntPtr Ptr { get; } = Marshal.StringToHGlobalAnsi(str); + internal IntPtr Pointer { get; } = Marshal.StringToHGlobalAnsi(str); private bool disposed; @@ -12,7 +12,7 @@ public void Dispose() { if (!disposed) { - Marshal.FreeHGlobal(Ptr); + Marshal.FreeHGlobal(Pointer); disposed = true; } } diff --git a/Divert.Windows/Constants.cs b/Divert.Windows/Constants.cs new file mode 100644 index 0000000..ef91e99 --- /dev/null +++ b/Divert.Windows/Constants.cs @@ -0,0 +1,16 @@ +namespace Divert.Windows; + +internal static class Constants +{ + public const short WINDIVERT_PRIORITY_HIGHEST = 3000; + public const short WINDIVERT_PRIORITY_LOWEST = -WINDIVERT_PRIORITY_HIGHEST; + public const int WINDIVERT_PARAM_QUEUE_LENGTH_DEFAULT = 4096; + public const int WINDIVERT_PARAM_QUEUE_LENGTH_MIN = 32; + public const int WINDIVERT_PARAM_QUEUE_LENGTH_MAX = 16384; + public const int WINDIVERT_PARAM_QUEUE_TIME_DEFAULT = 2000; + public const int WINDIVERT_PARAM_QUEUE_TIME_MIN = 100; + public const int WINDIVERT_PARAM_QUEUE_TIME_MAX = 16000; + public const int WINDIVERT_PARAM_QUEUE_SIZE_DEFAULT = 4 * 1024 * 1024; + public const int WINDIVERT_PARAM_QUEUE_SIZE_MIN = 64 * 1024; + public const int WINDIVERT_PARAM_QUEUE_SIZE_MAX = 32 * 1024 * 1024; +} diff --git a/Divert.Windows/DivertHandle.cs b/Divert.Windows/DivertHandle.cs index 0a2952a..ee9a4f0 100644 --- a/Divert.Windows/DivertHandle.cs +++ b/Divert.Windows/DivertHandle.cs @@ -2,7 +2,7 @@ namespace Divert.Windows; -internal sealed class DivertHandle : SafeHandleZeroOrMinusOneIsInvalid +public sealed class DivertHandle : SafeHandleZeroOrMinusOneIsInvalid { public DivertHandle(IntPtr handle, bool ownsHandle = true) : base(ownsHandle) diff --git a/Divert.Windows/DivertHelper.cs b/Divert.Windows/DivertHelper.cs index fdfdb54..a06b151 100644 --- a/Divert.Windows/DivertHelper.cs +++ b/Divert.Windows/DivertHelper.cs @@ -1,5 +1,7 @@ using System.ComponentModel; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using Windows.Win32.Foundation; namespace Divert.Windows; @@ -31,16 +33,100 @@ public static bool CalculateChecksums( } } - public static string FormatFilter(Span filter, DivertLayer layer) + public static bool DecrementTtl(Span packet) { - Memory buffer = GC.AllocateArray(ushort.MaxValue, pinned: true); - using var bufferHandle = buffer.Pin(); + fixed (byte* pPacket = packet) + { + return NativeMethods.WinDivertHelperDecrementTTL(pPacket, (uint)packet.Length); + } + } + + public static ReadOnlySpan CompileFilter( + DivertFilter filter, + DivertLayer layer, + int bufferLength = ushort.MaxValue + ) + { + ArgumentNullException.ThrowIfNull(filter); + + using var s = new CString(filter.Clause); + Span buffer = GC.AllocateArray(bufferLength, pinned: true); + var pBuffer = Unsafe.AsPointer(ref MemoryMarshal.GetReference(buffer)); + IntPtr errorStr; + uint errorPos; + bool success = NativeMethods.WinDivertHelperCompileFilter( + s.Pointer, + (WINDIVERT_LAYER)layer, + (byte*)pBuffer, + (uint)buffer.Length, + &errorStr, + &errorPos + ); + if (!success) + { + string? errorString = Marshal.PtrToStringAnsi(errorStr); + throw new ArgumentException( + $"{errorString} ({errorPos}): ...{filter.Clause[(int)errorPos..]}", + nameof(filter) + ); + } + + return buffer; + } + + private static bool EvaluateFilter(IntPtr filter, ReadOnlySpan packet, in DivertAddress address) + { + fixed (byte* pPacket = packet) + fixed (DivertAddress* pAddress = &address) + { + bool success = NativeMethods.WinDivertHelperEvalFilter( + filter, + pPacket, + (uint)packet.Length, + (WINDIVERT_ADDRESS*)pAddress + ); + if (!success) + { + int error = Marshal.GetLastPInvokeError(); + if (error is not (int)WIN32_ERROR.ERROR_SUCCESS) + { + throw new Win32Exception(error); + } + } + return success; + } + } + + public static bool EvaluateFilter(ReadOnlySpan filter, ReadOnlySpan packet, in DivertAddress address) + { + fixed (byte* pFilter = filter) + fixed (byte* pPacket = packet) + fixed (DivertAddress* pAddress = &address) + { + return EvaluateFilter(new IntPtr(pFilter), packet, address); + } + } + + public static bool EvaluateFilter(DivertFilter filter, ReadOnlySpan packet, in DivertAddress address) + { + using var s = new CString(filter.Clause); + fixed (byte* pPacket = packet) + fixed (DivertAddress* pAddress = &address) + { + return EvaluateFilter(s.Pointer, packet, address); + } + } + + public static string FormatFilter(Span filter, DivertLayer layer, int maxLength = ushort.MaxValue) + { + Span buffer = GC.AllocateArray(maxLength, pinned: true); + var pBuffer = Unsafe.AsPointer(ref MemoryMarshal.GetReference(buffer)); fixed (byte* pFilter = filter) { bool success = NativeMethods.WinDivertHelperFormatFilter( new(pFilter), (WINDIVERT_LAYER)layer, - (byte*)bufferHandle.Pointer, + (byte*)pBuffer, (uint)buffer.Length ); if (!success) @@ -48,7 +134,7 @@ public static string FormatFilter(Span filter, DivertLayer layer) throw new Win32Exception(Marshal.GetLastPInvokeError()); } - return Marshal.PtrToStringAnsi(new IntPtr(bufferHandle.Pointer))!; + return Marshal.PtrToStringAnsi(new IntPtr(pBuffer))!; } } } diff --git a/Divert.Windows/DivertIOControl.cs b/Divert.Windows/DivertIOControl.cs new file mode 100644 index 0000000..334f61c --- /dev/null +++ b/Divert.Windows/DivertIOControl.cs @@ -0,0 +1,84 @@ +using System.ComponentModel; +using System.Runtime.InteropServices; +using Windows.Win32; +using Windows.Win32.Foundation; + +namespace Divert.Windows; + +internal static unsafe class DivertIOControl +{ + private static uint CTL_CODE(uint deviceType, uint function, uint method, uint access) + { + return (deviceType << 16) | (access << 14) | (function << 2) | method; + } + + private static readonly uint SetParamControlCode = CTL_CODE( + PInvoke.FILE_DEVICE_NETWORK, + 0x925, + PInvoke.METHOD_IN_DIRECT, + (uint)FileAccess.ReadWrite + ); + + private static readonly uint GetParamControlCode = CTL_CODE( + PInvoke.FILE_DEVICE_NETWORK, + 0x926, + PInvoke.METHOD_OUT_DIRECT, + (uint)FileAccess.Read + ); + + private static void ManualResetCallback(uint errorCode, uint numBytes, NativeOverlapped* pOVERLAP) + { + var manualResetEvent = (ManualResetEventSlim)ThreadPoolBoundHandle.GetNativeOverlappedState(pOVERLAP)!; + manualResetEvent.Set(); + } + + private static readonly IOCompletionCallback manualResetCallback = ManualResetCallback; + + private static ulong DeviceIOControl(ThreadPoolBoundHandle threadPoolBoundHandle, WINDIVERT_IOCTL* ioctl, uint code) + { + ulong value; + using var eventHandle = new ManualResetEventSlim(initialState: false); + var nativeOverlapped = threadPoolBoundHandle.AllocateNativeOverlapped(manualResetCallback, eventHandle, null); + try + { + using var _ = threadPoolBoundHandle.Handle.DangerousGetHandle(out var handle); + bool success = PInvoke.DeviceIoControl( + new HANDLE(handle), + code, + ioctl, + (uint)sizeof(WINDIVERT_IOCTL), + &value, + sizeof(ulong), + null, + nativeOverlapped + ); + if (!success) + { + int error = Marshal.GetLastPInvokeError(); + if (error is not (int)WIN32_ERROR.ERROR_IO_PENDING) + { + throw new Win32Exception(error); + } + eventHandle.Wait(); + } + } + finally + { + threadPoolBoundHandle.FreeNativeOverlapped(nativeOverlapped); + } + + return value; + } + + public static void SetParam(ThreadPoolBoundHandle threadPoolBoundHandle, WINDIVERT_PARAM param, ulong value) + { + var ioctl = new WINDIVERT_IOCTL { SetParam = param, Value = value }; + DeviceIOControl(threadPoolBoundHandle, &ioctl, SetParamControlCode); + } + + public static ulong GetParam(ThreadPoolBoundHandle threadPoolBoundHandle, WINDIVERT_PARAM param) + { + var ioctl = new WINDIVERT_IOCTL { GetParam = param }; + return DeviceIOControl(threadPoolBoundHandle, &ioctl, GetParamControlCode); + } +} diff --git a/Divert.Windows/DivertService.cs b/Divert.Windows/DivertService.cs index 95fb5c3..4a40259 100644 --- a/Divert.Windows/DivertService.cs +++ b/Divert.Windows/DivertService.cs @@ -1,4 +1,5 @@ using System.ComponentModel; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Threading.Channels; using Divert.Windows.AsyncOperation; @@ -7,10 +8,22 @@ namespace Divert.Windows; /// -/// Main entry point WinDivert APIs. +/// Main entry point for WinDivert operations. /// public sealed unsafe class DivertService : IDisposable { + public const int HighestPriority = Constants.WINDIVERT_PRIORITY_HIGHEST; + public const int LowestPriority = Constants.WINDIVERT_PRIORITY_LOWEST; + public const int DefaultQueueLength = Constants.WINDIVERT_PARAM_QUEUE_LENGTH_DEFAULT; + public const int MinQueueLength = Constants.WINDIVERT_PARAM_QUEUE_LENGTH_MIN; + public const int MaxQueueLength = Constants.WINDIVERT_PARAM_QUEUE_LENGTH_MAX; + public static TimeSpan DefaultQueueTime => TimeSpan.FromMilliseconds(Constants.WINDIVERT_PARAM_QUEUE_TIME_DEFAULT); + public static TimeSpan MinQueueTime => TimeSpan.FromMilliseconds(Constants.WINDIVERT_PARAM_QUEUE_TIME_MIN); + public static TimeSpan MaxQueueTime => TimeSpan.FromMilliseconds(Constants.WINDIVERT_PARAM_QUEUE_TIME_MAX); + public const int DefaultQueueSize = Constants.WINDIVERT_PARAM_QUEUE_SIZE_DEFAULT; + public const int MinQueueSize = Constants.WINDIVERT_PARAM_QUEUE_SIZE_MIN; + public const int MaxQueueSize = Constants.WINDIVERT_PARAM_QUEUE_SIZE_MAX; + private readonly DivertHandle divertHandle; private readonly bool runContinuationsAsynchronously; @@ -20,13 +33,19 @@ public sealed unsafe class DivertService : IDisposable private readonly DivertReceiveExecutor receiveExecutor; private readonly DivertSendExecutor sendExecutor; + private DivertService(DivertHandle divertHandle) + { + this.divertHandle = divertHandle; + threadPoolBoundHandle = ThreadPoolBoundHandle.BindHandle(divertHandle); + sendVtsPool = Channel.CreateUnbounded(); + receiveVtsPool = Channel.CreateUnbounded(); + receiveExecutor = new DivertReceiveExecutor(); + sendExecutor = new DivertSendExecutor(); + } + /// - /// Opens a WinDivert handle for the given filter. + /// Initializes a new instance of the class. /// - /// A packet filter string specified in the WinDivert filter language. - /// The layer. - /// The priority of the handle. - /// Additional flags. public DivertService( DivertFilter filter, DivertLayer layer = DivertLayer.Network, @@ -34,55 +53,54 @@ public DivertService( DivertFlags flags = DivertFlags.None, bool runContinuationsAsynchronously = true ) + : this(OpenHandle(DivertHelper.CompileFilter(filter, layer), layer, priority, flags)) { - ArgumentNullException.ThrowIfNull(filter); - - using var s = new CString(filter.Clause); - IntPtr errorStr; - uint errorPos; - bool success = NativeMethods.WinDivertHelperCompileFilter( - s.Ptr, - (WINDIVERT_LAYER)layer, - null, - 0, - &errorStr, - &errorPos - ); - if (!success) + this.runContinuationsAsynchronously = runContinuationsAsynchronously; + } + + private static DivertHandle OpenHandle( + ReadOnlySpan filter, + DivertLayer layer, + short priority, + DivertFlags flags + ) + { + if (priority < LowestPriority || priority > HighestPriority) { - string? errorString = Marshal.PtrToStringAnsi(errorStr); - throw new ArgumentException( - $"{errorString} ({errorPos}): ...{filter.Clause[(int)errorPos..]}", - nameof(filter) - ); + throw new ArgumentOutOfRangeException(nameof(priority)); } - var handle = NativeMethods.WinDivertOpen(s.Ptr, (WINDIVERT_LAYER)layer, priority, (ulong)flags); + var pBuffer = Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)); + var handle = NativeMethods.WinDivertOpen(new(pBuffer), (WINDIVERT_LAYER)layer, priority, (ulong)flags); if (new HANDLE(handle) == HANDLE.INVALID_HANDLE_VALUE) { throw new Win32Exception(Marshal.GetLastPInvokeError()); } - divertHandle = new DivertHandle(handle); + return new DivertHandle(handle); + } + + /// + /// Initializes a new instance of the class. + /// + public DivertService( + ReadOnlySpan filter, + DivertLayer layer = DivertLayer.Network, + short priority = 0, + DivertFlags flags = DivertFlags.None, + bool runContinuationsAsynchronously = true + ) + : this(OpenHandle(filter, layer, priority, flags)) + { this.runContinuationsAsynchronously = runContinuationsAsynchronously; - threadPoolBoundHandle = ThreadPoolBoundHandle.BindHandle(divertHandle); - sendVtsPool = Channel.CreateUnbounded(); - receiveVtsPool = Channel.CreateUnbounded(); - receiveExecutor = new DivertReceiveExecutor(); - sendExecutor = new DivertSendExecutor(); } + /// + /// Initializes a new instance of the class. + /// public DivertService(SafeHandle handle, bool runContinuationsAsynchronously = true) + : this(new DivertHandle(handle.DangerousGetHandle(), ownsHandle: false)) { - ArgumentNullException.ThrowIfNull(handle); - - using var _ = handle.DangerousGetHandle(out var nativeHandle); - divertHandle = new DivertHandle(nativeHandle, ownsHandle: false); this.runContinuationsAsynchronously = runContinuationsAsynchronously; - threadPoolBoundHandle = ThreadPoolBoundHandle.BindHandle(divertHandle); - sendVtsPool = Channel.CreateUnbounded(); - receiveVtsPool = Channel.CreateUnbounded(); - receiveExecutor = new DivertReceiveExecutor(); - sendExecutor = new DivertSendExecutor(); } private bool disposed = false; @@ -113,6 +131,8 @@ public void Dispose() divertHandle.Dispose(); } + public DivertHandle SafeHandle => divertHandle; + private DivertValueTaskSource GetVts(Channel vtsPool) { if (!vtsPool.Reader.TryRead(out var vts)) @@ -172,4 +192,65 @@ public ValueTask SendAsync( var vts = GetVts(sendVtsPool); return sendExecutor.SendAsync(vts, buffer, addresses, cancellationToken); } + + public Version Version + { + get + { + int major = (int) + DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_VERSION_MAJOR); + int minor = (int) + DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_VERSION_MINOR); + return new Version(major, minor); + } + } + + public int QueueLength + { + get => (int)DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_LENGTH); + set + { + if (value < MinQueueLength || value > MaxQueueLength) + { + throw new ArgumentOutOfRangeException(nameof(value)); + } + + DivertIOControl.SetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_LENGTH, (ulong)value); + } + } + + public TimeSpan QueueTime + { + get => + TimeSpan.FromMilliseconds( + DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_TIME) + ); + set + { + if (value < MinQueueTime || value > MaxQueueTime) + { + throw new ArgumentOutOfRangeException(nameof(value)); + } + + DivertIOControl.SetParam( + threadPoolBoundHandle, + WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_TIME, + (ulong)value.TotalMilliseconds + ); + } + } + + public int QueueSize + { + get => (int)DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_SIZE); + set + { + if (value < MinQueueSize || value > MaxQueueSize) + { + throw new ArgumentOutOfRangeException(nameof(value)); + } + + DivertIOControl.SetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_SIZE, (ulong)value); + } + } } diff --git a/Divert.Windows/NativeMethods.cs b/Divert.Windows/NativeMethods.cs index 120f303..563fa65 100644 --- a/Divert.Windows/NativeMethods.cs +++ b/Divert.Windows/NativeMethods.cs @@ -80,6 +80,10 @@ public static partial bool WinDivertHelperCalcChecksums( ulong flags ); + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertHelperDecrementTTL(void* pPacket, uint packetLen); + [LibraryImport(dllName, SetLastError = true)] [return: MarshalAs(UnmanagedType.Bool)] public static partial bool WinDivertHelperCompileFilter( @@ -99,4 +103,13 @@ public static partial bool WinDivertHelperFormatFilter( byte* buffer, uint bufLen ); + + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertHelperEvalFilter( + IntPtr filter, + void* pPacket, + uint packetLen, + WINDIVERT_ADDRESS* pAddr + ); } diff --git a/Divert.Windows/NativeMethods.txt b/Divert.Windows/NativeMethods.txt index 2a1e234..8926c54 100644 --- a/Divert.Windows/NativeMethods.txt +++ b/Divert.Windows/NativeMethods.txt @@ -1,3 +1,7 @@ INVALID_HANDLE_VALUE WIN32_ERROR CancelIoEx +FILE_DEVICE_NETWORK +METHOD_IN_DIRECT +METHOD_OUT_DIRECT +DeviceIoControl diff --git a/Divert.Windows/NativeTypes.cs b/Divert.Windows/NativeTypes.cs index 5af8043..3f57897 100644 --- a/Divert.Windows/NativeTypes.cs +++ b/Divert.Windows/NativeTypes.cs @@ -175,3 +175,16 @@ internal enum WINDIVERT_SHUTDOWN WINDIVERT_SHUTDOWN_SEND = 0x2, WINDIVERT_SHUTDOWN_BOTH = 0x3, } + +[StructLayout(LayoutKind.Explicit, Size = 16)] +internal struct WINDIVERT_IOCTL +{ + [FieldOffset(0)] + public WINDIVERT_PARAM GetParam; + + [FieldOffset(0)] + public ulong Value; + + [FieldOffset(8)] + public WINDIVERT_PARAM SetParam; +} From 9433933de4298d14ec6558fa3b19baf1185254fb Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 9 Nov 2025 04:28:11 +0000 Subject: [PATCH 11/20] Documentations and examples. --- Divert.Windows/Constants.cs | 1 + Divert.Windows/Divert.Windows.csproj | 1 + Divert.Windows/DivertAddress.cs | 183 ++++++++++++++++++++++- Divert.Windows/DivertEvent.cs | 42 ++++++ Divert.Windows/DivertFilter.cs | 204 ++++++++++++++++++++++++++ Divert.Windows/DivertFlags.cs | 38 +++++ Divert.Windows/DivertHandle.cs | 18 +++ Divert.Windows/DivertHelper.cs | 60 ++++++++ Divert.Windows/DivertHelperFlags.cs | 26 ++++ Divert.Windows/DivertIOControl.cs | 4 + Divert.Windows/DivertLayer.cs | 22 +++ Divert.Windows/DivertReceiveResult.cs | 5 + Divert.Windows/DivertService.cs | 105 +++++++++++-- Examples/Flow/Flow.csproj | 15 ++ Examples/Flow/Program.cs | 136 +++++++++++++++++ Examples/Flow/app.manifest | 10 ++ Examples/Ping/Ping.csproj | 1 + Examples/Ping/Program.cs | 21 +-- Examples/Ping/app.manifest | 10 ++ Examples/Socket/Program.cs | 131 +++++++++++++++++ Examples/Socket/Socket.csproj | 15 ++ Examples/Socket/app.manifest | 10 ++ 22 files changed, 1024 insertions(+), 34 deletions(-) create mode 100644 Examples/Flow/Flow.csproj create mode 100644 Examples/Flow/Program.cs create mode 100644 Examples/Flow/app.manifest create mode 100644 Examples/Ping/app.manifest create mode 100644 Examples/Socket/Program.cs create mode 100644 Examples/Socket/Socket.csproj create mode 100644 Examples/Socket/app.manifest diff --git a/Divert.Windows/Constants.cs b/Divert.Windows/Constants.cs index ef91e99..d3774c5 100644 --- a/Divert.Windows/Constants.cs +++ b/Divert.Windows/Constants.cs @@ -13,4 +13,5 @@ internal static class Constants public const int WINDIVERT_PARAM_QUEUE_SIZE_DEFAULT = 4 * 1024 * 1024; public const int WINDIVERT_PARAM_QUEUE_SIZE_MIN = 64 * 1024; public const int WINDIVERT_PARAM_QUEUE_SIZE_MAX = 32 * 1024 * 1024; + public const int WINDIVERT_BATCH_MAX = byte.MaxValue; } diff --git a/Divert.Windows/Divert.Windows.csproj b/Divert.Windows/Divert.Windows.csproj index e2586ce..fe9a6e3 100644 --- a/Divert.Windows/Divert.Windows.csproj +++ b/Divert.Windows/Divert.Windows.csproj @@ -2,6 +2,7 @@ $(DefaultTargetFramework) x64 + true diff --git a/Divert.Windows/DivertAddress.cs b/Divert.Windows/DivertAddress.cs index f152b35..0a43dd8 100644 --- a/Divert.Windows/DivertAddress.cs +++ b/Divert.Windows/DivertAddress.cs @@ -4,51 +4,160 @@ namespace Divert.Windows; +/// +/// Represents the address of a captured or injected packet. +/// [StructLayout(LayoutKind.Sequential)] public unsafe struct DivertAddress { + /// + /// Network layer data. + /// public struct NetworkData { + /// + /// The interface index on which the packet was captured or to which it will be injected. + /// public uint InterfaceIndex; + + /// + /// The sub-interface index of the interface. + /// public uint SubInterfaceIndex; } + /// + /// Flow layer data. + /// public struct FlowData { + /// + /// The endpoint ID of the flow. + /// public ulong EndpointId; + + /// + /// The parent endpoint ID of the flow. + /// public ulong ParentEndpointId; + + /// + /// The process ID associated with the flow. + /// public uint ProcessId; + + /// + /// The local IP address of the flow. + /// public IPAddress LocalAddress; + + /// + /// The remote IP address of the flow. + /// public IPAddress RemoteAddress; + + /// + /// The local port of the flow. + /// public ushort LocalPort; + + /// + /// The remote port of the flow. + /// public ushort RemotePort; + + /// + /// The protocol of the flow. + /// public byte Protocol; } + /// + /// Socket layer data. + /// public struct SocketData { + /// + /// The endpoint ID of the socket operation. + /// public ulong EndpointId; + + /// + /// The parent endpoint ID of the socket operation. + /// public ulong ParentEndpointId; + + /// + /// The process ID associated with the socket operation. + /// public uint ProcessId; + + /// + /// The local IP address of the socket operation. + /// public IPAddress LocalAddress; + + /// + /// The remote IP address of the socket operation. + /// public IPAddress RemoteAddress; + + /// + /// The local port of the socket operation. + /// public ushort LocalPort; + + /// + /// The remote port of the socket operation. + /// public ushort RemotePort; + + /// + /// The protocol of the socket operation. + /// public byte Protocol; } + /// + /// Reflect layer data. + /// public struct ReflectData { + /// + /// The timestamp when the handle was opened. + /// public long Timestamp; + + /// + /// The process ID that opened the handle. + /// public uint ProcessId; + + /// + /// The layer parameter used when opening the handle. + /// public DivertLayer Layer; + + /// + /// The flags parameter used when opening the handle. + /// public DivertFlags Flags; + + /// + /// The priority parameter used when opening the handle. + /// public short Priority; } private WINDIVERT_ADDRESS address; - public DivertAddress(int interfaceIndex, int subInterfaceIndex) + /// + /// Initializes a new instance of the struct with the specified interface and + /// sub-interface indices. + /// + /// The interface index. + /// The sub-interface index. + public DivertAddress(int interfaceIndex, int subInterfaceIndex = 0) { address = new WINDIVERT_ADDRESS { @@ -60,6 +169,9 @@ public DivertAddress(int interfaceIndex, int subInterfaceIndex) }; } + /// + /// Clears all fields of the . + /// public void Reset() { fixed (WINDIVERT_ADDRESS* pAddress = &address) @@ -68,10 +180,19 @@ public void Reset() } } + /// + /// Gets the timestamp of the event. + /// public readonly long Timestamp => address.Timestamp; + /// + /// Gets the layer of the event. + /// public readonly DivertLayer Layer => (DivertLayer)address.Layer; + /// + /// Gets the event type. + /// public readonly DivertEvent Event => (DivertEvent)address.Event; private readonly bool GetBit(WINDIVERT_ADDRESS_BITS bit) @@ -91,42 +212,75 @@ private void SetBit(WINDIVERT_ADDRESS_BITS bit, bool value) } } + /// + /// Gets a value indicating whether the packet/event was sniffed (not blocked). + /// public readonly bool IsSniffed => GetBit(WINDIVERT_ADDRESS_BITS.Sniffed); + /// + /// Gets or sets a value indicating whether the packet/event is outbound. + /// public bool IsOutbound { readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.Outbound); } set { SetBit(WINDIVERT_ADDRESS_BITS.Outbound, value); } } + /// + /// Gets a value indicating whether the packet is loopback. + /// public readonly bool IsLoopback => GetBit(WINDIVERT_ADDRESS_BITS.Loopback); + /// + /// Gets or sets a value indicating impostor packets. + /// public bool IsImpostor { readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.Impostor); } set { SetBit(WINDIVERT_ADDRESS_BITS.Impostor, value); } } + /// + /// Gets a value indicating whether the packet/event is IPv6. + /// public readonly bool IsIPv6 => GetBit(WINDIVERT_ADDRESS_BITS.IPv6); + /// + /// Gets or sets a value indicating whether the IP checksum is valid. + /// public bool IsIPChecksumValid { readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.IPChecksum); } set { SetBit(WINDIVERT_ADDRESS_BITS.IPChecksum, value); } } + /// + /// Gets or sets a value indicating whether the TCP checksum is valid. + /// public bool IsTCPChecksumValid { readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.TCPChecksum); } set { SetBit(WINDIVERT_ADDRESS_BITS.TCPChecksum, value); } } + /// + /// Gets or sets a value indicating whether the UDP checksum is valid. + /// public bool IsUDPChecksumValid { readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.UDPChecksum); } set { SetBit(WINDIVERT_ADDRESS_BITS.UDPChecksum, value); } } + /// + /// Gets the network data associated with the event. + /// + /// + /// The associated with the event. + /// + /// + /// The is not or . + /// public NetworkData GetNetworkData() { return Layer switch @@ -153,6 +307,15 @@ private readonly IPAddress GetIPAddress(Span bytes) return address; } + /// + /// Gets the flow data associated with the event. + /// + /// + /// The associated with the event. + /// + /// + /// The is not . + /// public FlowData GetFlowData() { switch (Layer) @@ -175,6 +338,15 @@ public FlowData GetFlowData() } } + /// + /// Gets the socket data associated with the event. + /// + /// + /// The associated with the event. + /// + /// + /// The is not . + /// public SocketData GetSocketData() { switch (Layer) @@ -197,6 +369,15 @@ public SocketData GetSocketData() } } + /// + /// Gets the reflect data associated with the event. + /// + /// + /// The associated with the event. + /// + /// + /// The is not . + /// public ReflectData GetReflectData() { return Layer switch diff --git a/Divert.Windows/DivertEvent.cs b/Divert.Windows/DivertEvent.cs index 3d1b351..eb22284 100644 --- a/Divert.Windows/DivertEvent.cs +++ b/Divert.Windows/DivertEvent.cs @@ -1,15 +1,57 @@ namespace Divert.Windows; +/// +/// Types of Divert events. +/// public enum DivertEvent { + /// + /// A network packet. + /// NetworkPacket = WINDIVERT_EVENT.WINDIVERT_EVENT_NETWORK_PACKET, + + /// + /// A flow has been established. + /// FlowEstablished = WINDIVERT_EVENT.WINDIVERT_EVENT_FLOW_ESTABLISHED, + + /// + /// A flow has been deleted. + /// FlowDeleted = WINDIVERT_EVENT.WINDIVERT_EVENT_FLOW_DELETED, + + /// + /// A bind() socket operation. + /// SocketBind = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_BIND, + + /// + /// A connect() socket operation. + /// SocketConnect = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_CONNECT, + + /// + /// A listen() socket operation. + /// SocketListen = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_LISTEN, + + /// + /// An accept() socket operation. + /// SocketAccept = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_ACCEPT, + + /// + /// A socket is unbound or a connection is closed. + /// SocketClose = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_CLOSE, + + /// + /// A new WinDivert handle was opened. + /// ReflectOpen = WINDIVERT_EVENT.WINDIVERT_EVENT_REFLECT_OPEN, + + /// + /// A WinDivert handle was closed. + /// ReflectClose = WINDIVERT_EVENT.WINDIVERT_EVENT_REFLECT_CLOSE, } diff --git a/Divert.Windows/DivertFilter.cs b/Divert.Windows/DivertFilter.cs index 6968ca9..b6fb48f 100644 --- a/Divert.Windows/DivertFilter.cs +++ b/Divert.Windows/DivertFilter.cs @@ -6,10 +6,20 @@ namespace Divert.Windows; internal record struct ReplaceParenthesesOperation(string Expression); +/// +/// Helper type to build WinDivert filter expressions. +/// public partial class DivertFilter { + /// + /// The filter clause. + /// public string Clause { get; } + /// + /// Creates a new filter with the specified clause. + /// + /// The filter clause. public DivertFilter(string clause) { ArgumentNullException.ThrowIfNull(clause); @@ -72,6 +82,9 @@ private static string CollapseParentheses(string expression) private static bool MatchAndPattern(string s) => AndPatternRegex().IsMatch(CollapseParentheses(s)); + /// + /// Combines two filters with an AND operation. + /// public static DivertFilter operator &(DivertFilter left, DivertFilter right) { string leftClause = left.Clause; @@ -88,6 +101,9 @@ private static string CollapseParentheses(string expression) return new DivertFilter(clause); } + /// + /// Combines two filters with an OR operation. + /// public static DivertFilter operator |(DivertFilter left, DivertFilter right) { string leftClause = left.Clause; @@ -104,11 +120,16 @@ private static string CollapseParentheses(string expression) return new DivertFilter(clause); } + /// + /// Returns the filter clause. + /// + /// The filter clause. public override string ToString() { return Clause; } + /// public override bool Equals(object? obj) { return obj switch @@ -120,12 +141,20 @@ public override bool Equals(object? obj) }; } + /// public override int GetHashCode() => Clause.GetHashCode(); + /// + /// A layer specific property for matching packets/events. + /// public class Field { private readonly string field; + /// + /// Creates a new field with the specified name. + /// + /// The field name. public Field(string field) { ArgumentNullException.ThrowIfNull(field); @@ -133,6 +162,7 @@ public Field(string field) this.field = field; } + /// public override bool Equals(object? obj) { return obj switch @@ -144,32 +174,53 @@ public override bool Equals(object? obj) }; } + /// public override int GetHashCode() => field.GetHashCode(); + /// + /// Returns the field name. + /// + /// The field name. public override string ToString() => field; + /// + /// Implicitly converts a string to a Field. + /// + /// The field name. public static implicit operator Field(string field) { return new Field(field); } + /// + /// Implicitly converts a Field to a DivertFilter. + /// public static implicit operator DivertFilter(Field field) { return new DivertFilter(field.ToString()); } + /// + /// Combines two fields with an AND operation. + /// public static DivertFilter operator &(Field left, Field right) { DivertFilter value = left; return value & right; } + /// + /// Combines two fields with an OR operation. + /// public static DivertFilter operator |(Field left, Field right) { DivertFilter value = left; return value | right; } + /// + /// Negates a field. + /// public static DivertFilter operator !(Field value) { string clause = $"not {value}"; @@ -214,52 +265,94 @@ private static string Macro(DivertLayer layer) => _ => layer.ToString(), }; + /// + /// Equality operator for Field and object. + /// public static DivertFilter operator ==(Field left, object right) { string clause = $"{left} = {right}"; return new DivertFilter(clause); } + /// + /// Inequality operator for Field and object. + /// public static DivertFilter operator !=(Field left, object right) { string clause = $"{left} != {right}"; return new DivertFilter(clause); } + /// + /// Equality operator for Field and bool. + /// public static DivertFilter operator ==(Field left, bool right) => left == Macro(right); + /// + /// Inequality operator for Field and bool. + /// public static DivertFilter operator !=(Field left, bool right) => left != Macro(right); + /// + /// Equality operator for Field and . + /// public static DivertFilter operator ==(Field left, ProtocolType right) => left == Macro(right); + /// + /// Inequality operator for Field and . + /// public static DivertFilter operator !=(Field left, ProtocolType right) => left != Macro(right); + /// + /// Equality operator for Field and . + /// public static DivertFilter operator ==(Field left, DivertEvent right) => left == Macro(right); + /// + /// Inequality operator for Field and . + /// public static DivertFilter operator !=(Field left, DivertEvent right) => left != Macro(right); + /// + /// Equality operator for Field and . + /// public static DivertFilter operator ==(Field left, DivertLayer right) => left == Macro(right); + /// + /// Inequality operator for Field and . + /// public static DivertFilter operator !=(Field left, DivertLayer right) => left != Macro(right); + /// + /// Less than operator for Field and object. + /// public static DivertFilter operator <(Field left, object right) { string clause = $"{left} < {right}"; return new DivertFilter(clause); } + /// + /// Greater than operator for Field and object. + /// public static DivertFilter operator >(Field left, object right) { string clause = $"{left} > {right}"; return new DivertFilter(clause); } + /// + /// Less than or equal operator for Field and object. + /// public static DivertFilter operator <=(Field left, object right) { string clause = $"{left} <= {right}"; return new DivertFilter(clause); } + /// + /// Greater than or equal operator for Field and object. + /// public static DivertFilter operator >=(Field left, object right) { string clause = $"{left} >= {right}"; @@ -267,82 +360,193 @@ private static string Macro(DivertLayer layer) => } } + /// + /// The true expression. + /// public static DivertFilter True { get; } = "true"; + /// + /// The false expression. + /// public static DivertFilter False { get; } = "false"; + /// + /// The value zero. + /// public static Field Zero { get; } = "zero"; + /// + /// The packet/event timestamp. + /// public static Field Timestamp { get; } = "timestamp"; + /// + /// The event type. + /// public static Field Event { get; } = "event"; + /// + /// Is outbound? + /// public static Field Outbound { get; } = "outbound"; + /// + /// Is inbound? + /// public static Field Inbound { get; } = "inbound"; + /// + /// The interface index. + /// public static Field InterfaceIndex { get; } = "ifIdx"; + /// + /// The sub-interface index. + /// public static Field SubInterfaceIndex { get; } = "subIfIdx"; + /// + /// Is loopback packet? + /// public static Field Loopback { get; } = "loopback"; + /// + /// Is impostor packet? + /// public static Field Impostor { get; } = "impostor"; + /// + /// Is IPv4 fragment? + /// public static Field Fragment { get; } = "fragment"; + /// + /// The endpoint ID. + /// public static Field EndpointId { get; } = "endpointId"; + /// + /// The parent endpoint ID. + /// public static Field ParentEndpointId { get; } = "parentEndpointId"; + /// + /// The process ID. + /// public static Field ProcessId { get; } = "processId"; + /// + /// 8-bit random number. + /// public static Field Random8 { get; } = "random8"; + /// + /// 16-bit random number. + /// public static Field Random16 { get; } = "random16"; + /// + /// 32-bit random number. + /// public static Field Random32 { get; } = "random32"; + /// + /// The layer of the WinDivert handle. + /// public static Field Layer { get; } = "layer"; + /// + /// The priority of the WinDivert handle. + /// public static Field Priority { get; } = "priority"; + /// + /// The i-th byte of the packet. + /// public static Field Packet(int i) => $"packet[{i}]"; + /// + /// The i-th 16-bit word of the packet. + /// public static Field Packet16(int i) => $"packet16[{i}]"; + /// + /// The i-th 32-bit word of the packet. + /// public static Field Packet32(int i) => $"packet32[{i}]"; + /// + /// The packet length. + /// public static Field Length { get; } = "length"; + /// + /// Is IPv4? + /// public static Field Ip { get; } = "ip"; + /// + /// Is IPv6? + /// public static Field Ipv6 { get; } = "ipv6"; + /// + /// Is ICMP? + /// public static Field ICMP { get; } = "icmp"; + /// + /// Is ICMPv6? + /// // spell-checker:ignore icmpv6 public static Field ICMPv6 { get; } = "icmpv6"; + /// + /// Is TCP? + /// public static Field TCP { get; } = "tcp"; + /// + /// Is UDP? + /// public static Field UDP { get; } = "udp"; + /// + /// The protocol. + /// public static Field Protocol { get; } = "protocol"; + /// + /// The local address. + /// public static Field LocalAddress { get; } = "localAddr"; + /// + /// The local port. + /// public static Field LocalPort { get; } = "localPort"; + /// + /// The remote address. + /// public static Field RemoteAddress { get; } = "remoteAddr"; + /// + /// The remote port. + /// public static Field RemotePort { get; } = "remotePort"; + /// + /// Implicitly converts a string to a . + /// public static implicit operator DivertFilter(string clause) { return new DivertFilter(clause); } + /// + /// Implicitly converts a bool to a . + /// public static implicit operator DivertFilter(bool value) { return value ? True : False; diff --git a/Divert.Windows/DivertFlags.cs b/Divert.Windows/DivertFlags.cs index 314f2fe..5cf99c8 100644 --- a/Divert.Windows/DivertFlags.cs +++ b/Divert.Windows/DivertFlags.cs @@ -1,15 +1,53 @@ namespace Divert.Windows; +/// +/// Flags for configuring the behavior of the WinDivert handle. +/// [Flags] public enum DivertFlags { + /// + /// Default mode: packets are dropped and captured. + /// None = 0, + + /// + /// Packet sniffing mode: packets are captured but not dropped. + /// Sniff = 0x0001, + + /// + /// Packets are silently dropped. + /// Drop = 0x0002, + + /// + /// Receive-only mode: disables sending packets. + /// ReceiveOnly = 0x0004, + + /// + /// Same as . + /// ReadOnly = ReceiveOnly, + + /// + /// Send-only mode: disables receiving packets. + /// SendOnly = 0x0008, + + /// + /// Same as . + /// WriteOnly = SendOnly, + + /// + /// Prevents automatic installation of the WinDivert driver. + /// NoInstall = 0x0010, + + /// + /// Captures fragmented packets. + /// Fragments = 0x0020, } diff --git a/Divert.Windows/DivertHandle.cs b/Divert.Windows/DivertHandle.cs index ee9a4f0..2575fb3 100644 --- a/Divert.Windows/DivertHandle.cs +++ b/Divert.Windows/DivertHandle.cs @@ -2,13 +2,31 @@ namespace Divert.Windows; +/// +/// Safe handle for a WinDivert handle. +/// public sealed class DivertHandle : SafeHandleZeroOrMinusOneIsInvalid { + /// + /// Creates a new instance of the class. + /// + /// + /// The WinDivert handle. + /// + /// + /// Whether the handle should be released when the SafeHandle is disposed. + /// public DivertHandle(IntPtr handle, bool ownsHandle = true) : base(ownsHandle) { this.handle = handle; } + /// + /// Releases the WinDivert handle. + /// + /// + /// true if the handle was released successfully; otherwise, false. + /// protected override bool ReleaseHandle() => NativeMethods.WinDivertClose(handle); } diff --git a/Divert.Windows/DivertHelper.cs b/Divert.Windows/DivertHelper.cs index a06b151..60ae7dc 100644 --- a/Divert.Windows/DivertHelper.cs +++ b/Divert.Windows/DivertHelper.cs @@ -5,8 +5,19 @@ namespace Divert.Windows; +/// +/// WinDivert helper methods. +/// public static unsafe class DivertHelper { + /// + /// Calculates the checksums for the specified packet. + /// + /// + /// The packet data. + /// + /// The flags to disable individual checksum calculations. + /// true if the checksums were calculated successfully; otherwise, false. public static bool CalculateChecksums(Span packet, DivertHelperFlags flags = DivertHelperFlags.None) { fixed (byte* pPacket = packet) @@ -15,6 +26,17 @@ public static bool CalculateChecksums(Span packet, DivertHelperFlags flags } } + /// + /// Calculates the checksums for the specified packet. + /// + /// + /// The packet data. + /// + /// + /// A reference to a structure where the corresponding ChecksumValid fields will be set. + /// + /// The flags to disable individual checksum calculations. + /// true if the checksums were calculated successfully; otherwise, false. public static bool CalculateChecksums( Span packet, ref DivertAddress address, @@ -33,6 +55,11 @@ public static bool CalculateChecksums( } } + /// + /// Decrements the TTL field in the IP header of the specified packet. + /// + /// The packet data. + /// true if the result is non-zero; otherwise, false. public static bool DecrementTtl(Span packet) { fixed (byte* pPacket = packet) @@ -41,6 +68,18 @@ public static bool DecrementTtl(Span packet) } } + /// + /// Compiles a WinDivert filter string into a compact object representation. + /// + /// + /// The filter to compile. + /// + /// The layer. + /// The length of the buffer. + /// The compiled filter. + /// + /// The filter is invalid. + /// public static ReadOnlySpan CompileFilter( DivertFilter filter, DivertLayer layer, @@ -97,6 +136,13 @@ private static bool EvaluateFilter(IntPtr filter, ReadOnlySpan packet, in } } + /// + /// Evaluates a compiled WinDivert filter against the specified packet and address. + /// + /// The compiled filter. + /// The packet data. + /// The address information. + /// true if the packet matches the filter; otherwise, false. public static bool EvaluateFilter(ReadOnlySpan filter, ReadOnlySpan packet, in DivertAddress address) { fixed (byte* pFilter = filter) @@ -107,6 +153,13 @@ public static bool EvaluateFilter(ReadOnlySpan filter, ReadOnlySpan } } + /// + /// Evaluates a WinDivert filter string against the specified packet and address. + /// + /// The filter to evaluate. + /// The packet data. + /// The address information. + /// true if the packet matches the filter; otherwise, false. public static bool EvaluateFilter(DivertFilter filter, ReadOnlySpan packet, in DivertAddress address) { using var s = new CString(filter.Clause); @@ -117,6 +170,13 @@ public static bool EvaluateFilter(DivertFilter filter, ReadOnlySpan packet } } + /// + /// Formats a compiled WinDivert filter into a human-readable string. + /// + /// The compiled filter. + /// The layer. + /// The maximum length of the formatted string. + /// The formatted filter string. public static string FormatFilter(Span filter, DivertLayer layer, int maxLength = ushort.MaxValue) { Span buffer = GC.AllocateArray(maxLength, pinned: true); diff --git a/Divert.Windows/DivertHelperFlags.cs b/Divert.Windows/DivertHelperFlags.cs index 9a3f03d..78ad53c 100644 --- a/Divert.Windows/DivertHelperFlags.cs +++ b/Divert.Windows/DivertHelperFlags.cs @@ -1,12 +1,38 @@ namespace Divert.Windows; +/// +/// Flags to disable individual checksum calculations. +/// [Flags] public enum DivertHelperFlags { + /// + /// No flags specified. + /// None = 0, + + /// + /// Disables IP checksum calculation. + /// NoIPChecksum = 1, + + /// + /// Disables ICMP checksum calculation. + /// NoICMPChecksum = 2, + + /// + /// Disables ICMPv6 checksum calculation. + /// NoICMPv6Checksum = 4, + + /// + /// Disables TCP checksum calculation. + /// NoTCPChecksum = 8, + + /// + /// Disables UDP checksum calculation. + /// NoUDPChecksum = 16, } diff --git a/Divert.Windows/DivertIOControl.cs b/Divert.Windows/DivertIOControl.cs index 334f61c..ba4615c 100644 --- a/Divert.Windows/DivertIOControl.cs +++ b/Divert.Windows/DivertIOControl.cs @@ -5,6 +5,10 @@ namespace Divert.Windows; +/// +/// Replaces and as they +/// are not applicable to thread pool bound handles. +/// internal static unsafe class DivertIOControl { private static uint CTL_CODE(uint deviceType, uint function, uint method, uint access) diff --git a/Divert.Windows/DivertLayer.cs b/Divert.Windows/DivertLayer.cs index 4cc1c4b..4abed19 100644 --- a/Divert.Windows/DivertLayer.cs +++ b/Divert.Windows/DivertLayer.cs @@ -1,10 +1,32 @@ namespace Divert.Windows; +/// +/// Specifies the WinDivert layer. +/// public enum DivertLayer { + /// + /// Network packets to/from the local machine. + /// Network = WINDIVERT_LAYER.WINDIVERT_LAYER_NETWORK, + + /// + /// Network packets being forwarded through the local machine. + /// Forward = WINDIVERT_LAYER.WINDIVERT_LAYER_NETWORK_FORWARD, + + /// + /// Network flow events. + /// Flow = WINDIVERT_LAYER.WINDIVERT_LAYER_FLOW, + + /// + /// Socket events. + /// Socket = WINDIVERT_LAYER.WINDIVERT_LAYER_SOCKET, + + /// + /// WinDivert handle events. + /// Reflect = WINDIVERT_LAYER.WINDIVERT_LAYER_REFLECT, } diff --git a/Divert.Windows/DivertReceiveResult.cs b/Divert.Windows/DivertReceiveResult.cs index 3587ed0..627ba4a 100644 --- a/Divert.Windows/DivertReceiveResult.cs +++ b/Divert.Windows/DivertReceiveResult.cs @@ -1,5 +1,10 @@ namespace Divert.Windows; +/// +/// Represents the result of a Divert receive operation. +/// +/// The length of the received data. +/// The length of the addresses. public readonly struct DivertReceiveResult(int dataLength, int addressLength) { /// diff --git a/Divert.Windows/DivertService.cs b/Divert.Windows/DivertService.cs index 4a40259..e0d48c1 100644 --- a/Divert.Windows/DivertService.cs +++ b/Divert.Windows/DivertService.cs @@ -12,18 +12,66 @@ namespace Divert.Windows; /// public sealed unsafe class DivertService : IDisposable { + /// + /// The highest priority for a WinDivert handle. + /// public const int HighestPriority = Constants.WINDIVERT_PRIORITY_HIGHEST; + + /// + /// The lowest priority for a WinDivert handle. + /// public const int LowestPriority = Constants.WINDIVERT_PRIORITY_LOWEST; + + /// + /// The default packet queue length for receive operations. + /// public const int DefaultQueueLength = Constants.WINDIVERT_PARAM_QUEUE_LENGTH_DEFAULT; + + /// + /// The minimum packet queue length for receive operations. + /// public const int MinQueueLength = Constants.WINDIVERT_PARAM_QUEUE_LENGTH_MIN; + + /// + /// The maximum packet queue length for receive operations. + /// public const int MaxQueueLength = Constants.WINDIVERT_PARAM_QUEUE_LENGTH_MAX; + + /// + /// The default packet queue time. + /// public static TimeSpan DefaultQueueTime => TimeSpan.FromMilliseconds(Constants.WINDIVERT_PARAM_QUEUE_TIME_DEFAULT); + + /// + /// The minimum packet queue time. + /// public static TimeSpan MinQueueTime => TimeSpan.FromMilliseconds(Constants.WINDIVERT_PARAM_QUEUE_TIME_MIN); + + /// + /// The maximum packet queue time. + /// public static TimeSpan MaxQueueTime => TimeSpan.FromMilliseconds(Constants.WINDIVERT_PARAM_QUEUE_TIME_MAX); + + /// + /// The default max number of bytes in the packet queue for receive operations. + /// public const int DefaultQueueSize = Constants.WINDIVERT_PARAM_QUEUE_SIZE_DEFAULT; + + /// + /// The minimum max number of bytes in the packet queue for receive operations. + /// public const int MinQueueSize = Constants.WINDIVERT_PARAM_QUEUE_SIZE_MIN; + + /// + /// The maximum max number of bytes in the packet queue for receive operations. + /// public const int MaxQueueSize = Constants.WINDIVERT_PARAM_QUEUE_SIZE_MAX; + /// + /// The maximum number of packets in a single send or receive operation. + /// + public const int MaxBatchSize = Constants.WINDIVERT_BATCH_MAX; + private readonly DivertHandle divertHandle; private readonly bool runContinuationsAsynchronously; @@ -43,21 +91,6 @@ private DivertService(DivertHandle divertHandle) sendExecutor = new DivertSendExecutor(); } - /// - /// Initializes a new instance of the class. - /// - public DivertService( - DivertFilter filter, - DivertLayer layer = DivertLayer.Network, - short priority = 0, - DivertFlags flags = DivertFlags.None, - bool runContinuationsAsynchronously = true - ) - : this(OpenHandle(DivertHelper.CompileFilter(filter, layer), layer, priority, flags)) - { - this.runContinuationsAsynchronously = runContinuationsAsynchronously; - } - private static DivertHandle OpenHandle( ReadOnlySpan filter, DivertLayer layer, @@ -82,6 +115,31 @@ DivertFlags flags /// /// Initializes a new instance of the class. /// + /// The packet filter. + /// The layer. + /// The priority of the handle. + /// The handle flags. + /// Whether to force continuations to run asynchronously. + public DivertService( + DivertFilter filter, + DivertLayer layer = DivertLayer.Network, + short priority = 0, + DivertFlags flags = DivertFlags.None, + bool runContinuationsAsynchronously = true + ) + : this(OpenHandle(DivertHelper.CompileFilter(filter, layer), layer, priority, flags)) + { + this.runContinuationsAsynchronously = runContinuationsAsynchronously; + } + + /// + /// Initializes a new instance of the class. + /// + /// The packet filter. + /// The layer. + /// The priority of the handle. + /// The handle flags. + /// Whether to force continuations to run asynchronously. public DivertService( ReadOnlySpan filter, DivertLayer layer = DivertLayer.Network, @@ -97,6 +155,8 @@ public DivertService( /// /// Initializes a new instance of the class. /// + /// An existing WinDivert handle. + /// Whether to force continuations to run asynchronously. public DivertService(SafeHandle handle, bool runContinuationsAsynchronously = true) : this(new DivertHandle(handle.DangerousGetHandle(), ownsHandle: false)) { @@ -131,6 +191,9 @@ public void Dispose() divertHandle.Dispose(); } + /// + /// Gets the underlying WinDivert handle. + /// public DivertHandle SafeHandle => divertHandle; private DivertValueTaskSource GetVts(Channel vtsPool) @@ -193,6 +256,9 @@ public ValueTask SendAsync( return sendExecutor.SendAsync(vts, buffer, addresses, cancellationToken); } + /// + /// Gets the version of the WinDivert driver. + /// public Version Version { get @@ -205,6 +271,9 @@ public Version Version } } + /// + /// Gets or sets the length of the receive queue. + /// public int QueueLength { get => (int)DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_LENGTH); @@ -219,6 +288,9 @@ public int QueueLength } } + /// + /// Gets or sets the maximum packet queue time. + /// public TimeSpan QueueTime { get => @@ -240,6 +312,9 @@ public TimeSpan QueueTime } } + /// + /// Gets or sets the maximum number of bytes in the receive queue. + /// public int QueueSize { get => (int)DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_SIZE); diff --git a/Examples/Flow/Flow.csproj b/Examples/Flow/Flow.csproj new file mode 100644 index 0000000..8927562 --- /dev/null +++ b/Examples/Flow/Flow.csproj @@ -0,0 +1,15 @@ + + + Exe + $(DefaultTargetFramework) + app.manifest + win-x64 + true + $(ProjectRoot).local/windows/Examples/$(MSBuildThisFileName) + CA2007 + + + + + + diff --git a/Examples/Flow/Program.cs b/Examples/Flow/Program.cs new file mode 100644 index 0000000..af780d7 --- /dev/null +++ b/Examples/Flow/Program.cs @@ -0,0 +1,136 @@ +using System.Net; +using System.Net.Sockets; +using System.Runtime.Versioning; +using Divert.Windows; + +[assembly: SupportedOSPlatform("windows6.0.6000")] + +// Logs flow events on a loopback TCP listener. + +using var service = new DivertService( + DivertFilter.ProcessId == Environment.ProcessId & DivertFilter.TCP, + DivertLayer.Flow, + flags: DivertFlags.Sniff | DivertFlags.ReceiveOnly +) +{ + QueueTime = DivertService.MaxQueueTime, +}; + +static void WriteLine(string prefix, string message) +{ + Console.WriteLine($"{prefix}: {message}"); +} + +using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); +listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); +listener.Listen(); +int port = ((IPEndPoint)listener.LocalEndPoint!).Port; +WriteLine(nameof(TcpListener), $"Listening on port {port}..."); +using var cts = new CancellationTokenSource(); +var token = cts.Token; + +var sniff = Task.Run(async () => +{ + try + { + var addresses = new DivertAddress[1]; + while (true) + { + await service.ReceiveAsync(Memory.Empty, addresses, token); + var @event = addresses[0].Event; + var socketData = addresses[0].GetFlowData(); + if (socketData.LocalPort == port) + { + WriteLine( + nameof(DivertService), + $"{@event} from {nameof(TcpListener)}, " + + $"local port = {socketData.LocalPort}, remote port = {socketData.RemotePort}." + ); + } + else + { + WriteLine( + nameof(DivertService), + $"{@event} from {nameof(TcpClient)}, " + + $"local port = {socketData.LocalPort}, remote port = {socketData.RemotePort}." + ); + } + } + } + catch (OperationCanceledException) when (token.IsCancellationRequested) { } + catch (Exception e) + { + WriteLine("Error", e.ToString()); + } +}); + +var listen = Task.Run(async () => +{ + try + { + while (true) + { + var client = await listener.AcceptAsync(token); + int clientPort = ((IPEndPoint)client.RemoteEndPoint!).Port; + WriteLine(nameof(TcpListener), $"Accepted connection from port {clientPort}."); + _ = Task.Run(async () => + { + using var _ = client; + var buffer = new byte[1024]; + + while (true) + { + int length = await client.ReceiveAsync(buffer, token); + if (length is 0) + { + WriteLine(nameof(TcpListener), "Closing connection..."); + client.Shutdown(SocketShutdown.Send); + break; + } + } + }); + } + } + catch (OperationCanceledException) when (token.IsCancellationRequested) { } + catch (Exception e) + { + WriteLine("Error", e.ToString()); + } +}); + +WriteLine("Info", "Preparing TCP client..."); +try +{ + for (int i = 0; i < 3; i++) + { + await Task.Delay(TimeSpan.FromSeconds(3), token); + Console.WriteLine(); + using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + WriteLine(nameof(TcpClient), $"Connecting to port {port}..."); + await client.ConnectAsync(IPAddress.Loopback, port, token); + WriteLine(nameof(TcpClient), $"Connected to port {port}."); + + WriteLine(nameof(TcpClient), "Closing connection..."); + client.Shutdown(SocketShutdown.Send); + var buffer = new byte[1024]; + while (true) + { + int length = await client.ReceiveAsync(buffer, token); + if (length is 0) + { + break; + } + } + } +} +catch (Exception e) +{ + WriteLine("Error", e.ToString()); +} + +await Task.Delay(TimeSpan.FromSeconds(3), token); +cts.Cancel(); +await Task.WhenAll(sniff, listen).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); +Console.WriteLine(); +WriteLine("Info", "Done."); +Console.ReadLine(); diff --git a/Examples/Flow/app.manifest b/Examples/Flow/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Examples/Flow/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/Examples/Ping/Ping.csproj b/Examples/Ping/Ping.csproj index f9fd15d..ee85910 100644 --- a/Examples/Ping/Ping.csproj +++ b/Examples/Ping/Ping.csproj @@ -2,6 +2,7 @@ Exe $(DefaultTargetFramework) + app.manifest win-x64 true $(ProjectRoot).local/windows/Examples/$(MSBuildThisFileName) diff --git a/Examples/Ping/Program.cs b/Examples/Ping/Program.cs index 30d2a88..e0c471e 100644 --- a/Examples/Ping/Program.cs +++ b/Examples/Ping/Program.cs @@ -1,27 +1,12 @@ -using System.Diagnostics; -using System.Net; +using System.Net; using System.Net.NetworkInformation; using System.Net.Sockets; using System.Runtime.Versioning; -using System.Security.Principal; using Divert.Windows; [assembly: SupportedOSPlatform("windows6.0.6000")] -var identity = WindowsIdentity.GetCurrent(); -var principal = new WindowsPrincipal(identity); -if (!principal.IsInRole(WindowsBuiltInRole.Administrator)) -{ - var startInfo = new ProcessStartInfo - { - FileName = Environment.ProcessPath, - Verb = "runas", - Arguments = args.Length > 0 ? args[0] : string.Empty, - UseShellExecute = true, - }; - Process.Start(startInfo); - Environment.Exit(0); -} +// Divert and re-inject ICMP packets to/from default gateway (or loopback if none). string? remoteIP = args.Length > 0 ? args[0] : null; if (remoteIP is null) @@ -36,7 +21,7 @@ from gateway in nic.GetIPProperties().GatewayAddresses select address; remoteIP = gateWays.FirstOrDefault()?.ToString(); } -remoteIP ??= "1.1.1.1"; +remoteIP ??= IPAddress.Loopback.ToString(); var outboundFilter = DivertFilter.Outbound & !DivertFilter.Loopback & DivertFilter.RemoteAddress == remoteIP & DivertFilter.ICMP; diff --git a/Examples/Ping/app.manifest b/Examples/Ping/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Examples/Ping/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/Examples/Socket/Program.cs b/Examples/Socket/Program.cs new file mode 100644 index 0000000..7cafa92 --- /dev/null +++ b/Examples/Socket/Program.cs @@ -0,0 +1,131 @@ +using System.Net; +using System.Net.Sockets; +using System.Runtime.Versioning; +using Divert.Windows; + +[assembly: SupportedOSPlatform("windows6.0.6000")] + +// Logs socket events on a loopback TCP listener. + +using var service = new DivertService( + DivertFilter.ProcessId == Environment.ProcessId & DivertFilter.TCP, + DivertLayer.Socket, + flags: DivertFlags.Sniff | DivertFlags.ReceiveOnly +) +{ + QueueTime = DivertService.MaxQueueTime, +}; + +static void WriteLine(string prefix, string message) +{ + Console.WriteLine($"{prefix}: {message}"); +} + +using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); +listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); +listener.Listen(); +int port = ((IPEndPoint)listener.LocalEndPoint!).Port; +WriteLine(nameof(TcpListener), $"Listening on port {port}..."); +using var cts = new CancellationTokenSource(); +var token = cts.Token; + +var sniff = Task.Run(async () => +{ + try + { + var addresses = new DivertAddress[1]; + while (true) + { + await service.ReceiveAsync(Memory.Empty, addresses, token); + var @event = addresses[0].Event; + var socketData = addresses[0].GetSocketData(); + if (socketData.LocalPort == port) + { + WriteLine( + nameof(DivertService), + $"{@event} from {nameof(TcpListener)}, " + + $"local port = {socketData.LocalPort}, remote port = {socketData.RemotePort}." + ); + } + else + { + WriteLine( + nameof(DivertService), + $"{@event} from {nameof(TcpClient)}, " + + $"local port = {socketData.LocalPort}, remote port = {socketData.RemotePort}." + ); + } + } + } + catch (OperationCanceledException) when (token.IsCancellationRequested) { } + catch (Exception e) + { + WriteLine("Error", e.ToString()); + } +}); + +var listen = Task.Run(async () => +{ + try + { + while (true) + { + var client = await listener.AcceptAsync(token); + int clientPort = ((IPEndPoint)client.RemoteEndPoint!).Port; + WriteLine(nameof(TcpListener), $"Accepted connection from port {clientPort}."); + _ = Task.Run(async () => + { + using var _ = client; + var buffer = new byte[1024]; + + while (true) + { + int length = await client.ReceiveAsync(buffer, token); + if (length is 0) + { + WriteLine(nameof(TcpListener), "Closing connection..."); + client.Shutdown(SocketShutdown.Send); + break; + } + } + }); + } + } + catch (OperationCanceledException) when (token.IsCancellationRequested) { } + catch (Exception e) + { + WriteLine("Error", e.ToString()); + } +}); + +try +{ + WriteLine("Info", "Preparing TCP client..."); + await Task.Delay(TimeSpan.FromSeconds(3), token); + using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + WriteLine(nameof(TcpClient), $"Connecting to port {port}..."); + await client.ConnectAsync(IPAddress.Loopback, port, token); + WriteLine(nameof(TcpClient), $"Connected to port {port}."); + + WriteLine(nameof(TcpClient), "Closing connection..."); + client.Shutdown(SocketShutdown.Send); + var buffer = new byte[1024]; + while (true) + { + int length = await client.ReceiveAsync(buffer, token); + if (length is 0) + { + break; + } + } +} +catch (Exception e) +{ + WriteLine("Error", e.ToString()); +} + +await Task.Delay(TimeSpan.FromSeconds(3), token); +cts.Cancel(); +await Task.WhenAll(sniff, listen).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); +WriteLine("Info", "Done."); +Console.ReadLine(); diff --git a/Examples/Socket/Socket.csproj b/Examples/Socket/Socket.csproj new file mode 100644 index 0000000..8927562 --- /dev/null +++ b/Examples/Socket/Socket.csproj @@ -0,0 +1,15 @@ + + + Exe + $(DefaultTargetFramework) + app.manifest + win-x64 + true + $(ProjectRoot).local/windows/Examples/$(MSBuildThisFileName) + CA2007 + + + + + + diff --git a/Examples/Socket/app.manifest b/Examples/Socket/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Examples/Socket/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + From 4c2369fb3064db6736d929e2061e072ccfdd6382 Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 9 Nov 2025 14:11:36 +0000 Subject: [PATCH 12/20] Examples. --- Divert.Windows.TestRunner/Program.cs | 3 + Divert.Windows.Tests/ChecksumTests.cs | 4 +- Divert.Windows.Tests/DivertTests.cs | 1 + Divert.Windows/DivertReceiveResult.cs | 11 ++ Examples/Greeting/Greeting.csproj | 15 +++ Examples/Greeting/Program.cs | 25 +++++ Examples/Greeting/app.manifest | 10 ++ Examples/Http/Http.csproj | 15 +++ Examples/Http/Program.cs | 138 ++++++++++++++++++++++++++ Examples/Http/app.manifest | 10 ++ Examples/Ping/Program.cs | 8 +- ReadMe.md | 24 ++++- 12 files changed, 257 insertions(+), 7 deletions(-) create mode 100644 Examples/Greeting/Greeting.csproj create mode 100644 Examples/Greeting/Program.cs create mode 100644 Examples/Greeting/app.manifest create mode 100644 Examples/Http/Http.csproj create mode 100644 Examples/Http/Program.cs create mode 100644 Examples/Http/app.manifest diff --git a/Divert.Windows.TestRunner/Program.cs b/Divert.Windows.TestRunner/Program.cs index 1ab0781..43b0de9 100644 --- a/Divert.Windows.TestRunner/Program.cs +++ b/Divert.Windows.TestRunner/Program.cs @@ -5,6 +5,9 @@ using System.Text; using System.Threading.Channels; +// Runs ../Divert.Windows.Tests with coverage collection enabled. +// In "watch" mode, run tests when receiving requests from Devcontainer. + string appPath = Path.GetDirectoryName(Environment.ProcessPath)!; string testResultsDirectory = Path.Combine(appPath, "../TestResults"); string coverageSettingsPath = Path.Combine(appPath, "CoverageSettings.xml"); diff --git a/Divert.Windows.Tests/ChecksumTests.cs b/Divert.Windows.Tests/ChecksumTests.cs index 1c39311..7c6cfae 100644 --- a/Divert.Windows.Tests/ChecksumTests.cs +++ b/Divert.Windows.Tests/ChecksumTests.cs @@ -33,8 +33,8 @@ public async Task InvalidChecksum() client.Connect(IPAddress.Loopback, port); await client.SendAsync(new byte[] { 1, 2, 3 }, Token); - var divertResult = await divertReceive; - var packet = packetBuffer.AsMemory(0, divertResult.DataLength); + (int length, _) = await divertReceive; + var packet = packetBuffer.AsMemory(0, length); // Calculate and then invalidate checksums Assert.IsTrue(DivertHelper.CalculateChecksums(packet.Span)); diff --git a/Divert.Windows.Tests/DivertTests.cs b/Divert.Windows.Tests/DivertTests.cs index 29bfeb4..7d47daf 100644 --- a/Divert.Windows.Tests/DivertTests.cs +++ b/Divert.Windows.Tests/DivertTests.cs @@ -20,6 +20,7 @@ public DivertTests() public void Dispose() { cts.Dispose(); + GC.SuppressFinalize(this); } public static UdpClient CreateUdpListener(out int port) diff --git a/Divert.Windows/DivertReceiveResult.cs b/Divert.Windows/DivertReceiveResult.cs index 627ba4a..8c7ee8e 100644 --- a/Divert.Windows/DivertReceiveResult.cs +++ b/Divert.Windows/DivertReceiveResult.cs @@ -16,4 +16,15 @@ public readonly struct DivertReceiveResult(int dataLength, int addressLength) /// Gets the length of the addresses. /// public int AddressLength { get; } = addressLength; + + /// + /// Deconstructs the result into its components. + /// + /// The length of the received data. + /// The length of the addresses. + public void Deconstruct(out int dataLength, out int addressLength) + { + dataLength = DataLength; + addressLength = AddressLength; + } } diff --git a/Examples/Greeting/Greeting.csproj b/Examples/Greeting/Greeting.csproj new file mode 100644 index 0000000..8927562 --- /dev/null +++ b/Examples/Greeting/Greeting.csproj @@ -0,0 +1,15 @@ + + + Exe + $(DefaultTargetFramework) + app.manifest + win-x64 + true + $(ProjectRoot).local/windows/Examples/$(MSBuildThisFileName) + CA2007 + + + + + + diff --git a/Examples/Greeting/Program.cs b/Examples/Greeting/Program.cs new file mode 100644 index 0000000..341ee83 --- /dev/null +++ b/Examples/Greeting/Program.cs @@ -0,0 +1,25 @@ +using System.Net; +using System.Net.Sockets; +using System.Runtime.Versioning; +using System.Text; +using Divert.Windows; + +// Captures a UDP packet and print to console. + +[assembly: SupportedOSPlatform("windows6.0.6000")] + +using var client = new UdpClient(new IPEndPoint(IPAddress.Loopback, 0)); +var localEndPoint = (IPEndPoint)client.Client.LocalEndPoint!; +int port = localEndPoint.Port; +Console.WriteLine($"Created UDP client on port {port}."); + +using var service = new DivertService(DivertFilter.UDP & DivertFilter.LocalPort == port); +var buffer = new byte[1024]; +var receive = service.ReceiveAsync(buffer, new DivertAddress[1]).AsTask(); + +Console.WriteLine("Sending packet to self..."); +await client.SendAsync("Hello"u8.ToArray(), localEndPoint); + +(int packetLength, _) = await receive; +string message = Encoding.UTF8.GetString(buffer.AsSpan(28, packetLength)); +Console.WriteLine($"{message} from WinDivert!"); diff --git a/Examples/Greeting/app.manifest b/Examples/Greeting/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Examples/Greeting/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/Examples/Http/Http.csproj b/Examples/Http/Http.csproj new file mode 100644 index 0000000..8927562 --- /dev/null +++ b/Examples/Http/Http.csproj @@ -0,0 +1,15 @@ + + + Exe + $(DefaultTargetFramework) + app.manifest + win-x64 + true + $(ProjectRoot).local/windows/Examples/$(MSBuildThisFileName) + CA2007 + + + + + + diff --git a/Examples/Http/Program.cs b/Examples/Http/Program.cs new file mode 100644 index 0000000..9796597 --- /dev/null +++ b/Examples/Http/Program.cs @@ -0,0 +1,138 @@ +using System.Buffers.Binary; +using System.Collections.Concurrent; +using System.Net; +using System.Runtime.Versioning; +using Divert.Windows; + +// Directs all HTTP traffic to a localhost:8080. + +[assembly: SupportedOSPlatform("windows6.0.6000")] + +if (!(args is [string arg, ..] && ushort.TryParse(arg, out ushort listenPort))) +{ + listenPort = 8080; +} + +Console.WriteLine($"Redirecting all HTTP traffic to http://localhost:{listenPort}..."); +using var listener = new HttpListener() { Prefixes = { $"http://*:{listenPort}/" } }; +listener.Start(); +var listen = Task.Run(async () => +{ + var buffer = "Hello from local HTTP server!"u8.ToArray(); + while (true) + { + var context = await listener.GetContextAsync(); + _ = Task.Run(async () => + { + var remoteEndPoint = context.Request.RemoteEndPoint; + Console.WriteLine( + $"Received request from {remoteEndPoint.Address}:{remoteEndPoint.Port} for {context.Request.Url}" + ); + var response = context.Response; + response.ContentLength64 = buffer.Length; + response.ContentType = "text/plain"; + await response.OutputStream.WriteAsync(buffer); + response.OutputStream.Close(); + response.Close(); + }); + } +}); + +using var outService = new DivertService(DivertFilter.Outbound & DivertFilter.RemotePort == 80); +using var inService = new DivertService( + (DivertFilter.RemoteAddress == IPAddress.Loopback | DivertFilter.RemoteAddress == IPAddress.IPv6Loopback) + & DivertFilter.LocalPort == listenPort +); +var portMapping = new ConcurrentDictionary(); + +var redirect = Task.Run(async () => +{ + var buffer = new byte[ushort.MaxValue + 40]; + var addresses = new DivertAddress[1]; + while (true) + { + var result = await outService.ReceiveAsync(buffer, addresses); + ushort sourcePort; + IPAddress originalSource; + IPAddress originalDestination; + if (addresses[0].IsIPv6) + { + sourcePort = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(40)); + originalSource = new IPAddress(buffer.AsSpan(8, 16)); + originalDestination = new IPAddress(buffer.AsSpan(24, 16)); + IPAddress.IPv6Loopback.GetAddressBytes().CopyTo(buffer, 8); + IPAddress.IPv6Loopback.GetAddressBytes().CopyTo(buffer, 24); + BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(42), listenPort); + } + else + { + int ihl = (buffer[0] & 0x0F) * 4; + sourcePort = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(ihl)); + originalSource = new IPAddress(buffer.AsSpan(12, 4)); + originalDestination = new IPAddress(buffer.AsSpan(16, 4)); + IPAddress.Loopback.GetAddressBytes().CopyTo(buffer, 12); + IPAddress.Loopback.GetAddressBytes().CopyTo(buffer, 16); + BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(ihl + 2), listenPort); + } + Console.WriteLine( + $"Redirecting outbound HTTP packet from {originalSource}:{sourcePort} to {originalDestination}:80..." + ); + portMapping[sourcePort] = (originalSource, originalDestination); + var packet = buffer.AsMemory(0, result.DataLength); + DivertHelper.CalculateChecksums(packet.Span, ref addresses[0]); + await outService.SendAsync(packet, addresses); + } +}); + +var inject = Task.Run(async () => +{ + var buffer = new byte[ushort.MaxValue + 40]; + var addresses = new DivertAddress[1]; + while (true) + { + (int packetLength, _) = await inService.ReceiveAsync(buffer, addresses); + ushort targetPort; + IPAddress? originalSource = null; + IPAddress? originalDestination = null; + if (addresses[0].IsIPv6) + { + targetPort = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(42)); + if (portMapping.TryGetValue(targetPort, out var value)) + { + (originalSource, originalDestination) = value; + originalDestination.GetAddressBytes().CopyTo(buffer, 8); + originalSource.GetAddressBytes().CopyTo(buffer, 24); + BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(40), 80); + } + } + else + { + int ihl = (buffer[0] & 0x0F) * 4; + targetPort = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(ihl + 2)); + if (portMapping.TryGetValue(targetPort, out var value)) + { + (originalSource, originalDestination) = value; + originalDestination.GetAddressBytes().CopyTo(buffer, 12); + originalSource.GetAddressBytes().CopyTo(buffer, 16); + BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(ihl), 80); + } + } + if (originalSource is not null) + { + Console.WriteLine( + $"Injecting inbound HTTP packet from {originalDestination}:{80} to {originalSource}:{targetPort}..." + ); + } + var packet = buffer.AsMemory(0, packetLength); + DivertHelper.CalculateChecksums(packet.Span, ref addresses[0]); + await inService.SendAsync(packet, addresses); + } +}); + +using var client = new HttpClient(); +string response = await client.GetStringAsync("http://example.com/"); +Console.WriteLine($"Response from example.com (injected): {response}"); + +Console.WriteLine("Ready."); + +await await Task.WhenAny(listen, redirect, inject); diff --git a/Examples/Http/app.manifest b/Examples/Http/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Examples/Http/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/Examples/Ping/Program.cs b/Examples/Ping/Program.cs index e0c471e..2d2798b 100644 --- a/Examples/Ping/Program.cs +++ b/Examples/Ping/Program.cs @@ -68,8 +68,8 @@ from gateway in nic.GetIPProperties().GatewayAddresses { try { - var result = await outDivert.ReceiveAsync(buffer, addresses, cts.Token).ConfigureAwait(false); - var packet = buffer.AsMemory(0, result.DataLength); + (int packetLength, _) = await outDivert.ReceiveAsync(buffer, addresses, cts.Token).ConfigureAwait(false); + var packet = buffer.AsMemory(0, packetLength); var remoteAddress = new IPAddress(packet[16..20].Span); Console.WriteLine($"Pinging {remoteAddress} with {packet.Length - 28} bytes of data (Divert):"); await outDivert.SendAsync(packet, addresses, cts.Token); @@ -97,8 +97,8 @@ from gateway in nic.GetIPProperties().GatewayAddresses { try { - var receiveResult = await inDivert.ReceiveAsync(buffer, addresses, cts.Token).ConfigureAwait(false); - var packet = buffer.AsMemory(0, receiveResult.DataLength); + (int packetLength, _) = await inDivert.ReceiveAsync(buffer, addresses, cts.Token).ConfigureAwait(false); + var packet = buffer.AsMemory(0, packetLength); var remoteAddress = new IPAddress(packet[12..16].Span); long timestamp = BitConverter.ToInt64(packet.Slice(28, sizeof(long)).Span); Console.WriteLine( diff --git a/ReadMe.md b/ReadMe.md index 6129734..1765ae2 100644 --- a/ReadMe.md +++ b/ReadMe.md @@ -1,5 +1,27 @@ # Divert.Windows -WinDivert .NET APIs. +High quality .NET APIs for WinDivert. See https://reqrypt.org/windivert.html. + +# Example + +```csharp +using var client = new UdpClient(new IPEndPoint(IPAddress.Loopback, 0)); +var localEndPoint = (IPEndPoint)client.Client.LocalEndPoint!; +int port = localEndPoint.Port; +Console.WriteLine($"Created UDP client on port {port}."); + +using var service = new DivertService(DivertFilter.UDP & DivertFilter.LocalPort == port); +var buffer = new byte[1024]; +var receive = service.ReceiveAsync(buffer, new DivertAddress[1]).AsTask(); + +Console.WriteLine("Sending packet to self..."); +await client.SendAsync("Hello"u8.ToArray(), localEndPoint); + +(int packetLength, _) = await receive; +string message = Encoding.UTF8.GetString(buffer.AsSpan(28, packetLength)); +Console.WriteLine($"{message} from WinDivert!"); +``` + +See also [Examples/](Examples/) From ee41e72f7f8ce52452e78426d4e07157d87aec9f Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 9 Nov 2025 14:36:59 +0000 Subject: [PATCH 13/20] Update license --- LICENSE | 186 ++++++++++++++++++++++++++++++++++++++++++++++++------ ReadMe.md | 4 ++ 2 files changed, 169 insertions(+), 21 deletions(-) diff --git a/LICENSE b/LICENSE index 646bbf8..0a04128 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,165 @@ -MIT License - -Copyright (c) 2022 gdlol - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/ReadMe.md b/ReadMe.md index 1765ae2..d3714d7 100644 --- a/ReadMe.md +++ b/ReadMe.md @@ -25,3 +25,7 @@ Console.WriteLine($"{message} from WinDivert!"); ``` See also [Examples/](Examples/) + +# License + +[LGPL-3.0](LICENSE) From 94c88552e50b5cb814efe915a913993ecbb145eb Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 9 Nov 2025 15:00:44 +0000 Subject: [PATCH 14/20] Pack --- .config/dotnet/Packages.props | 2 ++ Automation/Automation.csproj | 1 + Automation/Context.cs | 2 ++ Automation/Pack.cs | 38 ++++++++++++++++++++++++++++ Divert.Windows/Divert.Windows.csproj | 10 ++++++++ package.json | 1 + 6 files changed, 54 insertions(+) create mode 100644 Automation/Pack.cs diff --git a/.config/dotnet/Packages.props b/.config/dotnet/Packages.props index 8a46bc8..038ebbd 100644 --- a/.config/dotnet/Packages.props +++ b/.config/dotnet/Packages.props @@ -5,7 +5,9 @@ + + diff --git a/Automation/Automation.csproj b/Automation/Automation.csproj index 0771360..aa79bbb 100644 --- a/Automation/Automation.csproj +++ b/Automation/Automation.csproj @@ -6,6 +6,7 @@ + diff --git a/Automation/Context.cs b/Automation/Context.cs index 1c1e0f1..713ad16 100644 --- a/Automation/Context.cs +++ b/Automation/Context.cs @@ -15,4 +15,6 @@ public class Context(ICakeContext context) : FrostingContext(context) public static string LocalDirectory => Path.Combine(ProjectRoot, ".local"); public static string LocalWindowsDirectory => Path.Combine(LocalDirectory, "windows"); + + public static string PackagesDirectory => Path.Combine(LocalDirectory, "packages"); } diff --git a/Automation/Pack.cs b/Automation/Pack.cs new file mode 100644 index 0000000..15bcf74 --- /dev/null +++ b/Automation/Pack.cs @@ -0,0 +1,38 @@ +using Cake.Common.IO; +using Cake.Common.Tools.DotNet; +using Cake.Frosting; +using Git = LibGit2Sharp; + +namespace Automation; + +public class Pack : FrostingTask +{ + public override void Run(Context context) + { + context.CleanDirectory(Context.PackagesDirectory); + + using var repository = new Git.Repository(Context.ProjectRoot); + string authors = repository.Config.Get("user.name").Value; + + context.DotNetPack( + Path.Combine(Context.ProjectRoot, "Divert.Windows"), + new() + { + MSBuildSettings = new() + { + Properties = + { + ["ReadMePath"] = [Path.Combine(Context.ProjectRoot, "ReadMe.md")], + ["PackageOutputPath"] = [Context.PackagesDirectory], + ["Authors"] = [authors], + ["PackageDescription"] = ["High quality .NET APIs for WinDivert."], + ["PackageLicenseExpression"] = ["LGPL-3.0-only"], + ["PackageRequireLicenseAcceptance"] = ["true"], + ["PackageTags"] = ["WinDivert divert networking packet capture"], + ["PackageReadmeFile"] = ["ReadMe.md"], + }, + }, + } + ); + } +} diff --git a/Divert.Windows/Divert.Windows.csproj b/Divert.Windows/Divert.Windows.csproj index fe9a6e3..f0edfcd 100644 --- a/Divert.Windows/Divert.Windows.csproj +++ b/Divert.Windows/Divert.Windows.csproj @@ -19,4 +19,14 @@ + + + + all + + + + + + diff --git a/package.json b/package.json index 72ff098..ac8998f 100644 --- a/package.json +++ b/package.json @@ -6,6 +6,7 @@ "cake": "dotnet run --project Automation --target", "format": "pnpm run cake Format", "lint": "pnpm run cake Lint", + "pack": "pnpm run cake Pack", "restore": "pnpm run cake Restore", "test": "pnpm run cake Test", "test-report": "pnpm run cake TestReport" From a1a2f969ebc3fdd98ee833a5bce90786438a39fc Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 9 Nov 2025 15:39:45 +0000 Subject: [PATCH 15/20] CI config --- .config/git/ignore | 1 + .devcontainer/devcontainer.json | 3 +- .github/workflows/main.yml | 87 ++++++++++++++++++++++++++++ Automation/Build.cs | 12 +++- Automation/GitPush.cs | 36 ++++++++++++ Automation/Publish.cs | 27 +++++++++ Divert.Windows.TestRunner/Program.cs | 9 +-- Divert.Windows/Divert.Windows.csproj | 6 +- package.json | 1 + 9 files changed, 171 insertions(+), 11 deletions(-) create mode 100644 .github/workflows/main.yml create mode 100644 Automation/GitPush.cs create mode 100644 Automation/Publish.cs diff --git a/.config/git/ignore b/.config/git/ignore index d1c0ab2..b2bcadf 100644 --- a/.config/git/ignore +++ b/.config/git/ignore @@ -6,3 +6,4 @@ _* !**/.config/** node_modules/ +!.github/ diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 17cb4db..0667149 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -18,7 +18,8 @@ "ms-azuretools.vscode-docker", "streetsidesoftware.code-spell-checker", "ms-dotnettools.csharp", - "csharpier.csharpier-vscode" + "csharpier.csharpier-vscode", + "github.vscode-github-actions" ], "settings": { "files.associations": { diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..34d44cf --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,87 @@ +name: "Main" + +on: + pull_request: + push: + branches: + - main + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout (GitHub) + uses: actions/checkout@v4 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ github.token }} + + - name: Pre-build dev container image + uses: devcontainers/ci@v0.3 + with: + imageName: ghcr.io/${{ github.repository }}/devcontainer + cacheFrom: ghcr.io/${{ github.repository }}/devcontainer + push: always + runCmd: pnpm restore + + - name: Lint + uses: devcontainers/ci@v0.3 + with: + cacheFrom: ghcr.io/${{ github.repository }}/devcontainer + push: never + runCmd: pnpm lint + + - name: Build + uses: devcontainers/ci@v0.3 + env: + DivertWindowsTests: "true" + with: + cacheFrom: ghcr.io/${{ github.repository }}/devcontainer + push: never + runCmd: pnpm build + + - name: Upload Build Artifacts + uses: actions/upload-artifact@v4 + with: + name: build-artifacts + path: | + .local/windows/Divert.Windows.TestRunner + .local/windows/Divert.Windows.Tests + + test: + needs: build + runs-on: windows-latest + steps: + - name: Download Build Artifacts + uses: actions/download-artifact@v4 + with: + name: build-artifacts + + - name: Run Tests + run: .local/windows/Divert.Windows.TestRunner/Divert.Windows.TestRunner.exe + + - name: Upload Test Results + uses: actions/upload-artifact@v4 + with: + name: test-results + path: .local/windows/TestResults + + code-coverage: + needs: test + runs-on: ubuntu-latest + steps: + - name: Download Test Results + uses: actions/download-artifact@v4 + with: + name: test-results + + - name: Codecov + uses: codecov/codecov-action@v5 + with: + files: .local/windows/TestResults/coverage.cobertura.xml + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/Automation/Build.cs b/Automation/Build.cs index 8b54cb9..05e38cc 100644 --- a/Automation/Build.cs +++ b/Automation/Build.cs @@ -12,7 +12,17 @@ public override void Run(Context context) context.CleanDirectory(Context.LocalWindowsDirectory); context.DotNetBuild( Context.ProjectRoot, - new() { MSBuildSettings = new() { TreatAllWarningsAs = MSBuildTreatAllWarningsAs.Error } } + new() + { + MSBuildSettings = new() + { + TreatAllWarningsAs = MSBuildTreatAllWarningsAs.Error, + Properties = + { + ["DivertWindowsTests"] = ["true"], // Enable DivertValueTaskExecutorDelay + }, + }, + } ); } } diff --git a/Automation/GitPush.cs b/Automation/GitPush.cs new file mode 100644 index 0000000..a5c07fe --- /dev/null +++ b/Automation/GitPush.cs @@ -0,0 +1,36 @@ +using Cake.Frosting; +using LibGit2Sharp; +using Git = LibGit2Sharp; + +namespace Automation; + +public class GitPush : FrostingTask +{ + private const string GIT_TOKEN = nameof(GIT_TOKEN); + + public override void Run(Context context) + { + string token = + Environment.GetEnvironmentVariable(GIT_TOKEN) + ?? throw new InvalidOperationException($"Environment variable {GIT_TOKEN} is not set."); + using var repo = new Git.Repository(Context.ProjectRoot); + string currentBranch = repo.Head.FriendlyName; + + var remote = repo.Network.Remotes["origin"] ?? throw new InvalidOperationException(); + var pushRefSpec = $"+refs/heads/{currentBranch}:refs/heads/{currentBranch}"; + PushStatusError? error = null; + var options = new Git.PushOptions + { + CredentialsProvider = (_, _, _) => + new Git.UsernamePasswordCredentials { Username = "git", Password = token }, + OnPushStatusError = (pushStatusErrors) => error = pushStatusErrors, + }; + repo.Network.Push(remote, pushRefSpec, options); + if (error is not null) + { + throw new InvalidOperationException( + $"Error pushing to remote. Reference: {error.Reference}, Message: {error.Message}" + ); + } + } +} diff --git a/Automation/Publish.cs b/Automation/Publish.cs new file mode 100644 index 0000000..9046b3a --- /dev/null +++ b/Automation/Publish.cs @@ -0,0 +1,27 @@ +using Cake.Common.Tools.DotNet; +using Cake.Common.Tools.DotNet.NuGet.Push; +using Cake.Frosting; + +namespace Automation; + +[IsDependentOn(typeof(Pack))] +public class Publish : FrostingTask +{ + public override void Run(Context context) + { + string package = Directory.GetFiles(Context.PackagesDirectory).Single(); + string apiKey = + Environment.GetEnvironmentVariable("NUGET_API_KEY") + ?? throw new InvalidOperationException("NUGET_API_KEY is not set."); + string source = Environment.GetEnvironmentVariable("NUGET_SOURCE") ?? "https://api.nuget.org/v3/index.json"; + context.DotNetNuGetPush( + package, + new DotNetNuGetPushSettings + { + Source = source, + ApiKey = apiKey, + SkipDuplicate = true, + } + ); + } +} diff --git a/Divert.Windows.TestRunner/Program.cs b/Divert.Windows.TestRunner/Program.cs index 43b0de9..72cbf1f 100644 --- a/Divert.Windows.TestRunner/Program.cs +++ b/Divert.Windows.TestRunner/Program.cs @@ -38,7 +38,7 @@ Process LaunchTestProcess(bool redirect) => Environment = { // MSTest should respect DOTNET_SYSTEM_CONSOLE_ALLOW_ANSI_COLOR_REDIRECTION instead. - ["GITHUB_ACTIONS"] = "true", // Trick MSTest to output ANSI colors. + ["GITHUB_ACTIONS"] = redirect.ToString().ToLowerInvariant(), // Trick MSTest to output ANSI colors. }, } )!; @@ -96,9 +96,6 @@ Process LaunchTestProcess(bool redirect) => using var _ = sessionToken.Register(() => process.Kill(entireProcessTree: true)); var lines = Channel.CreateBounded(1024); - var stdOutReader = new StreamReader(process.StandardOutput.BaseStream); - var stdErrReader = new StreamReader(process.StandardError.BaseStream); - async Task ForwardLines(StreamReader reader) { string? line = null; @@ -112,8 +109,8 @@ async Task ForwardLines(StreamReader reader) await lines.Writer.WriteAsync(line, sessionToken); } } - var readStdOutTask = ForwardLines(stdOutReader); - var readStdErrTask = ForwardLines(stdErrReader); + var readStdOutTask = ForwardLines(process.StandardOutput); + var readStdErrTask = ForwardLines(process.StandardError); using var writer = new StreamWriter(stream) { AutoFlush = true }; var readTask = Task.Run( diff --git a/Divert.Windows/Divert.Windows.csproj b/Divert.Windows/Divert.Windows.csproj index f0edfcd..4f68965 100644 --- a/Divert.Windows/Divert.Windows.csproj +++ b/Divert.Windows/Divert.Windows.csproj @@ -5,7 +5,7 @@ true - + $(DefineConstants);DIVERT_WINDOWS_TESTS @@ -21,12 +21,12 @@ - + all - + diff --git a/package.json b/package.json index ac8998f..14ab355 100644 --- a/package.json +++ b/package.json @@ -5,6 +5,7 @@ "build": "pnpm run cake Build", "cake": "dotnet run --project Automation --target", "format": "pnpm run cake Format", + "git-push": "pnpm run cake GitPush", "lint": "pnpm run cake Lint", "pack": "pnpm run cake Pack", "restore": "pnpm run cake Restore", From 3e2efc48f2ddbf0af85939b634c537d458bf88df Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 9 Nov 2025 16:39:44 +0000 Subject: [PATCH 16/20] Fix image name. --- .github/workflows/main.yml | 13 +++++++++---- Automation/GitPush.cs | 8 +++----- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 34d44cf..a1fc67c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -10,7 +10,12 @@ on: jobs: build: runs-on: ubuntu-latest + env: + GITHUB_REPOSITORY: ${{ github.repository }} steps: + - name: Set lowercase repository name + run: echo "GITHUB_REPOSITORY=${GITHUB_REPOSITORY@L}" >> $GITHUB_ENV + - name: Checkout (GitHub) uses: actions/checkout@v4 @@ -24,15 +29,15 @@ jobs: - name: Pre-build dev container image uses: devcontainers/ci@v0.3 with: - imageName: ghcr.io/${{ github.repository }}/devcontainer - cacheFrom: ghcr.io/${{ github.repository }}/devcontainer + imageName: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer + cacheFrom: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer push: always runCmd: pnpm restore - name: Lint uses: devcontainers/ci@v0.3 with: - cacheFrom: ghcr.io/${{ github.repository }}/devcontainer + cacheFrom: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer push: never runCmd: pnpm lint @@ -41,7 +46,7 @@ jobs: env: DivertWindowsTests: "true" with: - cacheFrom: ghcr.io/${{ github.repository }}/devcontainer + cacheFrom: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer push: never runCmd: pnpm build diff --git a/Automation/GitPush.cs b/Automation/GitPush.cs index a5c07fe..18248f4 100644 --- a/Automation/GitPush.cs +++ b/Automation/GitPush.cs @@ -1,6 +1,5 @@ using Cake.Frosting; using LibGit2Sharp; -using Git = LibGit2Sharp; namespace Automation; @@ -13,16 +12,15 @@ public override void Run(Context context) string token = Environment.GetEnvironmentVariable(GIT_TOKEN) ?? throw new InvalidOperationException($"Environment variable {GIT_TOKEN} is not set."); - using var repo = new Git.Repository(Context.ProjectRoot); + using var repo = new Repository(Context.ProjectRoot); string currentBranch = repo.Head.FriendlyName; var remote = repo.Network.Remotes["origin"] ?? throw new InvalidOperationException(); var pushRefSpec = $"+refs/heads/{currentBranch}:refs/heads/{currentBranch}"; PushStatusError? error = null; - var options = new Git.PushOptions + var options = new PushOptions { - CredentialsProvider = (_, _, _) => - new Git.UsernamePasswordCredentials { Username = "git", Password = token }, + CredentialsProvider = (_, _, _) => new UsernamePasswordCredentials { Username = "git", Password = token }, OnPushStatusError = (pushStatusErrors) => error = pushStatusErrors, }; repo.Network.Push(remote, pushRefSpec, options); From c754cb0b1d0fd55fd10fac0a7b6416fbbc7a1110 Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 9 Nov 2025 16:50:08 +0000 Subject: [PATCH 17/20] update CI config --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a1fc67c..f88defb 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -68,7 +68,7 @@ jobs: name: build-artifacts - name: Run Tests - run: .local/windows/Divert.Windows.TestRunner/Divert.Windows.TestRunner.exe + run: ./.local/windows/Divert.Windows.TestRunner/Divert.Windows.TestRunner.exe - name: Upload Test Results uses: actions/upload-artifact@v4 From 6de85a7f4a2bfcff2486b096febde2754feb1e76 Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 9 Nov 2025 16:57:52 +0000 Subject: [PATCH 18/20] update CI config --- .github/workflows/main.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f88defb..8aa1178 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -43,8 +43,6 @@ jobs: - name: Build uses: devcontainers/ci@v0.3 - env: - DivertWindowsTests: "true" with: cacheFrom: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer push: never @@ -66,9 +64,10 @@ jobs: uses: actions/download-artifact@v4 with: name: build-artifacts + path: .local/windows - name: Run Tests - run: ./.local/windows/Divert.Windows.TestRunner/Divert.Windows.TestRunner.exe + run: .local/windows/Divert.Windows.TestRunner/Divert.Windows.TestRunner.exe - name: Upload Test Results uses: actions/upload-artifact@v4 @@ -84,6 +83,7 @@ jobs: uses: actions/download-artifact@v4 with: name: test-results + path: .local/windows - name: Codecov uses: codecov/codecov-action@v5 From 33a0637e56b92cca67a8729363bd97f74bf5288a Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 9 Nov 2025 17:03:53 +0000 Subject: [PATCH 19/20] update CI config --- .github/workflows/main.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8aa1178..65a281b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -16,7 +16,7 @@ jobs: - name: Set lowercase repository name run: echo "GITHUB_REPOSITORY=${GITHUB_REPOSITORY@L}" >> $GITHUB_ENV - - name: Checkout (GitHub) + - name: Checkout uses: actions/checkout@v4 - name: Login to GitHub Container Registry @@ -79,6 +79,9 @@ jobs: needs: test runs-on: ubuntu-latest steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Download Test Results uses: actions/download-artifact@v4 with: @@ -88,5 +91,6 @@ jobs: - name: Codecov uses: codecov/codecov-action@v5 with: + fail_ci_if_error: true files: .local/windows/TestResults/coverage.cobertura.xml token: ${{ secrets.CODECOV_TOKEN }} From fe1aed37a44e5d7af2d7287c8900861ba32f7ce8 Mon Sep 17 00:00:00 2001 From: gdlol Date: Sun, 9 Nov 2025 17:12:30 +0000 Subject: [PATCH 20/20] update CI config --- .github/workflows/main.yml | 9 ++++----- .github/workflows/publish.yml | 29 +++++++++++++++++++++++++++++ package.json | 1 + 3 files changed, 34 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/publish.yml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 65a281b..343d2b8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -72,8 +72,8 @@ jobs: - name: Upload Test Results uses: actions/upload-artifact@v4 with: - name: test-results - path: .local/windows/TestResults + name: test-coverage + path: .local/windows/TestResults/coverage.cobertura.xml code-coverage: needs: test @@ -85,12 +85,11 @@ jobs: - name: Download Test Results uses: actions/download-artifact@v4 with: - name: test-results - path: .local/windows + name: test-coverage - name: Codecov uses: codecov/codecov-action@v5 with: fail_ci_if_error: true - files: .local/windows/TestResults/coverage.cobertura.xml + files: coverage.cobertura.xml token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..2068cbd --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,29 @@ +name: "Publish" + +on: + push: + tags: + - "[0-9]+.[0-9]+.[0-9]+" + workflow_dispatch: + +jobs: + publish: + runs-on: ubuntu-latest + env: + GITHUB_REPOSITORY: ${{ github.repository }} + steps: + - name: Set lowercase repository name + run: echo "GITHUB_REPOSITORY=${GITHUB_REPOSITORY@L}" >> $GITHUB_ENV + + - name: Checkout + uses: actions/checkout@v4 + + - name: Publish + uses: devcontainers/ci@v0.3 + env: + NUGET_API_KEY: ${{ secrets.NUGET_API_KEY }} + with: + imageName: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer + cacheFrom: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer + push: never + runCmd: pnpm run publish diff --git a/package.json b/package.json index 14ab355..dba536a 100644 --- a/package.json +++ b/package.json @@ -8,6 +8,7 @@ "git-push": "pnpm run cake GitPush", "lint": "pnpm run cake Lint", "pack": "pnpm run cake Pack", + "publish": "pnpm run cake Publish", "restore": "pnpm run cake Restore", "test": "pnpm run cake Test", "test-report": "pnpm run cake TestReport"