diff --git a/.config/csharpier/.csharpierignore b/.config/csharpier/.csharpierignore new file mode 100644 index 0000000..51045e2 --- /dev/null +++ b/.config/csharpier/.csharpierignore @@ -0,0 +1,2 @@ +**/* +!**/*.cs diff --git a/.config/csharpier/.csharpierrc.json b/.config/csharpier/.csharpierrc.json new file mode 100644 index 0000000..963354f --- /dev/null +++ b/.config/csharpier/.csharpierrc.json @@ -0,0 +1,3 @@ +{ + "printWidth": 120 +} diff --git a/.config/cspell/cspell.json b/.config/cspell/cspell.json new file mode 100644 index 0000000..27b7028 --- /dev/null +++ b/.config/cspell/cspell.json @@ -0,0 +1,30 @@ +{ + "version": "0.2", + "enableGlobDot": true, + "useGitignore": true, + "gitignoreRoot": ".", + "ignorePaths": ["LICENSE"], + "words": [ + "CIFS", + "csharpierignore", + "csharpierrc", + "devcontainer", + "devcontainers", + "Finalizers", + "getgid", + "getuid", + "globalconfig", + "globaltool", + "ICMPV6", + "libc", + "msbuild", + "MSTEST", + "packagejson", + "reportgenerator", + "runas", + "runsettings", + "statvfs", + "Syscall", + "WINDIVERT" + ] +} diff --git a/.config/dotnet/.globalconfig b/.config/dotnet/.globalconfig new file mode 100644 index 0000000..9cd5137 --- /dev/null +++ b/.config/dotnet/.globalconfig @@ -0,0 +1,3 @@ +dotnet_diagnostic.IDE0005.severity=warning +dotnet_diagnostic.CA1852.severity=warning +dotnet_diagnostic.CA2007.severity=warning diff --git a/.config/dotnet/Format.targets b/.config/dotnet/Format.targets new file mode 100644 index 0000000..ddcc8d5 --- /dev/null +++ b/.config/dotnet/Format.targets @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/.config/dotnet/Packages.props b/.config/dotnet/Packages.props new file mode 100644 index 0000000..038ebbd --- /dev/null +++ b/.config/dotnet/Packages.props @@ -0,0 +1,15 @@ + + + true + + + + + + + + + + + + diff --git a/.config/dotnet/Project.props b/.config/dotnet/Project.props new file mode 100644 index 0000000..a384f84 --- /dev/null +++ b/.config/dotnet/Project.props @@ -0,0 +1,22 @@ + + + enable + enable + true + preview + + + + + + + + + + true + true + embedded + true + $(MSBuildProjectDirectory)=/_/$(MSBuildProjectName) + + diff --git a/.config/dotnet/tools.json b/.config/dotnet/tools.json new file mode 100644 index 0000000..20f019c --- /dev/null +++ b/.config/dotnet/tools.json @@ -0,0 +1,16 @@ +{ + "version": 1, + "isRoot": true, + "tools": { + "csharpier": { + "version": "1.1.2", + "commands": ["csharpier"], + "rollForward": false + }, + "dotnet-reportgenerator-globaltool": { + "version": "5.4.18", + "commands": ["reportgenerator"], + "rollForward": false + } + } +} diff --git a/.gitattributes b/.config/git/attributes similarity index 100% rename from .gitattributes rename to .config/git/attributes diff --git a/.config/git/ignore b/.config/git/ignore new file mode 100644 index 0000000..b2bcadf --- /dev/null +++ b/.config/git/ignore @@ -0,0 +1,9 @@ +.* +_* +!.devcontainer/ +!**/.devcontainer/** +!.config/ +!**/.config/** + +node_modules/ +!.github/ diff --git a/.config/pnpm/rc b/.config/pnpm/rc new file mode 100644 index 0000000..863fba9 --- /dev/null +++ b/.config/pnpm/rc @@ -0,0 +1,3 @@ +lockfile=false +resolution-mode=time-based +store-dir=/home/dev/.local/share/pnpm/store diff --git a/.config/prettier/.prettierrc.json b/.config/prettier/.prettierrc.json new file mode 100644 index 0000000..b8ae6eb --- /dev/null +++ b/.config/prettier/.prettierrc.json @@ -0,0 +1,9 @@ +{ + "printWidth": 120, + "plugins": ["prettier-plugin-packagejson", "prettier-plugin-sh", "@prettier/plugin-xml", "prettier-plugin-ini"], + "xmlWhitespaceSensitivity": "ignore", + "overrides": [ + { "files": "app.manifest", "options": { "parser": "xml" } }, + { "files": "*.globalconfig", "options": { "parser": "ini" } } + ] +} diff --git a/.config/workspaces/Directory.Build.props b/.config/workspaces/Directory.Build.props new file mode 100644 index 0000000..b7d1bef --- /dev/null +++ b/.config/workspaces/Directory.Build.props @@ -0,0 +1,9 @@ + + + $(MSBuildThisFileDirectory)artifacts + + + + + + diff --git a/.config/workspaces/pnpm-workspace.yaml b/.config/workspaces/pnpm-workspace.yaml new file mode 100644 index 0000000..e69de29 diff --git a/.devcontainer/.env b/.devcontainer/.env new file mode 100644 index 0000000..9ae87c3 --- /dev/null +++ b/.devcontainer/.env @@ -0,0 +1,8 @@ +WORKSPACES=/workspaces +XDG_CONFIG_HOME=/home/dev/.config +XDG_CACHE_HOME=/home/dev/.cache +XDG_DATA_HOME=/home/dev/.local/share +XDG_STATE_HOME=/home/dev/.local/state +XDG_DATA_DIRS=/usr/local/share:/usr/share +XDG_CONFIG_DIRS=/etc/xdg +NUGET_PACKAGES=/home/dev/.local/share/NuGet/global-packages diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000..4c27707 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,6 @@ +FROM mcr.microsoft.com/devcontainers/javascript-node:22 + +RUN apt-get update && apt-get install --yes \ + cifs-utils + +RUN npm install --global pnpm@latest-10 diff --git a/.devcontainer/compose.yaml b/.devcontainer/compose.yaml new file mode 100644 index 0000000..13d9f6b --- /dev/null +++ b/.devcontainer/compose.yaml @@ -0,0 +1,24 @@ +services: + devcontainer: + env_file: + - .env + - path: ../.local/.env + required: false + build: + context: . + dockerfile: Dockerfile + init: true + privileged: true + volumes: + - WORKSPACES:${WORKSPACES} + - ..:${WORKSPACES}/divert-windows + - XDG_CONFIG_HOME:${XDG_CONFIG_HOME} + - XDG_CACHE_HOME:${XDG_CACHE_HOME} + - XDG_DATA_HOME:${XDG_DATA_HOME} + - XDG_STATE_HOME:${XDG_STATE_HOME} +volumes: + WORKSPACES: + XDG_CONFIG_HOME: + XDG_CACHE_HOME: + XDG_DATA_HOME: + XDG_STATE_HOME: diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..0667149 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,41 @@ +// spell-checker:ignore esbenp dotnettools +{ + "name": "divert-windows", + "dockerComposeFile": "compose.yaml", + "service": "devcontainer", + "remoteUser": "dev", + "overrideCommand": true, + "workspaceFolder": "/workspaces/divert-windows", + "features": { + "ghcr.io/devcontainer-config/features/user-init:2": {}, + "ghcr.io/devcontainer-config/features/dot-config:3": {}, + "ghcr.io/devcontainers/features/dotnet:2": { "version": "8.0" } + }, + "customizations": { + "vscode": { + "extensions": [ + "esbenp.prettier-vscode", + "ms-azuretools.vscode-docker", + "streetsidesoftware.code-spell-checker", + "ms-dotnettools.csharp", + "csharpier.csharpier-vscode", + "github.vscode-github-actions" + ], + "settings": { + "files.associations": { + "ignore": "ignore", + "attributes": "properties", + "rc": "properties", + "*.globalconfig": "ini", + "app.manifest": "xml" + }, + "editor.formatOnSave": true, + "editor.defaultFormatter": "esbenp.prettier-vscode", + "cSpell.autoFormatConfigFile": true, + "cSpell.checkOnlyEnabledFileTypes": false, + "[csharp]": { "editor.defaultFormatter": "csharpier.csharpier-vscode" } + } + } + }, + "onCreateCommand": "pnpm install && pnpm restore || true" +} diff --git a/.devcontainer/dot-config.json b/.devcontainer/dot-config.json new file mode 100644 index 0000000..cda4619 --- /dev/null +++ b/.devcontainer/dot-config.json @@ -0,0 +1,15 @@ +{ + "git": { "attributes": "/home/dev/.config/git/attributes", "ignore": "/home/dev/.config/git/ignore" }, + "pnpm": { "rc": "/home/dev/.config/pnpm/rc" }, + "prettier": { ".prettierrc.json": ".prettierrc.json" }, + "cspell": { "cspell.json": "cspell.json" }, + "workspaces": { + "../git/attributes": ".gitattributes", + "../git/ignore": ".gitignore", + "../../package.json": "package.json", + "pnpm-workspace.yaml": "pnpm-workspace.yaml", + "Directory.Build.props": "Directory.Build.props" + }, + "csharpier": { ".csharpierrc.json": ".csharpierrc.json", ".csharpierignore": ".csharpierignore" }, + "dotnet": { "tools.json": ".config/dotnet-tools.json" } +} diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..343d2b8 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,95 @@ +name: "Main" + +on: + pull_request: + push: + branches: + - main + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + env: + GITHUB_REPOSITORY: ${{ github.repository }} + steps: + - name: Set lowercase repository name + run: echo "GITHUB_REPOSITORY=${GITHUB_REPOSITORY@L}" >> $GITHUB_ENV + + - name: Checkout + uses: actions/checkout@v4 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ github.token }} + + - name: Pre-build dev container image + uses: devcontainers/ci@v0.3 + with: + imageName: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer + cacheFrom: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer + push: always + runCmd: pnpm restore + + - name: Lint + uses: devcontainers/ci@v0.3 + with: + cacheFrom: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer + push: never + runCmd: pnpm lint + + - name: Build + uses: devcontainers/ci@v0.3 + with: + cacheFrom: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer + push: never + runCmd: pnpm build + + - name: Upload Build Artifacts + uses: actions/upload-artifact@v4 + with: + name: build-artifacts + path: | + .local/windows/Divert.Windows.TestRunner + .local/windows/Divert.Windows.Tests + + test: + needs: build + runs-on: windows-latest + steps: + - name: Download Build Artifacts + uses: actions/download-artifact@v4 + with: + name: build-artifacts + path: .local/windows + + - name: Run Tests + run: .local/windows/Divert.Windows.TestRunner/Divert.Windows.TestRunner.exe + + - name: Upload Test Results + uses: actions/upload-artifact@v4 + with: + name: test-coverage + path: .local/windows/TestResults/coverage.cobertura.xml + + code-coverage: + needs: test + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Download Test Results + uses: actions/download-artifact@v4 + with: + name: test-coverage + + - name: Codecov + uses: codecov/codecov-action@v5 + with: + fail_ci_if_error: true + files: coverage.cobertura.xml + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..2068cbd --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,29 @@ +name: "Publish" + +on: + push: + tags: + - "[0-9]+.[0-9]+.[0-9]+" + workflow_dispatch: + +jobs: + publish: + runs-on: ubuntu-latest + env: + GITHUB_REPOSITORY: ${{ github.repository }} + steps: + - name: Set lowercase repository name + run: echo "GITHUB_REPOSITORY=${GITHUB_REPOSITORY@L}" >> $GITHUB_ENV + + - name: Checkout + uses: actions/checkout@v4 + + - name: Publish + uses: devcontainers/ci@v0.3 + env: + NUGET_API_KEY: ${{ secrets.NUGET_API_KEY }} + with: + imageName: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer + cacheFrom: ghcr.io/${{ env.GITHUB_REPOSITORY }}/devcontainer + push: never + runCmd: pnpm run publish diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 2246bd7..0000000 --- a/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -.*/ -obj/ -bin/ -/Build/ -/Publish/ diff --git a/Automation/Automation.csproj b/Automation/Automation.csproj new file mode 100644 index 0000000..aa79bbb --- /dev/null +++ b/Automation/Automation.csproj @@ -0,0 +1,12 @@ + + + Exe + $(DefaultTargetFramework) + CA2007 + + + + + + + diff --git a/Automation/Build.cs b/Automation/Build.cs new file mode 100644 index 0000000..05e38cc --- /dev/null +++ b/Automation/Build.cs @@ -0,0 +1,28 @@ +using Cake.Common.IO; +using Cake.Common.Tools.DotNet; +using Cake.Common.Tools.DotNet.MSBuild; +using Cake.Frosting; + +namespace Automation; + +public class Build : FrostingTask +{ + public override void Run(Context context) + { + context.CleanDirectory(Context.LocalWindowsDirectory); + context.DotNetBuild( + Context.ProjectRoot, + new() + { + MSBuildSettings = new() + { + TreatAllWarningsAs = MSBuildTreatAllWarningsAs.Error, + Properties = + { + ["DivertWindowsTests"] = ["true"], // Enable DivertValueTaskExecutorDelay + }, + }, + } + ); + } +} diff --git a/Automation/Context.cs b/Automation/Context.cs new file mode 100644 index 0000000..713ad16 --- /dev/null +++ b/Automation/Context.cs @@ -0,0 +1,20 @@ +using System.Runtime.CompilerServices; +using Cake.Core; +using Cake.Frosting; + +namespace Automation; + +public class Context(ICakeContext context) : FrostingContext(context) +{ + static string GetFilePath([CallerFilePath] string? path = null) => path!; + + public static string ProjectRoot => new FileInfo(GetFilePath()).Directory!.Parent!.FullName; + + public static string Workspaces => new DirectoryInfo(ProjectRoot).Parent!.FullName; + + public static string LocalDirectory => Path.Combine(ProjectRoot, ".local"); + + public static string LocalWindowsDirectory => Path.Combine(LocalDirectory, "windows"); + + public static string PackagesDirectory => Path.Combine(LocalDirectory, "packages"); +} diff --git a/Automation/Format.cs b/Automation/Format.cs new file mode 100644 index 0000000..96bee98 --- /dev/null +++ b/Automation/Format.cs @@ -0,0 +1,9 @@ +using Automation.Tasks; +using Cake.Frosting; + +namespace Automation; + +[IsDependentOn(typeof(PrettierFormat))] +[IsDependentOn(typeof(DotNetFormat))] +[IsDependentOn(typeof(CSharpierFormat))] +public class Format : FrostingTask; diff --git a/Automation/GitPush.cs b/Automation/GitPush.cs new file mode 100644 index 0000000..18248f4 --- /dev/null +++ b/Automation/GitPush.cs @@ -0,0 +1,34 @@ +using Cake.Frosting; +using LibGit2Sharp; + +namespace Automation; + +public class GitPush : FrostingTask +{ + private const string GIT_TOKEN = nameof(GIT_TOKEN); + + public override void Run(Context context) + { + string token = + Environment.GetEnvironmentVariable(GIT_TOKEN) + ?? throw new InvalidOperationException($"Environment variable {GIT_TOKEN} is not set."); + using var repo = new Repository(Context.ProjectRoot); + string currentBranch = repo.Head.FriendlyName; + + var remote = repo.Network.Remotes["origin"] ?? throw new InvalidOperationException(); + var pushRefSpec = $"+refs/heads/{currentBranch}:refs/heads/{currentBranch}"; + PushStatusError? error = null; + var options = new PushOptions + { + CredentialsProvider = (_, _, _) => new UsernamePasswordCredentials { Username = "git", Password = token }, + OnPushStatusError = (pushStatusErrors) => error = pushStatusErrors, + }; + repo.Network.Push(remote, pushRefSpec, options); + if (error is not null) + { + throw new InvalidOperationException( + $"Error pushing to remote. Reference: {error.Reference}, Message: {error.Message}" + ); + } + } +} diff --git a/Automation/Lint.cs b/Automation/Lint.cs new file mode 100644 index 0000000..5ebf948 --- /dev/null +++ b/Automation/Lint.cs @@ -0,0 +1,10 @@ +using Automation.Tasks; +using Cake.Frosting; + +namespace Automation; + +[IsDependentOn(typeof(PrettierCheck))] +[IsDependentOn(typeof(DotNetFormatCheck))] +[IsDependentOn(typeof(CSharpierCheck))] +[IsDependentOn(typeof(CSpell))] +public class Lint : FrostingTask; diff --git a/Automation/Pack.cs b/Automation/Pack.cs new file mode 100644 index 0000000..15bcf74 --- /dev/null +++ b/Automation/Pack.cs @@ -0,0 +1,38 @@ +using Cake.Common.IO; +using Cake.Common.Tools.DotNet; +using Cake.Frosting; +using Git = LibGit2Sharp; + +namespace Automation; + +public class Pack : FrostingTask +{ + public override void Run(Context context) + { + context.CleanDirectory(Context.PackagesDirectory); + + using var repository = new Git.Repository(Context.ProjectRoot); + string authors = repository.Config.Get("user.name").Value; + + context.DotNetPack( + Path.Combine(Context.ProjectRoot, "Divert.Windows"), + new() + { + MSBuildSettings = new() + { + Properties = + { + ["ReadMePath"] = [Path.Combine(Context.ProjectRoot, "ReadMe.md")], + ["PackageOutputPath"] = [Context.PackagesDirectory], + ["Authors"] = [authors], + ["PackageDescription"] = ["High quality .NET APIs for WinDivert."], + ["PackageLicenseExpression"] = ["LGPL-3.0-only"], + ["PackageRequireLicenseAcceptance"] = ["true"], + ["PackageTags"] = ["WinDivert divert networking packet capture"], + ["PackageReadmeFile"] = ["ReadMe.md"], + }, + }, + } + ); + } +} diff --git a/Automation/Program.cs b/Automation/Program.cs new file mode 100644 index 0000000..f2fde84 --- /dev/null +++ b/Automation/Program.cs @@ -0,0 +1,5 @@ +using Automation; +using Cake.Frosting; + +Directory.SetCurrentDirectory(Context.Workspaces); +return new CakeHost().UseContext().Run(args.Concat(["--verbosity", "diagnostic"])); diff --git a/Automation/Publish.cs b/Automation/Publish.cs new file mode 100644 index 0000000..9046b3a --- /dev/null +++ b/Automation/Publish.cs @@ -0,0 +1,27 @@ +using Cake.Common.Tools.DotNet; +using Cake.Common.Tools.DotNet.NuGet.Push; +using Cake.Frosting; + +namespace Automation; + +[IsDependentOn(typeof(Pack))] +public class Publish : FrostingTask +{ + public override void Run(Context context) + { + string package = Directory.GetFiles(Context.PackagesDirectory).Single(); + string apiKey = + Environment.GetEnvironmentVariable("NUGET_API_KEY") + ?? throw new InvalidOperationException("NUGET_API_KEY is not set."); + string source = Environment.GetEnvironmentVariable("NUGET_SOURCE") ?? "https://api.nuget.org/v3/index.json"; + context.DotNetNuGetPush( + package, + new DotNetNuGetPushSettings + { + Source = source, + ApiKey = apiKey, + SkipDuplicate = true, + } + ); + } +} diff --git a/Automation/Publish/Program.cs b/Automation/Publish/Program.cs deleted file mode 100644 index 35d589f..0000000 --- a/Automation/Publish/Program.cs +++ /dev/null @@ -1,78 +0,0 @@ -using System.ComponentModel; -using System.Runtime.CompilerServices; -using Microsoft.DotNet.Cli.Utils; - -static string GetFilePath([CallerFilePath] string? path = null) -{ - if (path is null) - { - throw new InvalidOperationException(nameof(path)); - } - return path; -} - -static void Run(string commandName, params string[] args) -{ - var command = Command.Create(commandName, args); - Console.WriteLine($"{commandName} {command.CommandArgs}"); - var result = command.Execute(); - if (result.ExitCode != 0) - { - throw new Win32Exception(result.ExitCode); - } -} - -static string GetOutput(string commandName, params string[] args) -{ - var command = Command.Create(commandName, args).CaptureStdOut(); - var result = command.Execute(); - if (result.ExitCode != 0) - { - throw new Win32Exception(result.ExitCode); - } - return result.StdOut; -} - -string version = args.Length > 0 ? args[0] : "1.0.0"; - -string filePath = GetFilePath(); -string projectPath = new FileInfo(filePath).Directory?.Parent?.Parent?.FullName!; -Console.WriteLine($"Project path: {projectPath}"); - -string publishPath = Path.Combine(projectPath, "Publish"); -if (Directory.Exists(publishPath)) -{ - Directory.Delete(publishPath, recursive: true); -} - -// Get metadata from Git. -string userName = GetOutput("git", "config", "user.name").Trim(); -var remotes = GetOutput("git", "remote").Trim(); -string? repositoryUrl = remotes.Split('\n').FirstOrDefault() switch -{ - null or "" => null, - string remote => GetOutput("git", "remote", "get-url", remote.Trim()).Trim() -}; -Console.WriteLine($"{nameof(userName)}: {userName}"); -Console.WriteLine($"{nameof(repositoryUrl)}: {repositoryUrl}"); -string projectName = "Divert.Windows"; -string description = "WinDivert .NET APIs."; - -var arguments = new List -{ - "pack", - Path.Combine(projectPath, projectName, $"{projectName}.csproj"), - "--configuration", "Release", - "--output", publishPath, - $"-property:PackageVersion={version}", - $"-property:Authors={userName}", - $"-property:PackageDescription={description}", - "-property:PackageLicenseExpression=MIT", - "-property:PackageRequireLicenseAcceptance=true", - "-property:PackageTags=WinDivert", -}; -if (repositoryUrl is not null) -{ - arguments.Add($"-property:RepositoryUrl={repositoryUrl}"); -} -Run("dotnet", arguments.ToArray()); diff --git a/Automation/Publish/Publish.csproj b/Automation/Publish/Publish.csproj deleted file mode 100644 index 7dbbd88..0000000 --- a/Automation/Publish/Publish.csproj +++ /dev/null @@ -1,14 +0,0 @@ - - - - Exe - net6.0 - enable - enable - - - - - - - \ No newline at end of file diff --git a/Automation/Restore.cs b/Automation/Restore.cs new file mode 100644 index 0000000..58c1764 --- /dev/null +++ b/Automation/Restore.cs @@ -0,0 +1,16 @@ +using Automation.Tasks; +using Cake.Common.Tools.Command; +using Cake.Common.Tools.DotNet; +using Cake.Frosting; + +namespace Automation; + +[IsDependentOn(typeof(CIFSMount))] +public class Restore : FrostingTask +{ + public override void Run(Context context) + { + context.Command(["dotnet"], "tool restore"); + context.DotNetRestore(Context.ProjectRoot); + } +} diff --git a/Automation/Restore/Program.cs b/Automation/Restore/Program.cs deleted file mode 100644 index 7a665d5..0000000 --- a/Automation/Restore/Program.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System.ComponentModel; -using System.Runtime.CompilerServices; -using Microsoft.DotNet.Cli.Utils; - -static string GetFilePath([CallerFilePath] string? path = null) -{ - if (path is null) - { - throw new InvalidOperationException(nameof(path)); - } - return path; -} - -static void Run(string commandName, params string[] args) -{ - var command = Command.Create(commandName, args); - Console.WriteLine($"{commandName} {command.CommandArgs}"); - var result = command.Execute(); - if (result.ExitCode != 0) - { - throw new Win32Exception(result.ExitCode); - } -} - -string filePath = GetFilePath(); -string workspacePath = new FileInfo(filePath).Directory?.Parent?.Parent?.FullName!; -Console.WriteLine($"Workspace path: {workspacePath}"); - -Run("dotnet", "nuget", "locals", "http-cache", "--clear"); -Parallel.ForEach(Directory.EnumerateFiles(workspacePath, "*.csproj", SearchOption.AllDirectories), csProjectPath => -{ - Run("dotnet", "restore", csProjectPath); -}); diff --git a/Automation/Restore/Restore.csproj b/Automation/Restore/Restore.csproj deleted file mode 100644 index 7dbbd88..0000000 --- a/Automation/Restore/Restore.csproj +++ /dev/null @@ -1,14 +0,0 @@ - - - - Exe - net6.0 - enable - enable - - - - - - - \ No newline at end of file diff --git a/Automation/Tasks/CIFSMount.cs b/Automation/Tasks/CIFSMount.cs new file mode 100644 index 0000000..a4f59b9 --- /dev/null +++ b/Automation/Tasks/CIFSMount.cs @@ -0,0 +1,76 @@ +using Cake.Common.Diagnostics; +using Cake.Common.Tools.Command; +using Cake.Core; +using Cake.Core.Diagnostics; +using Cake.Core.IO; +using Cake.Frosting; +using Mono.Unix.Native; + +namespace Automation.Tasks; + +internal sealed record class CIFSMountConfig(string RemoteHost, string Share, string UserName, string Password); + +// Optionally mounts Bin/ directory to a CIFS share, workaround for slow .exe startup time in WSL2 directories. +public class CIFSMount : FrostingTask +{ + public const string CIFS_REMOTE_HOST = nameof(CIFS_REMOTE_HOST); + public const string CIFS_SHARE = nameof(CIFS_SHARE); + public const string CIFS_USERNAME = nameof(CIFS_USERNAME); + public const string CIFS_PASSWORD = nameof(CIFS_PASSWORD); + + private static CIFSMountConfig? LoadConfig() + { + string? remoteHost = Environment.GetEnvironmentVariable(CIFS_REMOTE_HOST); + string? share = Environment.GetEnvironmentVariable(CIFS_SHARE); + string? userName = Environment.GetEnvironmentVariable(CIFS_USERNAME); + string? password = Environment.GetEnvironmentVariable(CIFS_PASSWORD); + if (share is null || userName is null || password is null) + { + return null; + } + return new CIFSMountConfig( + RemoteHost: remoteHost ?? "host.docker.internal", + Share: share, + UserName: userName, + Password: password + ); + } + + public override void Run(Context context) + { + var config = LoadConfig(); + if (config is null) + { + context.Information("CIFS share not configured, skipping."); + return; + } + + string mountPath = Context.LocalWindowsDirectory; + try + { + if (DriveInfo.GetDrives().Any(drive => drive.DriveType is DriveType.Network && drive.Name == mountPath)) + { + context.Information("CIFS share already mounted at {0}, skipping.", mountPath); + // Already mounted. + return; + } + + Directory.CreateDirectory(mountPath); + uint uid = Syscall.getuid(); + uint gid = Syscall.getgid(); + context.Command( + ["sudo"], + ProcessArgumentBuilder + .FromStrings(["mount", "--types", "cifs", $"//{config.RemoteHost}/{config.Share}", mountPath]) + .AppendSwitchSecret( + "--options", + $"username={config.UserName},password={config.Password},uid={uid},gid={gid}" + ) + ); + } + catch (Exception ex) + { + context.Warning("Failed to mount CIFS share: {0}", ex.Message); + } + } +} diff --git a/Automation/Tasks/CSharpier.cs b/Automation/Tasks/CSharpier.cs new file mode 100644 index 0000000..b099697 --- /dev/null +++ b/Automation/Tasks/CSharpier.cs @@ -0,0 +1,20 @@ +using Cake.Common.Tools.DotNet; +using Cake.Frosting; + +namespace Automation.Tasks; + +public class CSharpierCheck : FrostingTask +{ + public override void Run(Context context) + { + context.DotNetTool(Context.ProjectRoot, "csharpier", $"check {Context.ProjectRoot}"); + } +} + +public class CSharpierFormat : FrostingTask +{ + public override void Run(Context context) + { + context.DotNetTool(Context.ProjectRoot, "csharpier", $"format {Context.ProjectRoot}"); + } +} diff --git a/Automation/Tasks/CSpell.cs b/Automation/Tasks/CSpell.cs new file mode 100644 index 0000000..75a23d4 --- /dev/null +++ b/Automation/Tasks/CSpell.cs @@ -0,0 +1,12 @@ +using Cake.Common.Tools.Command; +using Cake.Frosting; + +namespace Automation.Tasks; + +public class CSpell : FrostingTask +{ + public override void Run(Context context) + { + context.Command(["pnpm"], $"cspell {Context.ProjectRoot}"); + } +} diff --git a/Automation/Tasks/DotNetFormat.cs b/Automation/Tasks/DotNetFormat.cs new file mode 100644 index 0000000..f698863 --- /dev/null +++ b/Automation/Tasks/DotNetFormat.cs @@ -0,0 +1,26 @@ +using Cake.Common.Tools.DotNet; +using Cake.Frosting; + +namespace Automation.Tasks; + +public class DotNetFormatCheck : FrostingTask +{ + public override void Run(Context context) + { + context.DotNetMSBuild( + Context.ProjectRoot, + new() { Targets = { "GetTargetPath" }, Properties = { ["DotNetFormatCheck"] = ["true"] } } + ); + } +} + +public class DotNetFormat : FrostingTask +{ + public override void Run(Context context) + { + context.DotNetMSBuild( + Context.ProjectRoot, + new() { Targets = { "GetTargetPath" }, Properties = { ["DotNetFormat"] = ["true"] } } + ); + } +} diff --git a/Automation/Tasks/Prettier.cs b/Automation/Tasks/Prettier.cs new file mode 100644 index 0000000..3b7fe1f --- /dev/null +++ b/Automation/Tasks/Prettier.cs @@ -0,0 +1,20 @@ +using Cake.Common.Tools.Command; +using Cake.Frosting; + +namespace Automation.Tasks; + +public class PrettierCheck : FrostingTask +{ + public override void Run(Context context) + { + context.Command(["pnpm"], $"prettier --check {Context.ProjectRoot}"); + } +} + +public class PrettierFormat : FrostingTask +{ + public override void Run(Context context) + { + context.Command(["pnpm"], $"prettier --write {Context.ProjectRoot}"); + } +} diff --git a/Automation/Test.cs b/Automation/Test.cs new file mode 100644 index 0000000..eece2d9 --- /dev/null +++ b/Automation/Test.cs @@ -0,0 +1,116 @@ +using System.Net.Sockets; +using Cake.Common.Diagnostics; +using Cake.Common.Tools.DotNet; +using Cake.Common.Tools.DotNet.Tool; +using Cake.Core.Diagnostics; +using Cake.Core.IO; +using Cake.Frosting; +using Path = System.IO.Path; + +namespace Automation; + +public class Test : AsyncFrostingTask +{ + public const string TEST_HOST = nameof(TEST_HOST); + public const string TestProjectName = "Divert.Windows.Tests"; + public const string TestRunnerProjectName = "Divert.Windows.TestRunner"; + + public static string GetTestHost() => Environment.GetEnvironmentVariable(TEST_HOST) ?? "host.docker.internal"; + + public static string TestResultsDirectory => Path.Combine(Context.LocalWindowsDirectory, "TestResults"); + + public static string CoverletOutput => Path.Combine(TestResultsDirectory, "coverage.cobertura.xml"); + + public static string TestReportsDirectory => Path.Combine(Context.LocalDirectory, "TestReports"); + + public static string TestRunnerLockFilePath => + Path.Combine(Context.LocalWindowsDirectory, $"{TestRunnerProjectName}/TestRunner.lock"); + + public static void GenerateReport(Context context) + { + string sourcePath = Path.Combine(Context.ProjectRoot, "Divert.Windows"); + // spell-checker: ignore sourcedirs targetdir reporttypes + context.DotNetTool( + "reportgenerator", + new DotNetToolSettings + { + ArgumentCustomization = _ => + ProcessArgumentBuilder.FromStrings( + [ + "reportgenerator", + $"-reports:{CoverletOutput}", + $"-sourcedirs:{sourcePath}", + $"-targetdir:{TestReportsDirectory}", + "-reporttypes:Html;MarkdownSummary", + ] + ), + } + ); + } + + public override async Task RunAsync(Context context) + { + context.DotNetBuild( + Path.Combine(Context.ProjectRoot, TestProjectName), + new() + { + MSBuildSettings = new() + { + Properties = + { + ["DivertWindowsTests"] = ["true"], // Enable DivertValueTaskExecutorDelay + }, + }, + } + ); + if (!Directory.Exists(Path.Combine(Context.LocalWindowsDirectory, TestRunnerProjectName))) + { + context.DotNetBuild(Path.Combine(Context.ProjectRoot, TestRunnerProjectName)); + } + + context.Information($"Waiting for test runner lock file {TestRunnerLockFilePath}..."); + int port = 0; + while (port is 0) + { + try + { + string text = await File.ReadAllTextAsync(TestRunnerLockFilePath); + port = int.Parse(text); + break; + } + catch + { + await Task.Delay(TimeSpan.FromMilliseconds(1000)); + } + } + + using var client = new TcpClient(); + await client.ConnectAsync(GetTestHost(), port); + using var stream = client.GetStream(); + using var reader = new StreamReader(stream); + string lastLine = string.Empty; + while (true) + { + string? line = await reader.ReadLineAsync(); + if (line is null) + { + break; + } + lastLine = line; + Console.WriteLine(line); + } + if (!int.TryParse(lastLine, out int exitCode) || exitCode != 0) + { + throw new Exception("Tests failed"); + } + GenerateReport(context); + } +} + +public class TestReport : FrostingTask +{ + public override void Run(Context context) + { + Test.GenerateReport(context); + } +} diff --git a/Automation/UpdatePackages/Directory.Packages.props b/Automation/UpdatePackages/Directory.Packages.props deleted file mode 100644 index a6ee229..0000000 --- a/Automation/UpdatePackages/Directory.Packages.props +++ /dev/null @@ -1,10 +0,0 @@ - - - false - - - - - - - \ No newline at end of file diff --git a/Automation/UpdatePackages/Program.cs b/Automation/UpdatePackages/Program.cs deleted file mode 100644 index 723d72d..0000000 --- a/Automation/UpdatePackages/Program.cs +++ /dev/null @@ -1,89 +0,0 @@ -using System.Collections.Immutable; -using System.Runtime.CompilerServices; -using Microsoft.Build.Construction; - -static string GetFilePath([CallerFilePath] string? path = null) -{ - if (path is null) - { - throw new InvalidOperationException(nameof(path)); - } - return path; -} - -string filePath = GetFilePath(); -string workspacePath = new FileInfo(filePath).Directory?.Parent?.Parent?.FullName!; -Console.WriteLine($"Workspace path: {workspacePath}"); - -// Load package versions. -string packagePropsFileName = "Directory.Packages.props"; -var packageVersions = await Task.Run(() => -{ - string packagePropsPath = Path.Combine(Path.GetDirectoryName(filePath)!, packagePropsFileName); - var props = ProjectRootElement.Open(packagePropsPath); - var result = new Dictionary(); - foreach (var item in props.Items) - { - if (item is - { - ElementName: "PackageVersion", - Include: string packageName, - FirstChild: ProjectMetadataElement - { - Name: "Version", - Value: string version - } - }) - { - result.Add(packageName, version); - } - } - return result; -}); - -// Update package versions in .csproj files. -var foundPackages = new HashSet(); -foreach (var csProjectPath in Directory.EnumerateFiles(workspacePath, "*.csproj", SearchOption.AllDirectories)) -{ - Console.WriteLine(csProjectPath); - var projectRoot = ProjectRootElement.Open(csProjectPath, new(), preserveFormatting: true); - - foreach (var item in projectRoot.Items) - { - if (item is - { - ElementName: "PackageReference", - Include: string packageName, - FirstChild: ProjectMetadataElement packageVersion and - { - Name: "Version", - Value: string version - } - }) - { - if (packageVersions.TryGetValue(packageName, out string? specifiedVersion)) - { - foundPackages.Add(packageName); - if (version != specifiedVersion) - { - Console.WriteLine($"Updating {packageName} version from {version} to {specifiedVersion}."); - packageVersion.Value = specifiedVersion; - } - } - else - { - Console.WriteLine($"{packageName} version is not specified in {packagePropsFileName}."); - } - } - } - - projectRoot.Save(); - Console.WriteLine(); -} - -foreach (var packageName in packageVersions.Keys.Except(foundPackages).ToImmutableSortedSet()) -{ - Console.WriteLine($"Package {packageName} is not referenced."); -} - -Console.WriteLine("Done."); diff --git a/Automation/UpdatePackages/UpdatePackages.csproj b/Automation/UpdatePackages/UpdatePackages.csproj deleted file mode 100644 index 8379413..0000000 --- a/Automation/UpdatePackages/UpdatePackages.csproj +++ /dev/null @@ -1,15 +0,0 @@ - - - - Exe - net6.0 - preview - enable - enable - - - - - - - \ No newline at end of file diff --git a/Directory.Build.props b/Directory.Build.props index d936144..2b3bcf1 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -1,20 +1,12 @@ - - embedded - true - $(MSBuildProjectDirectory)=$(MSBuildProjectName) - - - FreeBSD - Linux - OSX - Windows - Unknown - + + - $([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildProjectDirectory), "Directory.Build.props")) - $([MSBuild]::MakeRelative($(WorkspacePath), $(MSBuildProjectDirectory))) - $(WorkspacePath)/Build/$(OSName)/$(MSBuildProjectRelativePath)/obj - $(WorkspacePath)/Build/$(OSName)/$(MSBuildProjectRelativePath)/bin + $(MSBuildThisFileDirectory) + $([MSBuild]::NormalizeDirectory($(ProjectRoot), ".config", "dotnet")) + $(ConfigDirectory)Packages.props + net8.0 - \ No newline at end of file + + + diff --git a/Divert.Windows.TestRunner/CoverageSettings.xml b/Divert.Windows.TestRunner/CoverageSettings.xml new file mode 100644 index 0000000..48a287b --- /dev/null +++ b/Divert.Windows.TestRunner/CoverageSettings.xml @@ -0,0 +1,11 @@ + + + + + + .*\.g\.cs$ + .*\.generated\.cs$ + + + + diff --git a/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj b/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj new file mode 100644 index 0000000..70d7fbe --- /dev/null +++ b/Divert.Windows.TestRunner/Divert.Windows.TestRunner.csproj @@ -0,0 +1,15 @@ + + + Exe + $(DefaultTargetFramework) + app.manifest + win-x64 + true + $(ProjectRoot).local/windows/$(MSBuildThisFileName) + CA2007 + + + + + + diff --git a/Divert.Windows.TestRunner/Program.cs b/Divert.Windows.TestRunner/Program.cs new file mode 100644 index 0000000..72cbf1f --- /dev/null +++ b/Divert.Windows.TestRunner/Program.cs @@ -0,0 +1,155 @@ +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; +using System.Reflection; +using System.Text; +using System.Threading.Channels; + +// Runs ../Divert.Windows.Tests with coverage collection enabled. +// In "watch" mode, run tests when receiving requests from Devcontainer. + +string appPath = Path.GetDirectoryName(Environment.ProcessPath)!; +string testResultsDirectory = Path.Combine(appPath, "../TestResults"); +string coverageSettingsPath = Path.Combine(appPath, "CoverageSettings.xml"); +string coverageOutputPath = Path.Combine(testResultsDirectory, "coverage.cobertura.xml"); +string testAppPath = Path.Combine(appPath, "../Divert.Windows.Tests"); + +Directory.SetCurrentDirectory(testAppPath); + +Process LaunchTestProcess(bool redirect) => + Process.Start( + new ProcessStartInfo() + { + FileName = Path.Combine(testAppPath, "Divert.Windows.Tests.exe"), + ArgumentList = + { + "--results-directory", + testResultsDirectory, + "--coverage", + "--coverage-settings", + coverageSettingsPath, + "--coverage-output-format", + "cobertura", + "--coverage-output", + coverageOutputPath, + }, + RedirectStandardOutput = redirect, + RedirectStandardError = redirect, + Environment = + { + // MSTest should respect DOTNET_SYSTEM_CONSOLE_ALLOW_ANSI_COLOR_REDIRECTION instead. + ["GITHUB_ACTIONS"] = redirect.ToString().ToLowerInvariant(), // Trick MSTest to output ANSI colors. + }, + } + )!; + +if (args is not ["watch", ..]) +{ + using var process = LaunchTestProcess(redirect: false); + await process.WaitForExitAsync(); + return process.ExitCode; +} + +Console.WriteLine("Starting Test Runner in watch mode..."); +using var mutex = new Mutex(true, Assembly.GetExecutingAssembly().FullName, out bool createdNew); +if (!createdNew) +{ + Console.WriteLine("Another instance is already running. Exiting..."); + return 1; +} + +using var cts = new CancellationTokenSource(); +Console.CancelKeyPress += (s, e) => +{ + e.Cancel = true; + cts.Cancel(); +}; +var token = cts.Token; + +using var listener = new TcpListener(IPAddress.Any, 0); +listener.Start(); +int port = ((IPEndPoint)listener.LocalEndpoint).Port; +Console.WriteLine($"Listening on port {port}..."); +using var lockFile = new FileStream( + Path.Combine(appPath, "TestRunner.lock"), + FileMode.Create, + FileAccess.ReadWrite, + FileShare.Read, + bufferSize: 0, + FileOptions.Asynchronous | FileOptions.DeleteOnClose +); +await lockFile.WriteAsync(Encoding.UTF8.GetBytes(port.ToString() + '\n'), token); +await lockFile.FlushAsync(token); + +try +{ + while (!token.IsCancellationRequested) + { + Console.WriteLine("Waiting for client connection..."); + using var client = await listener.AcceptSocketAsync(token); + Console.WriteLine("Received client connection."); + + using var sessionCts = CancellationTokenSource.CreateLinkedTokenSource(token); + var sessionToken = sessionCts.Token; + using var stream = new NetworkStream(client, ownsSocket: false); + using var process = LaunchTestProcess(redirect: true); + using var _ = sessionToken.Register(() => process.Kill(entireProcessTree: true)); + + var lines = Channel.CreateBounded(1024); + async Task ForwardLines(StreamReader reader) + { + string? line = null; + while (true) + { + line = await reader.ReadLineAsync(sessionToken); + if (line is null) + { + break; + } + await lines.Writer.WriteAsync(line, sessionToken); + } + } + var readStdOutTask = ForwardLines(process.StandardOutput); + var readStdErrTask = ForwardLines(process.StandardError); + + using var writer = new StreamWriter(stream) { AutoFlush = true }; + var readTask = Task.Run( + async () => + { + try + { + await client.ReceiveAsync(new byte[ushort.MaxValue], sessionToken); // Monitor disconnect + } + finally + { + sessionCts.Cancel(); + } + }, + sessionToken + ); + var writeTask = Task.Run( + async () => + { + await foreach (var line in lines.Reader.ReadAllAsync(sessionToken)) + { + Console.WriteLine(line); + await writer.WriteLineAsync(line.AsMemory(), sessionToken); + } + }, + sessionToken + ); + + await process.WaitForExitAsync(token); + lines.Writer.Complete(); + await Task.WhenAll(readStdOutTask, readStdErrTask, writeTask) + .ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + await writer + .WriteLineAsync(process.ExitCode.ToString().AsMemory(), sessionToken) + .ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + client.Close(); + await readTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + } +} +catch (OperationCanceledException) when (token.IsCancellationRequested) { } + +return 0; diff --git a/Divert.Windows.TestRunner/app.manifest b/Divert.Windows.TestRunner/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Divert.Windows.TestRunner/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/Divert.Windows.Tests/AssemblyInfo.cs b/Divert.Windows.Tests/AssemblyInfo.cs new file mode 100644 index 0000000..a298fba --- /dev/null +++ b/Divert.Windows.Tests/AssemblyInfo.cs @@ -0,0 +1,4 @@ +using System.Runtime.Versioning; + +[assembly: DoNotParallelize] +[assembly: SupportedOSPlatform("windows6.0.6000")] diff --git a/Divert.Windows.Tests/ChecksumTests.cs b/Divert.Windows.Tests/ChecksumTests.cs new file mode 100644 index 0000000..7c6cfae --- /dev/null +++ b/Divert.Windows.Tests/ChecksumTests.cs @@ -0,0 +1,85 @@ +using System.Buffers.Binary; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; + +namespace Divert.Windows.Tests; + +[TestClass] +public class ChecksumTests : DivertTests +{ + [TestMethod] + public async Task InvalidChecksum() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(Token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); + + (int length, _) = await divertReceive; + var packet = packetBuffer.AsMemory(0, length); + + // Calculate and then invalidate checksums + Assert.IsTrue(DivertHelper.CalculateChecksums(packet.Span)); + ushort ipChecksum = BinaryPrimitives.ReadUInt16BigEndian(packet.Span[10..12]); + ushort udpChecksum = BinaryPrimitives.ReadUInt16BigEndian(packet.Span[26..28]); + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[10..12], (ushort)~ipChecksum); // invalidate IP checksum + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[26..28], (ushort)~udpChecksum); // invalidate UDP checksum + + // Recalculate + addressBuffer[0].IsIPChecksumValid = false; + addressBuffer[0].IsTCPChecksumValid = false; + addressBuffer[0].IsUDPChecksumValid = false; + Assert.IsTrue(DivertHelper.CalculateChecksums(packet.Span, ref addressBuffer[0])); + Assert.AreEqual(ipChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[10..12])); + Assert.AreEqual(udpChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[26..28])); + Assert.IsTrue(addressBuffer[0].IsIPChecksumValid); + Assert.IsTrue(addressBuffer[0].IsUDPChecksumValid); + Assert.IsFalse(addressBuffer[0].IsTCPChecksumValid); + + // Recalculate only UDP checksum + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[10..12], (ushort)~ipChecksum); // invalidate IP checksum + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[26..28], (ushort)~udpChecksum); // invalidate UDP checksum + addressBuffer[0].IsIPChecksumValid = false; + addressBuffer[0].IsTCPChecksumValid = false; + addressBuffer[0].IsUDPChecksumValid = false; + Assert.IsTrue( + DivertHelper.CalculateChecksums(packet.Span, ref addressBuffer[0], DivertHelperFlags.NoIPChecksum) + ); + Assert.AreNotEqual(ipChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[10..12])); + Assert.AreEqual(udpChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[26..28])); + Assert.IsFalse(addressBuffer[0].IsIPChecksumValid); + Assert.IsTrue(addressBuffer[0].IsUDPChecksumValid); + Assert.IsFalse(addressBuffer[0].IsTCPChecksumValid); + + // Recalculate only IP checksum + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[10..12], (ushort)~ipChecksum); // invalidate IP checksum + BinaryPrimitives.WriteUInt16BigEndian(packet.Span[26..28], (ushort)~udpChecksum); // invalidate UDP checksum + Assert.IsTrue(DivertHelper.CalculateChecksums(packet.Span, DivertHelperFlags.NoUDPChecksum)); + Assert.AreEqual(ipChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[10..12])); + Assert.AreNotEqual(udpChecksum, BinaryPrimitives.ReadUInt16BigEndian(packet.Span[26..28])); + + // Invalid packet + Assert.IsFalse(DivertHelper.CalculateChecksums(default)); + Assert.AreEqual(0, Marshal.GetLastPInvokeError()); + Assert.IsFalse(DivertHelper.CalculateChecksums(default, ref addressBuffer[0])); + Assert.AreEqual(0, Marshal.GetLastPInvokeError()); + } +} diff --git a/Divert.Windows.Tests/Divert.Windows.Tests.csproj b/Divert.Windows.Tests/Divert.Windows.Tests.csproj new file mode 100644 index 0000000..5d93a34 --- /dev/null +++ b/Divert.Windows.Tests/Divert.Windows.Tests.csproj @@ -0,0 +1,34 @@ + + + Exe + true + $(DefaultTargetFramework) + win-x64 + true + $(ProjectRoot).local/windows/$(MSBuildThisFileName) + false + true + CA2007 + + + + + all + + + + + + + + + + + + + + + all + + + diff --git a/Divert.Windows.Tests/DivertServiceTests.cs b/Divert.Windows.Tests/DivertServiceTests.cs new file mode 100644 index 0000000..624ed62 --- /dev/null +++ b/Divert.Windows.Tests/DivertServiceTests.cs @@ -0,0 +1,346 @@ +using System.ComponentModel; +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; +using Windows.Win32; +using Windows.Win32.Foundation; +using Windows.Win32.Storage.FileSystem; + +namespace Divert.Windows.Tests; + +[TestClass] +public sealed class DivertServiceTests : DivertTests +{ + [TestMethod] + public void Parameters() + { + Assert.Throws(() => + new DivertService(false, priority: DivertService.HighestPriority + 1) + ); + Assert.Throws(() => + new DivertService(false, priority: DivertService.LowestPriority - 1) + ); + + using var service = new DivertService(false, priority: DivertService.HighestPriority); + using var _ = new DivertService(false, priority: DivertService.LowestPriority); + Assert.AreEqual(new Version(2, 2), service.Version); + + Assert.AreEqual(DivertService.DefaultQueueLength, service.QueueLength); + service.QueueLength = DivertService.MinQueueLength; + Assert.AreEqual(DivertService.MinQueueLength, service.QueueLength); + service.QueueLength = DivertService.MaxQueueLength; + Assert.AreEqual(DivertService.MaxQueueLength, service.QueueLength); + Assert.Throws(() => service.QueueLength = DivertService.MinQueueLength - 1); + Assert.Throws(() => service.QueueLength = DivertService.MaxQueueLength + 1); + + Assert.AreEqual(DivertService.DefaultQueueTime, service.QueueTime); + service.QueueTime = DivertService.MinQueueTime; + Assert.AreEqual(DivertService.MinQueueTime, service.QueueTime); + service.QueueTime = DivertService.MaxQueueTime; + Assert.AreEqual(DivertService.MaxQueueTime, service.QueueTime); + Assert.Throws(() => + service.QueueTime = DivertService.MinQueueTime - TimeSpan.FromMilliseconds(1) + ); + Assert.Throws(() => + service.QueueTime = DivertService.MaxQueueTime + TimeSpan.FromMilliseconds(1) + ); + + Assert.AreEqual(DivertService.DefaultQueueSize, service.QueueSize); + service.QueueSize = DivertService.MinQueueSize; + Assert.AreEqual(DivertService.MinQueueSize, service.QueueSize); + service.QueueSize = DivertService.MaxQueueSize; + Assert.AreEqual(DivertService.MaxQueueSize, service.QueueSize); + Assert.Throws(() => service.QueueSize = DivertService.MinQueueSize - 1); + Assert.Throws(() => service.QueueSize = DivertService.MaxQueueSize + 1); + } + + public static IEnumerable InvalidFilterCases() + { + return + [ + [ + DivertLayer.Network, + DivertFilter.Protocol == "invalid", + "Filter expression contains a bad token (11): ...invalid", + ], + [ + DivertLayer.Network, + DivertFilter.Ip & DivertFilter.Layer == DivertLayer.Network & DivertFilter.Loopback, + "Filter expression contains a bad token for layer (7): ...layer = NETWORK and loopback", + ], + [ + DivertLayer.Network, + DivertFilter.Ip & DivertFilter.Event == DivertEvent.SocketBind & DivertFilter.Loopback, + "Filter expression parse error (15): ...BIND and loopback", + ], + ]; + } + + [TestMethod] + [DynamicData(nameof(InvalidFilterCases))] + public void InvalidFilter(DivertLayer layer, DivertFilter filter, string message) + { + var exception = Assert.Throws(() => new DivertService(filter, layer: layer)); + Assert.AreEqual("filter", exception.ParamName); + Assert.StartsWith(message, exception.Message); + } + + [TestMethod] + public async Task ModifyPayload() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(Token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + + long begin = Stopwatch.GetTimestamp(); + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); + + var divertResult = await divertReceive; + long end = Stopwatch.GetTimestamp(); + + Assert.AreEqual(20 + 8 + 3, divertResult.DataLength); + var packet = packetBuffer.AsMemory(0, divertResult.DataLength); + CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, packet.ToArray()[28..]); + Assert.AreEqual(1, divertResult.AddressLength); + Assert.IsFalse(receive.IsCompleted); + Assert.AreEqual(DivertLayer.Network, addressBuffer[0].Layer); + Assert.IsGreaterThanOrEqualTo(begin, addressBuffer[0].Timestamp); + Assert.IsLessThanOrEqualTo(end, addressBuffer[0].Timestamp); + Assert.AreEqual(DivertEvent.NetworkPacket, addressBuffer[0].Event); + Assert.IsFalse(addressBuffer[0].IsSniffed); + Assert.IsTrue(addressBuffer[0].IsOutbound); + Assert.IsTrue(addressBuffer[0].IsLoopback); + Assert.IsFalse(addressBuffer[0].IsImpostor); + Assert.IsFalse(addressBuffer[0].IsIPv6); + var networkData = addressBuffer[0].GetNetworkData(); + Assert.AreEqual(GetLoopbackInterfaceIndex(), (int)networkData.InterfaceIndex); + Assert.AreEqual(0, (int)networkData.SubInterfaceIndex); + Assert.Throws(() => addressBuffer[0].GetFlowData()); + Assert.Throws(() => addressBuffer[0].GetSocketData()); + Assert.Throws(() => addressBuffer[0].GetReflectData()); + + // Re-inject + addressBuffer[0] = new DivertAddress(1, 1) { IsImpostor = true, IsOutbound = true }; + await service.SendAsync(packet, addressBuffer, Token); + var result = await receive; + Assert.HasCount(3, result.Buffer); + CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, result.Buffer); + + // Re-inject with modified data + addressBuffer[0].Reset(); + addressBuffer[0].IsImpostor = true; + addressBuffer[0].IsOutbound = true; + receive = listener.ReceiveAsync(Token).AsTask(); + new byte[] { 4, 5, 6 }.CopyTo(packet.Span[28..]); + Assert.IsTrue(DivertHelper.CalculateChecksums(packet.Span)); + Assert.IsFalse(receive.IsCompleted); + await service.SendAsync(packet, addressBuffer, CancellationToken.None); + result = await receive; + Assert.HasCount(3, result.Buffer); + CollectionAssert.AreEqual(new byte[] { 4, 5, 6 }, result.Buffer); + } + + [TestMethod] + public async Task Close() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(Token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + service.Dispose(); + var exception = await Assert.ThrowsAsync(async () => await divertReceive); + Assert.IsFalse(Token.IsCancellationRequested); + Assert.AreNotEqual(Token, exception.CancellationToken); + Assert.AreEqual(CancellationToken.None, exception.CancellationToken); + } + + [TestMethod] + public async Task CancelReceive() + { + using var listener = CreateUdpListener(out int port); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + // Cancel before receive + { + using var cts = new CancellationTokenSource(); + cts.Cancel(); + var token = cts.Token; + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, token).AsTask(); + Assert.IsTrue(divertReceive.IsCanceled); + } + + // Cancel on receive + using (var pipe = new ExecutorDelayPipe()) + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(Token); + var token = cts.Token; + + var divertReceive = Task.Run(async () => await service.ReceiveAsync(packetBuffer, addressBuffer, token)); + Assert.IsFalse(divertReceive.IsCompleted); + + await pipe.Stream.WaitForConnectionAsync(token); + cts.Cancel(); + Assert.IsFalse(divertReceive.IsCompleted); + pipe.Stream.WriteByte(0); + await pipe.Stream.FlushAsync(token); + pipe.Stream.Disconnect(); + + var exception = await Assert.ThrowsAsync(async () => await divertReceive); + Assert.IsTrue(token.IsCancellationRequested); + Assert.AreEqual(token, exception.CancellationToken); + } + + // Cancel pending receive + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(Token); + var token = cts.Token; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + cts.Cancel(); + var exception = await Assert.ThrowsAsync(async () => await divertReceive); + Assert.IsTrue(token.IsCancellationRequested); + Assert.AreEqual(token, exception.CancellationToken); + } + } + + [TestMethod] + public void CancelSend() + { + using var service = new DivertService(DivertFilter.False); + + var packetBuffer = new byte[100]; + var addressBuffer = new DivertAddress[1]; + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(Token); + var token = cts.Token; + cts.Cancel(); + var divertSend = service.SendAsync(packetBuffer, addressBuffer, token).AsTask(); + Assert.IsTrue(divertSend.IsCanceled); + } + + [TestMethod] + public async Task InsufficientBuffer() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(Token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[10]; // insufficient buffer + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); + + var exception = await Assert.ThrowsAsync(async () => await divertReceive); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INSUFFICIENT_BUFFER, exception.NativeErrorCode); + } + + [TestMethod] + public async Task DisposeOnReceive() + { + using var listener = CreateUdpListener(out int port); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + using var pipe = new ExecutorDelayPipe(); + var divertReceive = Task.Run(async () => await service.ReceiveAsync(packetBuffer, addressBuffer, Token)); + Assert.IsFalse(divertReceive.IsCompleted); + + await pipe.Stream.WaitForConnectionAsync(Token); + Assert.IsFalse(divertReceive.IsCompleted); + service.Dispose(); + pipe.Stream.WriteByte(0); + await pipe.Stream.FlushAsync(Token); + pipe.Stream.Disconnect(); + + await Assert.ThrowsAsync(async () => await divertReceive); + } + + [TestMethod] + public async Task InvalidHandle() + { + var handle = PInvoke.CreateFile( + Path.GetTempFileName(), + (uint)GENERIC_ACCESS_RIGHTS.GENERIC_READ, + FILE_SHARE_MODE.FILE_SHARE_NONE, + null, + FILE_CREATION_DISPOSITION.OPEN_EXISTING, + FILE_FLAGS_AND_ATTRIBUTES.FILE_ATTRIBUTE_NORMAL | FILE_FLAGS_AND_ATTRIBUTES.FILE_FLAG_OVERLAPPED, + null + ); + using var service = new DivertService(handle); + Assert.AreEqual(handle.DangerousGetHandle(), service.SafeHandle.DangerousGetHandle()); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + var exception = await Assert.ThrowsAsync(async () => + await service.ReceiveAsync(packetBuffer, addressBuffer, Token) + ); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + + exception = Assert.Throws(() => service.QueueLength = DivertService.DefaultQueueLength); + Assert.AreEqual((int)WIN32_ERROR.ERROR_ACCESS_DENIED, exception.NativeErrorCode); + } +} diff --git a/Divert.Windows.Tests/DivertTests.cs b/Divert.Windows.Tests/DivertTests.cs new file mode 100644 index 0000000..7d47daf --- /dev/null +++ b/Divert.Windows.Tests/DivertTests.cs @@ -0,0 +1,43 @@ +using System.Net; +using System.Net.NetworkInformation; +using System.Net.Sockets; + +namespace Divert.Windows.Tests; + +public abstract class DivertTests : IDisposable +{ + private readonly CancellationTokenSource cts; + private readonly CancellationToken token; + + protected CancellationToken Token => token; + + public DivertTests() + { + cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + token = cts.Token; + } + + public void Dispose() + { + cts.Dispose(); + GC.SuppressFinalize(this); + } + + public static UdpClient CreateUdpListener(out int port) + { + var client = new UdpClient(new IPEndPoint(IPAddress.Loopback, 0)); + var localEndPoint = (IPEndPoint)client.Client.LocalEndPoint!; + port = localEndPoint.Port; + return client; + } + + public static int GetLoopbackInterfaceIndex() + { + return NetworkInterface + .GetAllNetworkInterfaces() + .Single(i => i.NetworkInterfaceType == NetworkInterfaceType.Loopback) + .GetIPProperties() + .GetIPv4Properties() + .Index; + } +} diff --git a/Divert.Windows.Tests/ExecutorDelayPipe.cs b/Divert.Windows.Tests/ExecutorDelayPipe.cs new file mode 100644 index 0000000..5b4dbe0 --- /dev/null +++ b/Divert.Windows.Tests/ExecutorDelayPipe.cs @@ -0,0 +1,25 @@ +using System.IO.Pipes; + +namespace Divert.Windows.Tests; + +internal sealed class ExecutorDelayPipe : IDisposable +{ + private const string DIVERT_WINDOWS_TESTS = nameof(DIVERT_WINDOWS_TESTS); + + private readonly string name; + + public NamedPipeServerStream Stream { get; } + + public ExecutorDelayPipe() + { + name = Guid.NewGuid().ToString("N"); + Environment.SetEnvironmentVariable(DIVERT_WINDOWS_TESTS, name); + Stream = new NamedPipeServerStream(name, PipeDirection.InOut); + } + + public void Dispose() + { + Environment.SetEnvironmentVariable(DIVERT_WINDOWS_TESTS, null); + Stream.Dispose(); + } +} diff --git a/Divert.Windows.Tests/FilterTests.cs b/Divert.Windows.Tests/FilterTests.cs new file mode 100644 index 0000000..7005f67 --- /dev/null +++ b/Divert.Windows.Tests/FilterTests.cs @@ -0,0 +1,88 @@ +namespace Divert.Windows.Tests; + +[TestClass] +public class FilterTests +{ + public static IEnumerable FilterCases() + { + return + [ + [(DivertFilter.TCP | DivertFilter.UDP), "tcp or udp"], + [(DivertFilter.TCP & DivertFilter.Loopback), "tcp and loopback"], + [(DivertFilter.TCP | DivertFilter.UDP) & DivertFilter.Outbound, "(tcp or udp) and outbound"], + [DivertFilter.Outbound & (DivertFilter.TCP | DivertFilter.UDP), "outbound and (tcp or udp)"], + [ + (DivertFilter.TCP & (DivertFilter.RemotePort == "80" | DivertFilter.RemotePort == "443")) + | DivertFilter.Outbound, + "(tcp and (remotePort = 80 or remotePort = 443)) or outbound", + ], + [ + DivertFilter.Outbound + & (DivertFilter.TCP | DivertFilter.UDP) + & (DivertFilter.RemotePort == "80" | DivertFilter.RemotePort == "443"), + "outbound and (tcp or udp) and (remotePort = 80 or remotePort = 443)", + ], + [ + !DivertFilter.Loopback | (DivertFilter.TCP & DivertFilter.RemotePort == "443"), + "not loopback or (tcp and remotePort = 443)", + ], + [ + DivertFilter.Loopback | (!DivertFilter.TCP & DivertFilter.RemotePort != "443"), + "loopback or (not tcp and remotePort != 443)", + ], + [ + DivertFilter.Inbound & (DivertFilter.LocalPort > "30000" | DivertFilter.LocalPort < "1024"), + "inbound and (localPort > 30000 or localPort < 1024)", + ], + [ + DivertFilter.RemotePort >= "5000" | DivertFilter.RemotePort <= "6000", + "remotePort >= 5000 or remotePort <= 6000", + ], + [ + DivertFilter.Packet(10) == "0x1" + | DivertFilter.Packet16(20) == "0x2" + | DivertFilter.Packet32(30) == "0x3", + "packet[10] = 0x1 or packet16[20] = 0x2 or packet32[30] = 0x3", + ], + ]; + } + + [TestMethod] + [DynamicData(nameof(FilterCases))] + public void FilterToString(DivertFilter filter, string expected) + { + var result = filter.ToString(); + Assert.AreEqual(expected, result); + } + + [TestMethod] + public void InvalidFilter() + { + // Unmatched parentheses + Assert.Throws(() => DivertFilter.Outbound & "(tcp or udp))"); + Assert.Throws(() => DivertFilter.Outbound & "((tcp or udp)"); + } + + [TestMethod] + public void Equals() + { + var filter = new DivertFilter("(tcp or udp) and outbound"); + Assert.IsTrue(filter.Equals(filter)); + Assert.IsTrue(filter.Equals("(tcp or udp) and outbound")); + Assert.IsTrue(filter.Equals((DivertFilter.TCP | DivertFilter.UDP) & DivertFilter.Outbound)); + Assert.IsFalse(filter.Equals(DivertFilter.TCP)); + Assert.IsFalse(filter.Equals(null)); + Assert.IsFalse(new DivertFilter("5").Equals(5)); + Assert.AreEqual(filter.GetHashCode(), filter.ToString().GetHashCode()); + + var field = new DivertFilter.Field("tcp"); + Assert.IsTrue(field.Equals(field)); + Assert.IsTrue(field.Equals("tcp")); + Assert.IsTrue(field.Equals(DivertFilter.TCP)); + Assert.IsFalse(field.Equals("udp")); + Assert.IsFalse(field.Equals(DivertFilter.UDP)); + Assert.IsFalse(field.Equals(null)); + Assert.IsFalse(new DivertFilter.Field("5").Equals(5)); + Assert.AreEqual(field.GetHashCode(), field.ToString().GetHashCode()); + } +} diff --git a/Divert.Windows.Tests/FlowTests.cs b/Divert.Windows.Tests/FlowTests.cs new file mode 100644 index 0000000..799240b --- /dev/null +++ b/Divert.Windows.Tests/FlowTests.cs @@ -0,0 +1,69 @@ +using System.Net; +using System.Net.Sockets; + +namespace Divert.Windows.Tests; + +[TestClass] +public class FlowTests : DivertTests +{ + [TestMethod] + public async Task FlowData() + { + using var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + int port = ((IPEndPoint)listener.LocalEndpoint).Port; + + var filter = DivertFilter.TCP & DivertFilter.Loopback & (DivertFilter.RemotePort == port); + using var service = new DivertService( + filter, + DivertLayer.Flow, + flags: DivertFlags.Sniff | DivertFlags.ReceiveOnly + ); + + var addressBuffer = new DivertAddress[1]; + var divertReceive = service.ReceiveAsync(default, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + var accept = listener.AcceptSocketAsync(Token); + Assert.IsFalse(accept.IsCompleted); + await client.ConnectAsync(IPAddress.Loopback, port, Token); + int clientPort = ((IPEndPoint)client.LocalEndPoint!).Port; + using var serverSocket = await accept; + await client.SendAsync(new byte[] { 1, 2, 3 }, SocketFlags.None); + + var divertResult = await divertReceive; + var address = addressBuffer[0]; + Assert.AreEqual(0, divertResult.DataLength); + Assert.AreEqual(DivertLayer.Flow, address.Layer); + Assert.AreEqual(DivertEvent.FlowEstablished, address.Event); + Assert.Throws(() => address.GetNetworkData()); + var flowData = address.GetFlowData(); + Assert.AreEqual(Environment.ProcessId, (int)flowData.ProcessId); + Assert.AreEqual(IPAddress.Loopback, flowData.LocalAddress); + Assert.AreEqual(IPAddress.Loopback, flowData.RemoteAddress); + Assert.AreEqual((ushort)port, flowData.RemotePort); + Assert.AreEqual((ushort)clientPort, flowData.LocalPort); + Assert.AreEqual((byte)ProtocolType.Tcp, flowData.Protocol); + + addressBuffer[0].Reset(); + divertReceive = service.ReceiveAsync(default, addressBuffer, Token).AsTask(); + + client.Shutdown(SocketShutdown.Both); + serverSocket.Shutdown(SocketShutdown.Both); + serverSocket.Close(); + client.Close(); + + divertResult = await divertReceive; + address = addressBuffer[0]; + Assert.AreEqual(DivertLayer.Flow, address.Layer); + Assert.AreEqual(DivertEvent.FlowDeleted, address.Event); + flowData = address.GetFlowData(); + Assert.AreEqual(Environment.ProcessId, (int)flowData.ProcessId); + Assert.AreEqual(IPAddress.Loopback, flowData.LocalAddress); + Assert.AreEqual(IPAddress.Loopback, flowData.RemoteAddress); + Assert.AreEqual((ushort)port, flowData.RemotePort); + Assert.AreEqual((ushort)clientPort, flowData.LocalPort); + Assert.AreEqual((byte)ProtocolType.Tcp, flowData.Protocol); + } +} diff --git a/Divert.Windows.Tests/HelperTests.cs b/Divert.Windows.Tests/HelperTests.cs new file mode 100644 index 0000000..7dd3c18 --- /dev/null +++ b/Divert.Windows.Tests/HelperTests.cs @@ -0,0 +1,142 @@ +using System.ComponentModel; +using System.Net; +using System.Net.Sockets; +using Windows.Win32.Foundation; + +namespace Divert.Windows.Tests; + +[TestClass] +public class HelperTests : DivertTests +{ + [TestMethod] + public void FormatFilter() + { + var invalidFilter = Array.Empty(); + var exception = Assert.Throws(() => + DivertHelper.FormatFilter(invalidFilter, DivertLayer.Network) + ); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + } + + [TestMethod] + public async Task DecrementTtl() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(Token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); + var divertResult = await divertReceive; + var packet = packetBuffer.AsMemory(0, divertResult.DataLength); + + // set ttl to 2 + packet.Span[8] = 2; + Assert.IsTrue(DivertHelper.DecrementTtl(packet.Span)); + Assert.AreEqual(1, packet.Span[8]); + Assert.IsFalse(DivertHelper.DecrementTtl(packet.Span)); + Assert.AreEqual(1, packet.Span[8]); + } + + [TestMethod] + public async Task CompileFilter() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(Token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + var filterBuffer = DivertHelper.CompileFilter(filter, DivertLayer.Network); + using var service = new DivertService(filterBuffer); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); + var divertResult = await divertReceive; + + Assert.AreEqual(20 + 8 + 3, divertResult.DataLength); + var packet = packetBuffer.AsMemory(0, divertResult.DataLength); + CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, packet.ToArray()[28..]); + } + + private static IEnumerable InvalidFilterCases() => DivertServiceTests.InvalidFilterCases(); + + [TestMethod] + [DynamicData(nameof(InvalidFilterCases))] + public void CompileInvalidFilter(DivertLayer layer, DivertFilter filter, string message) + { + var exception = Assert.Throws(() => DivertHelper.CompileFilter(filter, layer)); + Assert.AreEqual("filter", exception.ParamName); + Assert.StartsWith(message, exception.Message); + } + + [TestMethod] + public async Task EvaluateFilter() + { + using var listener = CreateUdpListener(out int port); + var receive = listener.ReceiveAsync(Token).AsTask(); + + var filter = + DivertFilter.UDP + & DivertFilter.Loopback + & DivertFilter.Ip + & !DivertFilter.Impostor + & (DivertFilter.RemotePort == port); + using var service = new DivertService(filter); + + var packetBuffer = new byte[ushort.MaxValue + 40]; + var addressBuffer = new DivertAddress[1]; + + var divertReceive = service.ReceiveAsync(packetBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + // send 3 bytes payload + using var client = new UdpClient(); + client.Connect(IPAddress.Loopback, port); + + await client.SendAsync(new byte[] { 1, 2, 3 }, Token); + var divertResult = await divertReceive; + var packet = packetBuffer.AsMemory(0, divertResult.DataLength); + var address = addressBuffer[0]; + + Assert.IsTrue(DivertHelper.EvaluateFilter(filter, packet.Span, address)); + Assert.IsTrue( + DivertHelper.EvaluateFilter(DivertHelper.CompileFilter(filter, DivertLayer.Network), packet.Span, address) + ); + Assert.IsTrue(DivertHelper.EvaluateFilter(DivertFilter.UDP, packet.Span, address)); + Assert.IsFalse(DivertHelper.EvaluateFilter(DivertFilter.TCP, packet.Span, address)); + + var exception = Assert.Throws(() => DivertHelper.EvaluateFilter([], packet.Span, address)); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + exception = Assert.Throws(() => DivertHelper.EvaluateFilter(filter, [], address)); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + } +} diff --git a/Divert.Windows.Tests/NativeMethods.txt b/Divert.Windows.Tests/NativeMethods.txt new file mode 100644 index 0000000..7e147bb --- /dev/null +++ b/Divert.Windows.Tests/NativeMethods.txt @@ -0,0 +1,3 @@ +WIN32_ERROR +CreateFile +GENERIC_ACCESS_RIGHTS diff --git a/Divert.Windows.Tests/ReflectTests.cs b/Divert.Windows.Tests/ReflectTests.cs new file mode 100644 index 0000000..1984be5 --- /dev/null +++ b/Divert.Windows.Tests/ReflectTests.cs @@ -0,0 +1,148 @@ +using System.ComponentModel; +using System.Diagnostics; +using System.Net.Sockets; +using Windows.Win32.Foundation; + +namespace Divert.Windows.Tests; + +[TestClass] +public class ReflectTests : DivertTests +{ + [TestMethod] + public async Task ReflectData() + { + using var reflectService = new DivertService( + true, + DivertLayer.Reflect, + flags: DivertFlags.ReceiveOnly | DivertFlags.Sniff | DivertFlags.NoInstall + ); + + var dataBuffer = new byte[ushort.MaxValue]; + var addressBuffer = new DivertAddress[1]; + var divertReceive = reflectService.ReceiveAsync(dataBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + long begin = Stopwatch.GetTimestamp(); + using var service = new DivertService(false, priority: 5, flags: DivertFlags.WriteOnly); + long end = Stopwatch.GetTimestamp(); + + var divertResult = await divertReceive; + var address = addressBuffer[0]; + Assert.AreEqual(DivertLayer.Reflect, address.Layer); + Assert.AreEqual(DivertEvent.ReflectOpen, address.Event); + var reflectData = address.GetReflectData(); + Assert.IsTrue(reflectData.Timestamp >= begin && reflectData.Timestamp <= end); + Assert.AreEqual(Environment.ProcessId, (int)reflectData.ProcessId); + Assert.AreEqual(DivertLayer.Network, reflectData.Layer); + Assert.AreEqual(DivertFlags.WriteOnly, reflectData.Flags); + Assert.AreEqual(5, reflectData.Priority); + var packet = dataBuffer.AsSpan(0, divertResult.DataLength); + string filterString = DivertHelper.FormatFilter(packet, reflectData.Layer); + Assert.AreEqual("false", filterString); + } + + public static IEnumerable FormatFilterCases() + { + return + [ + [ + DivertLayer.Socket, + DivertFilter.Protocol == ProtocolType.Tcp + | DivertFilter.Protocol == ProtocolType.Udp + | DivertFilter.Protocol == ProtocolType.Icmp + | DivertFilter.Protocol == ProtocolType.IcmpV6, + $"protocol = {(int)ProtocolType.Tcp} or " + + $"protocol = {(int)ProtocolType.Udp} or " + + $"protocol = {(int)ProtocolType.Icmp} or " + + $"protocol = {(int)ProtocolType.IcmpV6}", + ], + [DivertLayer.Socket, DivertFilter.Protocol == ProtocolType.IP, $"protocol = {(int)ProtocolType.IP}"], + [DivertLayer.Socket, DivertFilter.Ip == true, "ip"], + [DivertLayer.Socket, DivertFilter.Ip == false, "not ip"], + [DivertLayer.Socket, DivertFilter.Ip != true, "not ip"], + [DivertLayer.Socket, DivertFilter.Ip != false, "ip"], + [DivertLayer.Socket, DivertFilter.Protocol == ProtocolType.Unknown, "false"], + [DivertLayer.Socket, DivertFilter.Protocol != ProtocolType.Unknown, "true"], + [DivertLayer.Socket, DivertFilter.Protocol == Enum.GetValues().Max() + 1, "false"], + [ + DivertLayer.Socket, + ( + DivertFilter.Event == DivertEvent.SocketBind + | DivertFilter.Event == DivertEvent.SocketConnect + | DivertFilter.Event == DivertEvent.SocketListen + | DivertFilter.Event == DivertEvent.SocketAccept + | DivertFilter.Event == DivertEvent.SocketClose + ), + "event = BIND or event = CONNECT or event = LISTEN or event = ACCEPT or event = CLOSE", + ], + [DivertLayer.Network, DivertFilter.Event == DivertEvent.NetworkPacket, "event = PACKET"], + [DivertLayer.Forward, DivertFilter.Event != DivertEvent.NetworkPacket, "event != PACKET"], + [ + DivertLayer.Flow, + (DivertFilter.Event == DivertEvent.FlowEstablished | DivertFilter.Event == DivertEvent.FlowDeleted), + "event = ESTABLISHED or event = DELETED", + ], + ]; + } + + [TestMethod] + [DynamicData(nameof(FormatFilterCases))] + public async Task FormatFilter(DivertLayer layer, DivertFilter filter, string expected) + { + using var reflectService = new DivertService( + DivertFilter.Event == DivertEvent.ReflectOpen + & DivertFilter.ProcessId == Environment.ProcessId + & ( + DivertFilter.Layer == DivertLayer.Network + | DivertFilter.Layer == DivertLayer.Socket + | DivertFilter.Layer == DivertLayer.Forward + | DivertFilter.Layer == DivertLayer.Flow + | DivertFilter.Layer != DivertLayer.Reflect + ), + DivertLayer.Reflect, + priority: 2, + DivertFlags.ReceiveOnly | DivertFlags.Sniff | DivertFlags.NoInstall + ); + + var dataBuffer = new byte[ushort.MaxValue]; + var addressBuffer = new DivertAddress[1]; + var divertReceive = reflectService.ReceiveAsync(dataBuffer, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + using var service = new DivertService(filter, layer, priority: 1, DivertFlags.Sniff | DivertFlags.ReceiveOnly); + + var divertResult = await divertReceive; + var address = addressBuffer[0]; + Assert.AreEqual(DivertLayer.Reflect, address.Layer); + Assert.AreEqual(DivertEvent.ReflectOpen, address.Event); + var reflectData = address.GetReflectData(); + Assert.AreEqual(Environment.ProcessId, (int)reflectData.ProcessId); + Assert.AreEqual(layer, reflectData.Layer); + Assert.AreEqual(DivertFlags.Sniff | DivertFlags.ReceiveOnly, reflectData.Flags); + Assert.AreEqual(1, reflectData.Priority); + var packet = dataBuffer.AsSpan(0, divertResult.DataLength); + string filterString = DivertHelper.FormatFilter(packet, reflectData.Layer); + Assert.AreEqual(expected, filterString); + } + + [TestMethod] + public void InvalidFilter() + { + var exception = Assert.Throws(() => + { + using var service = new DivertService( + DivertFilter.Event == Enum.GetValues().Max() + 1, + DivertLayer.Reflect + ); + }); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + exception = Assert.Throws(() => + { + using var service = new DivertService( + DivertFilter.Layer == Enum.GetValues().Max() + 1, + DivertLayer.Reflect + ); + }); + Assert.AreEqual((int)WIN32_ERROR.ERROR_INVALID_PARAMETER, exception.NativeErrorCode); + } +} diff --git a/Divert.Windows.Tests/SocketTests.cs b/Divert.Windows.Tests/SocketTests.cs new file mode 100644 index 0000000..9dece24 --- /dev/null +++ b/Divert.Windows.Tests/SocketTests.cs @@ -0,0 +1,40 @@ +using System.Net; +using System.Net.Sockets; + +namespace Divert.Windows.Tests; + +[TestClass] +public class SocketTests : DivertTests +{ + [TestMethod] + public async Task SocketData() + { + var filter = DivertFilter.TCP & DivertFilter.Loopback & (DivertFilter.ProcessId == Environment.ProcessId); + using var service = new DivertService( + filter, + DivertLayer.Socket, + flags: DivertFlags.Sniff | DivertFlags.ReceiveOnly + ); + + var addressBuffer = new DivertAddress[1]; + var divertReceive = service.ReceiveAsync(default, addressBuffer, Token).AsTask(); + Assert.IsFalse(divertReceive.IsCompleted); + + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + socket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + int port = ((IPEndPoint)socket.LocalEndPoint!).Port; + + var divertResult = await divertReceive; + var address = addressBuffer[0]; + Assert.AreEqual(0, divertResult.DataLength); + Assert.AreEqual(DivertLayer.Socket, address.Layer); + Assert.AreEqual(DivertEvent.SocketBind, address.Event); + var socketData = address.GetSocketData(); + Assert.AreEqual(Environment.ProcessId, (int)socketData.ProcessId); + Assert.AreEqual(IPAddress.Loopback, socketData.LocalAddress); + Assert.AreEqual(IPAddress.Any, socketData.RemoteAddress); + Assert.AreEqual((ushort)port, socketData.LocalPort); + Assert.AreEqual(0, socketData.RemotePort); + Assert.AreEqual((byte)ProtocolType.Tcp, socketData.Protocol); + } +} diff --git a/Divert.Windows/AssemblyInfo.cs b/Divert.Windows/AssemblyInfo.cs new file mode 100644 index 0000000..654c4a8 --- /dev/null +++ b/Divert.Windows/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.Versioning; + +[assembly: SupportedOSPlatform("windows6.0.6000")] diff --git a/Divert.Windows/AsyncOperation/CancellationHandle.cs b/Divert.Windows/AsyncOperation/CancellationHandle.cs new file mode 100644 index 0000000..dfd8d45 --- /dev/null +++ b/Divert.Windows/AsyncOperation/CancellationHandle.cs @@ -0,0 +1,60 @@ +using System.Runtime.InteropServices; +using Windows.Win32; + +namespace Divert.Windows.AsyncOperation; + +internal sealed unsafe class CancellationHandle(SafeHandle safeHandle) : IDisposable +{ + private static class Status + { + public const int Idle = 0; + public const int Canceled = 1; + public const int Pending = 2; + public const int Disposed = 3; + } + + private int status; + + public SafeHandle SafeHandle => safeHandle; + + // From cancellation registration. + public void RequestOrInvokeCancel(NativeOverlapped* nativeOverlapped) + { + int originalStatus = Interlocked.CompareExchange(ref status, Status.Canceled, Status.Idle); + if (originalStatus is Status.Pending) + { + using (safeHandle.DangerousGetHandle(out var handle)) + { + _ = PInvoke.CancelIoEx(new(handle), nativeOverlapped); + } + } + } + + // After ERROR_IO_PENDING. + public void CancelWhenRequested(NativeOverlapped* nativeOverlapped) + { + int originalStatus = Interlocked.CompareExchange(ref status, Status.Pending, Status.Idle); + if (originalStatus is Status.Canceled) + { + using (safeHandle.DangerousGetHandle(out var handle)) + { + _ = PInvoke.CancelIoEx(new(handle), nativeOverlapped); + } + } + } + + // From completion callback. + public void Reset() + { + int originalStatus = Interlocked.CompareExchange(ref status, Status.Idle, Status.Pending); + if (originalStatus is Status.Canceled) + { + Interlocked.CompareExchange(ref status, Status.Idle, originalStatus); + } + } + + public void Dispose() + { + Interlocked.Exchange(ref status, Status.Disposed); + } +} diff --git a/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs b/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs new file mode 100644 index 0000000..a373376 --- /dev/null +++ b/Divert.Windows/AsyncOperation/DivertReceiveExecutor.cs @@ -0,0 +1,37 @@ +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Divert.Windows.AsyncOperation; + +internal sealed class DivertReceiveExecutor : IDivertValueTaskExecutor +{ + private readonly uint[] addressesLengthBuffer = GC.AllocateArray(1, pinned: true); + + public unsafe bool Execute(SafeHandle safeHandle, ref readonly PendingOperation pendingOperation) + { + using var _ = safeHandle.DangerousGetHandle(out var handle); + addressesLengthBuffer[0] = (uint)(pendingOperation.Addresses.Length * sizeof(DivertAddress)); + return NativeMethods.WinDivertRecvEx( + handle, + pendingOperation.PacketBufferHandle.Pointer, + (uint)pendingOperation.PacketBuffer.Length, + null, + 0, + (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, + (uint*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(addressesLengthBuffer.AsSpan())), + pendingOperation.NativeOverlapped + ); + } + + public async ValueTask ReceiveAsync( + DivertValueTaskSource source, + Memory buffer, + Memory addresses, + CancellationToken cancellationToken + ) + { + int dataLength = await source.ExecuteAsync(this, buffer, addresses, cancellationToken).ConfigureAwait(false); + int addressesLength = (int)addressesLengthBuffer[0] / Marshal.SizeOf(); + return new DivertReceiveResult(dataLength, addressesLength); + } +} diff --git a/Divert.Windows/AsyncOperation/DivertSendExecutor.cs b/Divert.Windows/AsyncOperation/DivertSendExecutor.cs new file mode 100644 index 0000000..a010e5a --- /dev/null +++ b/Divert.Windows/AsyncOperation/DivertSendExecutor.cs @@ -0,0 +1,36 @@ +using System.Runtime.InteropServices; + +namespace Divert.Windows.AsyncOperation; + +internal sealed unsafe class DivertSendExecutor : IDivertValueTaskExecutor +{ + public bool Execute(SafeHandle safeHandle, ref readonly PendingOperation pendingOperation) + { + using var _ = safeHandle.DangerousGetHandle(out var handle); + return NativeMethods.WinDivertSendEx( + handle, + pendingOperation.PacketBufferHandle.Pointer, + (uint)pendingOperation.PacketBuffer.Length, + null, + 0, + (WINDIVERT_ADDRESS*)pendingOperation.AddressesHandle.Pointer, + (uint)(pendingOperation.Addresses.Length * sizeof(DivertAddress)), + pendingOperation.NativeOverlapped + ); + } + + public ValueTask SendAsync( + DivertValueTaskSource source, + ReadOnlyMemory buffer, + ReadOnlyMemory addresses, + CancellationToken cancellationToken + ) + { + return source.ExecuteAsync( + this, + MemoryMarshal.AsMemory(buffer), + MemoryMarshal.AsMemory(addresses), + cancellationToken + ); + } +} diff --git a/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs new file mode 100644 index 0000000..7c77360 --- /dev/null +++ b/Divert.Windows/AsyncOperation/DivertValueTaskSource.cs @@ -0,0 +1,136 @@ +using System.ComponentModel; +using System.Runtime.InteropServices; +using System.Threading.Channels; +using System.Threading.Tasks.Sources; +using Windows.Win32.Foundation; + +namespace Divert.Windows.AsyncOperation; + +internal sealed unsafe class DivertValueTaskSource : IDisposable, IValueTaskSource, IOCompletionHandler +{ + private readonly Channel pool; + private readonly IOCompletionOperation ioCompletionOperation; + + private ManualResetValueTaskSourceCore source; + private PendingOperation pendingOperation; + + public DivertValueTaskSource( + Channel pool, + DivertHandle divertHandle, + ThreadPoolBoundHandle threadPoolBoundHandle, + bool runContinuationsAsynchronously + ) + { + this.pool = pool; + ioCompletionOperation = new IOCompletionOperation( + divertHandle, + threadPoolBoundHandle, + this + ); + source = new ManualResetValueTaskSourceCore + { + RunContinuationsAsynchronously = runContinuationsAsynchronously, + }; + } + + private SafeHandle SafeHandle => ioCompletionOperation.SafeHandle; + + public void Dispose() + { + pendingOperation.Dispose(); + ioCompletionOperation.Dispose(); + } + + public int GetResult(short token) + { + try + { + return source.GetResult(token); + } + finally + { + source.Reset(); + if (!pool.Writer.TryWrite(this)) + { + Dispose(); + } + } + } + + public ValueTaskSourceStatus GetStatus(short token) => source.GetStatus(token); + + public void OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags + ) => source.OnCompleted(continuation, state, token, flags); + + private ref readonly PendingOperation PrepareOperation( + Memory packetBuffer, + Memory addresses, + CancellationToken cancellationToken + ) + { + var nativeOverlapped = ioCompletionOperation.Prepare(cancellationToken); + pendingOperation = new PendingOperation(nativeOverlapped, packetBuffer, addresses, cancellationToken); + return ref pendingOperation; + } + + public void OnCompleted(uint errorCode, uint numBytes) + { + using var _ = pendingOperation; + if (errorCode is (uint)WIN32_ERROR.ERROR_SUCCESS) + { + source.SetResult((int)numBytes); + } + else if (errorCode is (uint)WIN32_ERROR.ERROR_OPERATION_ABORTED) + { + var token = pendingOperation.CancellationToken; + var exception = token.IsCancellationRequested + ? new OperationCanceledException(token) + : new OperationCanceledException(); + source.SetException(exception); + } + else + { + var exception = new Win32Exception((int)errorCode); + source.SetException(exception); + } + } + + public ValueTask ExecuteAsync( + TExecutor executor, + Memory buffer, + Memory addresses, + CancellationToken cancellationToken + ) + where TExecutor : IDivertValueTaskExecutor + { + try + { + ref readonly var pendingOperation = ref PrepareOperation(buffer, addresses, cancellationToken); + executor.DelayExecutionInTests(); + bool success = executor.Execute(SafeHandle, in pendingOperation); + if (!success) + { + int error = Marshal.GetLastPInvokeError(); + if (error is (int)WIN32_ERROR.ERROR_IO_PENDING) + { + ioCompletionOperation.CancelWhenRequested(pendingOperation.NativeOverlapped); + } + else + { + Dispose(); + return ValueTask.FromException(new Win32Exception(error)); + } + } + return new ValueTask(this, source.Version); + } + catch + { + Dispose(); + throw; + } + } +} diff --git a/Divert.Windows/AsyncOperation/IDivertValueTaskExecutor.cs b/Divert.Windows/AsyncOperation/IDivertValueTaskExecutor.cs new file mode 100644 index 0000000..b8d5394 --- /dev/null +++ b/Divert.Windows/AsyncOperation/IDivertValueTaskExecutor.cs @@ -0,0 +1,28 @@ +using System.Diagnostics; +using System.IO.Pipes; +using System.Runtime.InteropServices; + +namespace Divert.Windows.AsyncOperation; + +internal interface IDivertValueTaskExecutor +{ + bool Execute(SafeHandle safeHandle, ref readonly PendingOperation pendingOperation); +} + +internal static class DivertValueTaskExecutorDelay +{ + private const string DIVERT_WINDOWS_TESTS = "DIVERT_WINDOWS_TESTS"; + + [Conditional(DIVERT_WINDOWS_TESTS)] + public static void DelayExecutionInTests(this IDivertValueTaskExecutor executor) + { + if (Environment.GetEnvironmentVariable(DIVERT_WINDOWS_TESTS) is not string name) + { + return; + } + + using var stream = new NamedPipeClientStream(".", name, PipeDirection.InOut); + stream.Connect(); // Notify delay + stream.ReadByte(); // continue + } +} diff --git a/Divert.Windows/AsyncOperation/IOCompletionOperation.cs b/Divert.Windows/AsyncOperation/IOCompletionOperation.cs new file mode 100644 index 0000000..733ed55 --- /dev/null +++ b/Divert.Windows/AsyncOperation/IOCompletionOperation.cs @@ -0,0 +1,87 @@ +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Divert.Windows.AsyncOperation; + +internal interface IOCompletionHandler +{ + void OnCompleted(uint errorCode, uint numBytes); +} + +internal sealed unsafe class IOCompletionOperation : IDisposable, IOCompletionHandler + where THandler : IOCompletionHandler +{ + private static void OnIOCompleted(uint errorCode, uint numBytes, NativeOverlapped* pOVERLAP) + { + var operation = (IOCompletionOperation)ThreadPoolBoundHandle.GetNativeOverlappedState(pOVERLAP)!; + operation.OnCompleted(errorCode, numBytes); + } + + private static readonly IOCompletionCallback ioCompletionCallback = OnIOCompleted; + + private readonly ThreadPoolBoundHandle threadPoolBoundHandle; + private readonly CancellationHandle cancellationHandle; + private readonly PreAllocatedOverlapped preAllocatedOverlapped; + private readonly THandler handler; + + private CancellationTokenRegistration cancellationRegistration; + private NativeOverlapped* nativeOverlapped; + + public IOCompletionOperation(SafeHandle safeHandle, ThreadPoolBoundHandle threadPoolBoundHandle, THandler handler) + { + cancellationHandle = new CancellationHandle(safeHandle); + this.threadPoolBoundHandle = threadPoolBoundHandle; + this.handler = handler; + preAllocatedOverlapped = new PreAllocatedOverlapped(ioCompletionCallback, this, null); + } + + public SafeHandle SafeHandle => cancellationHandle.SafeHandle; + + public void Dispose() + { + if (nativeOverlapped is not null) + { + threadPoolBoundHandle.FreeNativeOverlapped(nativeOverlapped); + nativeOverlapped = null; + } + preAllocatedOverlapped.Dispose(); + cancellationHandle.Dispose(); + } + + public NativeOverlapped* Prepare(CancellationToken cancellationToken) + { + Debug.Assert(nativeOverlapped is null); + nativeOverlapped = threadPoolBoundHandle.AllocateNativeOverlapped(preAllocatedOverlapped); + cancellationRegistration = cancellationToken.CanBeCanceled + ? cancellationToken.UnsafeRegister( + static state => + { + var operation = (IOCompletionOperation)state!; + operation.cancellationHandle.RequestOrInvokeCancel(operation.nativeOverlapped); + }, + this + ) + : default; + return nativeOverlapped; + } + + public void CancelWhenRequested(NativeOverlapped* nativeOverlapped) + { + cancellationHandle.CancelWhenRequested(nativeOverlapped); + } + + public void OnCompleted(uint errorCode, uint numBytes) + { + cancellationRegistration.Dispose(); + try + { + handler.OnCompleted(errorCode, numBytes); + } + finally + { + cancellationHandle.Reset(); + threadPoolBoundHandle.FreeNativeOverlapped(nativeOverlapped); + nativeOverlapped = null; + } + } +} diff --git a/Divert.Windows/AsyncOperation/PendingOperation.cs b/Divert.Windows/AsyncOperation/PendingOperation.cs new file mode 100644 index 0000000..b6ca76d --- /dev/null +++ b/Divert.Windows/AsyncOperation/PendingOperation.cs @@ -0,0 +1,28 @@ +using System.Buffers; + +namespace Divert.Windows.AsyncOperation; + +internal unsafe struct PendingOperation( + NativeOverlapped* nativeOverlapped, + Memory packetBuffer, + Memory addresses, + CancellationToken cancellationToken +) : IDisposable +{ + private MemoryHandle packetBufferHandle = packetBuffer.Pin(); + private MemoryHandle addressesHandle = addresses.Pin(); + + public readonly NativeOverlapped* NativeOverlapped => nativeOverlapped; + public readonly MemoryHandle PacketBufferHandle => packetBufferHandle; + public readonly MemoryHandle AddressesHandle => addressesHandle; + + public readonly Memory PacketBuffer => packetBuffer; + public readonly Memory Addresses => addresses; + public readonly CancellationToken CancellationToken => cancellationToken; + + public void Dispose() + { + packetBufferHandle.Dispose(); + addressesHandle.Dispose(); + } +} diff --git a/Divert.Windows/CString.cs b/Divert.Windows/CString.cs index 6f593f8..de23046 100644 --- a/Divert.Windows/CString.cs +++ b/Divert.Windows/CString.cs @@ -2,14 +2,9 @@ namespace Divert.Windows; -internal class CString : IDisposable +internal sealed class CString(string str) : IDisposable { - internal IntPtr Ptr { get; } - - public CString(string str) - { - Ptr = Marshal.StringToHGlobalAnsi(str); - } + internal IntPtr Pointer { get; } = Marshal.StringToHGlobalAnsi(str); private bool disposed; @@ -17,14 +12,8 @@ public void Dispose() { if (!disposed) { - Marshal.FreeHGlobal(Ptr); + Marshal.FreeHGlobal(Pointer); disposed = true; } - GC.SuppressFinalize(this); - } - - ~CString() - { - Dispose(); } } diff --git a/Divert.Windows/Constants.cs b/Divert.Windows/Constants.cs new file mode 100644 index 0000000..d3774c5 --- /dev/null +++ b/Divert.Windows/Constants.cs @@ -0,0 +1,17 @@ +namespace Divert.Windows; + +internal static class Constants +{ + public const short WINDIVERT_PRIORITY_HIGHEST = 3000; + public const short WINDIVERT_PRIORITY_LOWEST = -WINDIVERT_PRIORITY_HIGHEST; + public const int WINDIVERT_PARAM_QUEUE_LENGTH_DEFAULT = 4096; + public const int WINDIVERT_PARAM_QUEUE_LENGTH_MIN = 32; + public const int WINDIVERT_PARAM_QUEUE_LENGTH_MAX = 16384; + public const int WINDIVERT_PARAM_QUEUE_TIME_DEFAULT = 2000; + public const int WINDIVERT_PARAM_QUEUE_TIME_MIN = 100; + public const int WINDIVERT_PARAM_QUEUE_TIME_MAX = 16000; + public const int WINDIVERT_PARAM_QUEUE_SIZE_DEFAULT = 4 * 1024 * 1024; + public const int WINDIVERT_PARAM_QUEUE_SIZE_MIN = 64 * 1024; + public const int WINDIVERT_PARAM_QUEUE_SIZE_MAX = 32 * 1024 * 1024; + public const int WINDIVERT_BATCH_MAX = byte.MaxValue; +} diff --git a/Divert.Windows/Divert.Windows.csproj b/Divert.Windows/Divert.Windows.csproj index 6ae7dc6..4f68965 100644 --- a/Divert.Windows/Divert.Windows.csproj +++ b/Divert.Windows/Divert.Windows.csproj @@ -1,19 +1,32 @@ - - net6.0 + $(DefaultTargetFramework) x64 - enable - enable - true + true + + + + $(DefineConstants);DIVERT_WINDOWS_TESTS + + $(MSBuildProjectDirectory)=$(MSBuildProjectName) + + + + + all + + + + - + all - runtime; build; native; contentfiles; analyzers - - \ No newline at end of file + + + + diff --git a/Divert.Windows/DivertAddress.cs b/Divert.Windows/DivertAddress.cs index 7a8c5d7..0a43dd8 100644 --- a/Divert.Windows/DivertAddress.cs +++ b/Divert.Windows/DivertAddress.cs @@ -1,90 +1,201 @@ using System.Net; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; namespace Divert.Windows; +/// +/// Represents the address of a captured or injected packet. +/// [StructLayout(LayoutKind.Sequential)] -unsafe public struct DivertAddress +public unsafe struct DivertAddress { + /// + /// Network layer data. + /// public struct NetworkData { + /// + /// The interface index on which the packet was captured or to which it will be injected. + /// public uint InterfaceIndex; + + /// + /// The sub-interface index of the interface. + /// public uint SubInterfaceIndex; } + /// + /// Flow layer data. + /// public struct FlowData { + /// + /// The endpoint ID of the flow. + /// public ulong EndpointId; + + /// + /// The parent endpoint ID of the flow. + /// public ulong ParentEndpointId; + + /// + /// The process ID associated with the flow. + /// public uint ProcessId; + + /// + /// The local IP address of the flow. + /// public IPAddress LocalAddress; + + /// + /// The remote IP address of the flow. + /// public IPAddress RemoteAddress; + + /// + /// The local port of the flow. + /// public ushort LocalPort; + + /// + /// The remote port of the flow. + /// public ushort RemotePort; + + /// + /// The protocol of the flow. + /// public byte Protocol; } + /// + /// Socket layer data. + /// public struct SocketData { + /// + /// The endpoint ID of the socket operation. + /// public ulong EndpointId; + + /// + /// The parent endpoint ID of the socket operation. + /// public ulong ParentEndpointId; + + /// + /// The process ID associated with the socket operation. + /// public uint ProcessId; + + /// + /// The local IP address of the socket operation. + /// public IPAddress LocalAddress; + + /// + /// The remote IP address of the socket operation. + /// public IPAddress RemoteAddress; + + /// + /// The local port of the socket operation. + /// public ushort LocalPort; + + /// + /// The remote port of the socket operation. + /// public ushort RemotePort; + + /// + /// The protocol of the socket operation. + /// public byte Protocol; } + /// + /// Reflect layer data. + /// public struct ReflectData { + /// + /// The timestamp when the handle was opened. + /// public long Timestamp; + + /// + /// The process ID that opened the handle. + /// public uint ProcessId; + + /// + /// The layer parameter used when opening the handle. + /// public DivertLayer Layer; + + /// + /// The flags parameter used when opening the handle. + /// public DivertFlags Flags; + + /// + /// The priority parameter used when opening the handle. + /// public short Priority; } private WINDIVERT_ADDRESS address; - internal WINDIVERT_ADDRESS Struct => address; - - internal DivertAddress(WINDIVERT_ADDRESS address) - { - this.address = address; - } - - public DivertAddress(int interfaceIndex, int subInterfaceIndex) + /// + /// Initializes a new instance of the struct with the specified interface and + /// sub-interface indices. + /// + /// The interface index. + /// The sub-interface index. + public DivertAddress(int interfaceIndex, int subInterfaceIndex = 0) { address = new WINDIVERT_ADDRESS { Network = new WINDIVERT_DATA_NETWORK { IfIdx = checked((uint)interfaceIndex), - SubIfIdx = checked((uint)subInterfaceIndex) - } + SubIfIdx = checked((uint)subInterfaceIndex), + }, }; } - public long Timestamp + /// + /// Clears all fields of the . + /// + public void Reset() { - get { return address.Timestamp; } - set { address.Timestamp = value; } + fixed (WINDIVERT_ADDRESS* pAddress = &address) + { + Unsafe.InitBlock(pAddress, 0, (uint)sizeof(WINDIVERT_ADDRESS)); + } } - public DivertLayer Layer - { - get { return (DivertLayer)address.Layer; } - set { address.Layer = (byte)value; } - } + /// + /// Gets the timestamp of the event. + /// + public readonly long Timestamp => address.Timestamp; - public DivertEvent Event - { - get { return (DivertEvent)address.Event; } - set { address.Layer = (byte)value; } - } + /// + /// Gets the layer of the event. + /// + public readonly DivertLayer Layer => (DivertLayer)address.Layer; + + /// + /// Gets the event type. + /// + public readonly DivertEvent Event => (DivertEvent)address.Event; - private bool GetBit(WINDIVERT_ADDRESS_BITS bit) + private readonly bool GetBit(WINDIVERT_ADDRESS_BITS bit) { return (address.Bits & bit) != 0; } @@ -101,54 +212,75 @@ private void SetBit(WINDIVERT_ADDRESS_BITS bit, bool value) } } - public bool IsSniffed - { - get { return GetBit(WINDIVERT_ADDRESS_BITS.Sniffed); } - set { SetBit(WINDIVERT_ADDRESS_BITS.Sniffed, value); } - } + /// + /// Gets a value indicating whether the packet/event was sniffed (not blocked). + /// + public readonly bool IsSniffed => GetBit(WINDIVERT_ADDRESS_BITS.Sniffed); + /// + /// Gets or sets a value indicating whether the packet/event is outbound. + /// public bool IsOutbound { - get { return GetBit(WINDIVERT_ADDRESS_BITS.Outbound); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.Outbound); } set { SetBit(WINDIVERT_ADDRESS_BITS.Outbound, value); } } - public bool IsLoopback - { - get { return GetBit(WINDIVERT_ADDRESS_BITS.Loopback); } - set { SetBit(WINDIVERT_ADDRESS_BITS.Loopback, value); } - } + /// + /// Gets a value indicating whether the packet is loopback. + /// + public readonly bool IsLoopback => GetBit(WINDIVERT_ADDRESS_BITS.Loopback); + /// + /// Gets or sets a value indicating impostor packets. + /// public bool IsImpostor { - get { return GetBit(WINDIVERT_ADDRESS_BITS.Impostor); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.Impostor); } set { SetBit(WINDIVERT_ADDRESS_BITS.Impostor, value); } } - public bool IsIPv6 - { - get { return GetBit(WINDIVERT_ADDRESS_BITS.IPv6); } - set { SetBit(WINDIVERT_ADDRESS_BITS.IPv6, value); } - } + /// + /// Gets a value indicating whether the packet/event is IPv6. + /// + public readonly bool IsIPv6 => GetBit(WINDIVERT_ADDRESS_BITS.IPv6); + /// + /// Gets or sets a value indicating whether the IP checksum is valid. + /// public bool IsIPChecksumValid { - get { return GetBit(WINDIVERT_ADDRESS_BITS.IPChecksum); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.IPChecksum); } set { SetBit(WINDIVERT_ADDRESS_BITS.IPChecksum, value); } } + /// + /// Gets or sets a value indicating whether the TCP checksum is valid. + /// public bool IsTCPChecksumValid { - get { return GetBit(WINDIVERT_ADDRESS_BITS.TCPChecksum); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.TCPChecksum); } set { SetBit(WINDIVERT_ADDRESS_BITS.TCPChecksum, value); } } + /// + /// Gets or sets a value indicating whether the UDP checksum is valid. + /// public bool IsUDPChecksumValid { - get { return GetBit(WINDIVERT_ADDRESS_BITS.UDPChecksum); } + readonly get { return GetBit(WINDIVERT_ADDRESS_BITS.UDPChecksum); } set { SetBit(WINDIVERT_ADDRESS_BITS.UDPChecksum, value); } } + /// + /// Gets the network data associated with the event. + /// + /// + /// The associated with the event. + /// + /// + /// The is not or . + /// public NetworkData GetNetworkData() { return Layer switch @@ -156,16 +288,18 @@ public NetworkData GetNetworkData() DivertLayer.Network or DivertLayer.Forward => new NetworkData { InterfaceIndex = address.Network.IfIdx, - SubInterfaceIndex = address.Network.SubIfIdx + SubInterfaceIndex = address.Network.SubIfIdx, }, _ => throw new InvalidOperationException($"{nameof(Layer)}: {Layer}"), }; } - private IPAddress GetIPAddress(Span bytes) + private readonly IPAddress GetIPAddress(Span bytes) { - bytes.Reverse(); - var address = new IPAddress(bytes); + Span beBytes = stackalloc byte[bytes.Length]; + bytes.CopyTo(beBytes); + beBytes.Reverse(); + var address = new IPAddress(beBytes); if (!IsIPv6) { address = address.MapToIPv4(); @@ -173,6 +307,15 @@ private IPAddress GetIPAddress(Span bytes) return address; } + /// + /// Gets the flow data associated with the event. + /// + /// + /// The associated with the event. + /// + /// + /// The is not . + /// public FlowData GetFlowData() { switch (Layer) @@ -188,13 +331,22 @@ public FlowData GetFlowData() RemoteAddress = GetIPAddress(new Span(flow.RemoteAddr, 16)), LocalPort = address.Flow.LocalPort, RemotePort = address.Flow.RemotePort, - Protocol = address.Flow.Protocol + Protocol = address.Flow.Protocol, }; default: - throw new InvalidOperationException(Layer.ToString()); + throw new InvalidOperationException($"{nameof(Layer)}: {Layer}"); } } + /// + /// Gets the socket data associated with the event. + /// + /// + /// The associated with the event. + /// + /// + /// The is not . + /// public SocketData GetSocketData() { switch (Layer) @@ -210,13 +362,22 @@ public SocketData GetSocketData() RemoteAddress = GetIPAddress(new Span(socket.RemoteAddr, 16)), LocalPort = address.Socket.LocalPort, RemotePort = address.Socket.RemotePort, - Protocol = address.Socket.Protocol + Protocol = address.Socket.Protocol, }; default: - throw new InvalidOperationException(Layer.ToString()); + throw new InvalidOperationException($"{nameof(Layer)}: {Layer}"); } } + /// + /// Gets the reflect data associated with the event. + /// + /// + /// The associated with the event. + /// + /// + /// The is not . + /// public ReflectData GetReflectData() { return Layer switch @@ -227,9 +388,9 @@ public ReflectData GetReflectData() ProcessId = address.Reflect.ProcessId, Layer = (DivertLayer)address.Reflect.Layer, Flags = (DivertFlags)address.Reflect.Flags, - Priority = address.Reflect.Priority + Priority = address.Reflect.Priority, }, - _ => throw new InvalidOperationException(Layer.ToString()), + _ => throw new InvalidOperationException($"{nameof(Layer)}: {Layer}"), }; } } diff --git a/Divert.Windows/DivertEvent.cs b/Divert.Windows/DivertEvent.cs index 9342c58..eb22284 100644 --- a/Divert.Windows/DivertEvent.cs +++ b/Divert.Windows/DivertEvent.cs @@ -1,15 +1,57 @@ namespace Divert.Windows; +/// +/// Types of Divert events. +/// public enum DivertEvent { - NetworkPacket = 0, - FlowEstablished = 1, - FlowDeleted = 2, - SocketBind = 3, - SocketConnect = 4, - SocketListen = 5, - SocketAccept = 6, - SocketClose = 7, - ReflectOpen = 8, - ReflectClose = 9, + /// + /// A network packet. + /// + NetworkPacket = WINDIVERT_EVENT.WINDIVERT_EVENT_NETWORK_PACKET, + + /// + /// A flow has been established. + /// + FlowEstablished = WINDIVERT_EVENT.WINDIVERT_EVENT_FLOW_ESTABLISHED, + + /// + /// A flow has been deleted. + /// + FlowDeleted = WINDIVERT_EVENT.WINDIVERT_EVENT_FLOW_DELETED, + + /// + /// A bind() socket operation. + /// + SocketBind = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_BIND, + + /// + /// A connect() socket operation. + /// + SocketConnect = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_CONNECT, + + /// + /// A listen() socket operation. + /// + SocketListen = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_LISTEN, + + /// + /// An accept() socket operation. + /// + SocketAccept = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_ACCEPT, + + /// + /// A socket is unbound or a connection is closed. + /// + SocketClose = WINDIVERT_EVENT.WINDIVERT_EVENT_SOCKET_CLOSE, + + /// + /// A new WinDivert handle was opened. + /// + ReflectOpen = WINDIVERT_EVENT.WINDIVERT_EVENT_REFLECT_OPEN, + + /// + /// A WinDivert handle was closed. + /// + ReflectClose = WINDIVERT_EVENT.WINDIVERT_EVENT_REFLECT_CLOSE, } diff --git a/Divert.Windows/DivertFilter.cs b/Divert.Windows/DivertFilter.cs index 57cc1e3..b6fb48f 100644 --- a/Divert.Windows/DivertFilter.cs +++ b/Divert.Windows/DivertFilter.cs @@ -1,14 +1,25 @@ +using System.Net.Sockets; using System.Text; using System.Text.RegularExpressions; namespace Divert.Windows; -internal record class ReplaceParenthesesOperation(string Expression); +internal record struct ReplaceParenthesesOperation(string Expression); -public class DivertFilter +/// +/// Helper type to build WinDivert filter expressions. +/// +public partial class DivertFilter { + /// + /// The filter clause. + /// public string Clause { get; } + /// + /// Creates a new filter with the specified clause. + /// + /// The filter clause. public DivertFilter(string clause) { ArgumentNullException.ThrowIfNull(clause); @@ -16,12 +27,8 @@ public DivertFilter(string clause) Clause = clause; } - public static implicit operator DivertFilter(string clause) - { - return new DivertFilter(clause); - } - - private static string ReplaceParentheses(string expression) + // e.g. "(a and (b or c)) or d" -> "() or d" + private static string CollapseParentheses(string expression) { var builder = new StringBuilder(); int index = 0; @@ -65,12 +72,19 @@ private static string ReplaceParentheses(string expression) return builder.ToString(); } - private static readonly string orOp = Regex.Escape("||"); + [GeneratedRegex(@"\s(or|\|\|)\s")] + private static partial Regex OrPatternRegex(); - private static bool MatchOrPattern(string s) => Regex.IsMatch(ReplaceParentheses(s), @$"\s(or|{orOp})\s"); + private static bool MatchOrPattern(string s) => OrPatternRegex().IsMatch(CollapseParentheses(s)); - private static bool MatchAndPattern(string s) => Regex.IsMatch(ReplaceParentheses(s), @"\s(and|&&)\s"); + [GeneratedRegex(@"\s(and|&&)\s")] + private static partial Regex AndPatternRegex(); + private static bool MatchAndPattern(string s) => AndPatternRegex().IsMatch(CollapseParentheses(s)); + + /// + /// Combines two filters with an AND operation. + /// public static DivertFilter operator &(DivertFilter left, DivertFilter right) { string leftClause = left.Clause; @@ -87,6 +101,9 @@ private static string ReplaceParentheses(string expression) return new DivertFilter(clause); } + /// + /// Combines two filters with an OR operation. + /// public static DivertFilter operator |(DivertFilter left, DivertFilter right) { string leftClause = left.Clause; @@ -103,15 +120,41 @@ private static string ReplaceParentheses(string expression) return new DivertFilter(clause); } + /// + /// Returns the filter clause. + /// + /// The filter clause. public override string ToString() { return Clause; } + /// + public override bool Equals(object? obj) + { + return obj switch + { + null => false, + DivertFilter filter => Clause == filter.Clause, + string s => Clause == s, + _ => false, + }; + } + + /// + public override int GetHashCode() => Clause.GetHashCode(); + + /// + /// A layer specific property for matching packets/events. + /// public class Field { private readonly string field; + /// + /// Creates a new field with the specified name. + /// + /// The field name. public Field(string field) { ArgumentNullException.ThrowIfNull(field); @@ -119,92 +162,197 @@ public Field(string field) this.field = field; } + /// public override bool Equals(object? obj) { - if (obj is null) - { - return false; - } - else if (obj is Field filter) + return obj switch { - return field == filter.field; - } - else if (obj is string s) - { - return field == s; - } - else - { - return false; - } + null => false, + Field filter => field == filter.field, + string s => field == s, + _ => false, + }; } - public override int GetHashCode() => base.GetHashCode(); + /// + public override int GetHashCode() => field.GetHashCode(); + /// + /// Returns the field name. + /// + /// The field name. public override string ToString() => field; + /// + /// Implicitly converts a string to a Field. + /// + /// The field name. public static implicit operator Field(string field) { return new Field(field); } + /// + /// Implicitly converts a Field to a DivertFilter. + /// public static implicit operator DivertFilter(Field field) { return new DivertFilter(field.ToString()); } + /// + /// Combines two fields with an AND operation. + /// public static DivertFilter operator &(Field left, Field right) { DivertFilter value = left; return value & right; } + /// + /// Combines two fields with an OR operation. + /// public static DivertFilter operator |(Field left, Field right) { DivertFilter value = left; return value | right; } + /// + /// Negates a field. + /// public static DivertFilter operator !(Field value) { - if (MatchAndPattern(value.field) || MatchOrPattern(value.field)) - { - value = new Field($"({value.field})"); - } string clause = $"not {value}"; return new DivertFilter(clause); } + private static string Macro(bool value) => value ? "TRUE" : "FALSE"; + + private static string Macro(ProtocolType protocolType) => + protocolType switch + { + ProtocolType.Tcp => "TCP", + ProtocolType.Udp => "UDP", + ProtocolType.Icmp => "ICMP", + ProtocolType.IcmpV6 => "ICMPV6", + _ => ((int)protocolType).ToString(), + }; + + private static string Macro(DivertEvent e) => + e switch + { + DivertEvent.NetworkPacket => "PACKET", + DivertEvent.FlowEstablished => "ESTABLISHED", + DivertEvent.FlowDeleted => "DELETED", + DivertEvent.SocketBind => "BIND", + DivertEvent.SocketConnect => "CONNECT", + DivertEvent.SocketListen => "LISTEN", + DivertEvent.SocketAccept => "ACCEPT", + DivertEvent.ReflectOpen => "OPEN", + DivertEvent.SocketClose or DivertEvent.ReflectClose => "CLOSE", + _ => e.ToString(), + }; + + private static string Macro(DivertLayer layer) => + layer switch + { + DivertLayer.Network => "NETWORK", + DivertLayer.Forward => "NETWORK_FORWARD", + DivertLayer.Flow => "FLOW", + DivertLayer.Socket => "SOCKET", + DivertLayer.Reflect => "REFLECT", + _ => layer.ToString(), + }; + + /// + /// Equality operator for Field and object. + /// public static DivertFilter operator ==(Field left, object right) { string clause = $"{left} = {right}"; return new DivertFilter(clause); } + /// + /// Inequality operator for Field and object. + /// public static DivertFilter operator !=(Field left, object right) { string clause = $"{left} != {right}"; return new DivertFilter(clause); } + /// + /// Equality operator for Field and bool. + /// + public static DivertFilter operator ==(Field left, bool right) => left == Macro(right); + + /// + /// Inequality operator for Field and bool. + /// + public static DivertFilter operator !=(Field left, bool right) => left != Macro(right); + + /// + /// Equality operator for Field and . + /// + public static DivertFilter operator ==(Field left, ProtocolType right) => left == Macro(right); + + /// + /// Inequality operator for Field and . + /// + public static DivertFilter operator !=(Field left, ProtocolType right) => left != Macro(right); + + /// + /// Equality operator for Field and . + /// + public static DivertFilter operator ==(Field left, DivertEvent right) => left == Macro(right); + + /// + /// Inequality operator for Field and . + /// + public static DivertFilter operator !=(Field left, DivertEvent right) => left != Macro(right); + + /// + /// Equality operator for Field and . + /// + public static DivertFilter operator ==(Field left, DivertLayer right) => left == Macro(right); + + /// + /// Inequality operator for Field and . + /// + public static DivertFilter operator !=(Field left, DivertLayer right) => left != Macro(right); + + /// + /// Less than operator for Field and object. + /// public static DivertFilter operator <(Field left, object right) { string clause = $"{left} < {right}"; return new DivertFilter(clause); } + /// + /// Greater than operator for Field and object. + /// public static DivertFilter operator >(Field left, object right) { string clause = $"{left} > {right}"; return new DivertFilter(clause); } + /// + /// Less than or equal operator for Field and object. + /// public static DivertFilter operator <=(Field left, object right) { string clause = $"{left} <= {right}"; return new DivertFilter(clause); } + /// + /// Greater than or equal operator for Field and object. + /// public static DivertFilter operator >=(Field left, object right) { string clause = $"{left} >= {right}"; @@ -212,73 +360,195 @@ public static implicit operator DivertFilter(Field field) } } + /// + /// The true expression. + /// public static DivertFilter True { get; } = "true"; + /// + /// The false expression. + /// public static DivertFilter False { get; } = "false"; + /// + /// The value zero. + /// public static Field Zero { get; } = "zero"; + /// + /// The packet/event timestamp. + /// public static Field Timestamp { get; } = "timestamp"; + /// + /// The event type. + /// public static Field Event { get; } = "event"; + /// + /// Is outbound? + /// public static Field Outbound { get; } = "outbound"; + /// + /// Is inbound? + /// public static Field Inbound { get; } = "inbound"; + /// + /// The interface index. + /// public static Field InterfaceIndex { get; } = "ifIdx"; + /// + /// The sub-interface index. + /// public static Field SubInterfaceIndex { get; } = "subIfIdx"; + /// + /// Is loopback packet? + /// public static Field Loopback { get; } = "loopback"; + /// + /// Is impostor packet? + /// public static Field Impostor { get; } = "impostor"; + /// + /// Is IPv4 fragment? + /// public static Field Fragment { get; } = "fragment"; + /// + /// The endpoint ID. + /// public static Field EndpointId { get; } = "endpointId"; + /// + /// The parent endpoint ID. + /// public static Field ParentEndpointId { get; } = "parentEndpointId"; + /// + /// The process ID. + /// public static Field ProcessId { get; } = "processId"; + /// + /// 8-bit random number. + /// public static Field Random8 { get; } = "random8"; + /// + /// 16-bit random number. + /// public static Field Random16 { get; } = "random16"; + /// + /// 32-bit random number. + /// public static Field Random32 { get; } = "random32"; + /// + /// The layer of the WinDivert handle. + /// public static Field Layer { get; } = "layer"; + /// + /// The priority of the WinDivert handle. + /// public static Field Priority { get; } = "priority"; + /// + /// The i-th byte of the packet. + /// public static Field Packet(int i) => $"packet[{i}]"; + /// + /// The i-th 16-bit word of the packet. + /// public static Field Packet16(int i) => $"packet16[{i}]"; + /// + /// The i-th 32-bit word of the packet. + /// public static Field Packet32(int i) => $"packet32[{i}]"; + /// + /// The packet length. + /// public static Field Length { get; } = "length"; + /// + /// Is IPv4? + /// public static Field Ip { get; } = "ip"; + /// + /// Is IPv6? + /// public static Field Ipv6 { get; } = "ipv6"; + /// + /// Is ICMP? + /// public static Field ICMP { get; } = "icmp"; + /// + /// Is ICMPv6? + /// + // spell-checker:ignore icmpv6 public static Field ICMPv6 { get; } = "icmpv6"; + /// + /// Is TCP? + /// public static Field TCP { get; } = "tcp"; + /// + /// Is UDP? + /// public static Field UDP { get; } = "udp"; + /// + /// The protocol. + /// public static Field Protocol { get; } = "protocol"; + /// + /// The local address. + /// public static Field LocalAddress { get; } = "localAddr"; + /// + /// The local port. + /// public static Field LocalPort { get; } = "localPort"; + /// + /// The remote address. + /// public static Field RemoteAddress { get; } = "remoteAddr"; + /// + /// The remote port. + /// public static Field RemotePort { get; } = "remotePort"; + + /// + /// Implicitly converts a string to a . + /// + public static implicit operator DivertFilter(string clause) + { + return new DivertFilter(clause); + } + + /// + /// Implicitly converts a bool to a . + /// + public static implicit operator DivertFilter(bool value) + { + return value ? True : False; + } } diff --git a/Divert.Windows/DivertFlags.cs b/Divert.Windows/DivertFlags.cs index 7f3dfd7..5cf99c8 100644 --- a/Divert.Windows/DivertFlags.cs +++ b/Divert.Windows/DivertFlags.cs @@ -1,15 +1,53 @@ namespace Divert.Windows; +/// +/// Flags for configuring the behavior of the WinDivert handle. +/// [Flags] public enum DivertFlags { + /// + /// Default mode: packets are dropped and captured. + /// None = 0, + + /// + /// Packet sniffing mode: packets are captured but not dropped. + /// Sniff = 0x0001, + + /// + /// Packets are silently dropped. + /// Drop = 0x0002, + + /// + /// Receive-only mode: disables sending packets. + /// ReceiveOnly = 0x0004, + + /// + /// Same as . + /// ReadOnly = ReceiveOnly, + + /// + /// Send-only mode: disables receiving packets. + /// SendOnly = 0x0008, + + /// + /// Same as . + /// WriteOnly = SendOnly, + + /// + /// Prevents automatic installation of the WinDivert driver. + /// NoInstall = 0x0010, - Fragments = 0x0020 + + /// + /// Captures fragmented packets. + /// + Fragments = 0x0020, } diff --git a/Divert.Windows/DivertHandle.cs b/Divert.Windows/DivertHandle.cs new file mode 100644 index 0000000..2575fb3 --- /dev/null +++ b/Divert.Windows/DivertHandle.cs @@ -0,0 +1,32 @@ +using Microsoft.Win32.SafeHandles; + +namespace Divert.Windows; + +/// +/// Safe handle for a WinDivert handle. +/// +public sealed class DivertHandle : SafeHandleZeroOrMinusOneIsInvalid +{ + /// + /// Creates a new instance of the class. + /// + /// + /// The WinDivert handle. + /// + /// + /// Whether the handle should be released when the SafeHandle is disposed. + /// + public DivertHandle(IntPtr handle, bool ownsHandle = true) + : base(ownsHandle) + { + this.handle = handle; + } + + /// + /// Releases the WinDivert handle. + /// + /// + /// true if the handle was released successfully; otherwise, false. + /// + protected override bool ReleaseHandle() => NativeMethods.WinDivertClose(handle); +} diff --git a/Divert.Windows/DivertHelper.cs b/Divert.Windows/DivertHelper.cs index c6d1b22..60ae7dc 100644 --- a/Divert.Windows/DivertHelper.cs +++ b/Divert.Windows/DivertHelper.cs @@ -1,22 +1,200 @@ using System.ComponentModel; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using Windows.Win32.Foundation; namespace Divert.Windows; -unsafe public static class DivertHelper +/// +/// WinDivert helper methods. +/// +public static unsafe class DivertHelper { - public static void CalculateChecksums(Span packet, DivertHelperFlags flags) + /// + /// Calculates the checksums for the specified packet. + /// + /// + /// The packet data. + /// + /// The flags to disable individual checksum calculations. + /// true if the checksums were calculated successfully; otherwise, false. + public static bool CalculateChecksums(Span packet, DivertHelperFlags flags = DivertHelperFlags.None) { fixed (byte* pPacket = packet) { - bool success = NativeMethods.WinDivertHelperCalcChecksums( - pPacket, (uint)packet.Length, - null, - (ulong)flags); + return NativeMethods.WinDivertHelperCalcChecksums(pPacket, (uint)packet.Length, null, (ulong)flags); + } + } + + /// + /// Calculates the checksums for the specified packet. + /// + /// + /// The packet data. + /// + /// + /// A reference to a structure where the corresponding ChecksumValid fields will be set. + /// + /// The flags to disable individual checksum calculations. + /// true if the checksums were calculated successfully; otherwise, false. + public static bool CalculateChecksums( + Span packet, + ref DivertAddress address, + DivertHelperFlags flags = DivertHelperFlags.None + ) + { + fixed (byte* pPacket = packet) + fixed (DivertAddress* pAddress = &address) + { + return NativeMethods.WinDivertHelperCalcChecksums( + pPacket, + (uint)packet.Length, + (WINDIVERT_ADDRESS*)pAddress, + (ulong)flags + ); + } + } + + /// + /// Decrements the TTL field in the IP header of the specified packet. + /// + /// The packet data. + /// true if the result is non-zero; otherwise, false. + public static bool DecrementTtl(Span packet) + { + fixed (byte* pPacket = packet) + { + return NativeMethods.WinDivertHelperDecrementTTL(pPacket, (uint)packet.Length); + } + } + + /// + /// Compiles a WinDivert filter string into a compact object representation. + /// + /// + /// The filter to compile. + /// + /// The layer. + /// The length of the buffer. + /// The compiled filter. + /// + /// The filter is invalid. + /// + public static ReadOnlySpan CompileFilter( + DivertFilter filter, + DivertLayer layer, + int bufferLength = ushort.MaxValue + ) + { + ArgumentNullException.ThrowIfNull(filter); + + using var s = new CString(filter.Clause); + Span buffer = GC.AllocateArray(bufferLength, pinned: true); + var pBuffer = Unsafe.AsPointer(ref MemoryMarshal.GetReference(buffer)); + IntPtr errorStr; + uint errorPos; + bool success = NativeMethods.WinDivertHelperCompileFilter( + s.Pointer, + (WINDIVERT_LAYER)layer, + (byte*)pBuffer, + (uint)buffer.Length, + &errorStr, + &errorPos + ); + if (!success) + { + string? errorString = Marshal.PtrToStringAnsi(errorStr); + throw new ArgumentException( + $"{errorString} ({errorPos}): ...{filter.Clause[(int)errorPos..]}", + nameof(filter) + ); + } + + return buffer; + } + + private static bool EvaluateFilter(IntPtr filter, ReadOnlySpan packet, in DivertAddress address) + { + fixed (byte* pPacket = packet) + fixed (DivertAddress* pAddress = &address) + { + bool success = NativeMethods.WinDivertHelperEvalFilter( + filter, + pPacket, + (uint)packet.Length, + (WINDIVERT_ADDRESS*)pAddress + ); if (!success) { - throw new Win32Exception(Marshal.GetLastWin32Error()); + int error = Marshal.GetLastPInvokeError(); + if (error is not (int)WIN32_ERROR.ERROR_SUCCESS) + { + throw new Win32Exception(error); + } } + return success; + } + } + + /// + /// Evaluates a compiled WinDivert filter against the specified packet and address. + /// + /// The compiled filter. + /// The packet data. + /// The address information. + /// true if the packet matches the filter; otherwise, false. + public static bool EvaluateFilter(ReadOnlySpan filter, ReadOnlySpan packet, in DivertAddress address) + { + fixed (byte* pFilter = filter) + fixed (byte* pPacket = packet) + fixed (DivertAddress* pAddress = &address) + { + return EvaluateFilter(new IntPtr(pFilter), packet, address); + } + } + + /// + /// Evaluates a WinDivert filter string against the specified packet and address. + /// + /// The filter to evaluate. + /// The packet data. + /// The address information. + /// true if the packet matches the filter; otherwise, false. + public static bool EvaluateFilter(DivertFilter filter, ReadOnlySpan packet, in DivertAddress address) + { + using var s = new CString(filter.Clause); + fixed (byte* pPacket = packet) + fixed (DivertAddress* pAddress = &address) + { + return EvaluateFilter(s.Pointer, packet, address); + } + } + + /// + /// Formats a compiled WinDivert filter into a human-readable string. + /// + /// The compiled filter. + /// The layer. + /// The maximum length of the formatted string. + /// The formatted filter string. + public static string FormatFilter(Span filter, DivertLayer layer, int maxLength = ushort.MaxValue) + { + Span buffer = GC.AllocateArray(maxLength, pinned: true); + var pBuffer = Unsafe.AsPointer(ref MemoryMarshal.GetReference(buffer)); + fixed (byte* pFilter = filter) + { + bool success = NativeMethods.WinDivertHelperFormatFilter( + new(pFilter), + (WINDIVERT_LAYER)layer, + (byte*)pBuffer, + (uint)buffer.Length + ); + if (!success) + { + throw new Win32Exception(Marshal.GetLastPInvokeError()); + } + + return Marshal.PtrToStringAnsi(new IntPtr(pBuffer))!; } } } diff --git a/Divert.Windows/DivertHelperFlags.cs b/Divert.Windows/DivertHelperFlags.cs index 5e547d5..78ad53c 100644 --- a/Divert.Windows/DivertHelperFlags.cs +++ b/Divert.Windows/DivertHelperFlags.cs @@ -1,12 +1,38 @@ namespace Divert.Windows; +/// +/// Flags to disable individual checksum calculations. +/// [Flags] public enum DivertHelperFlags { + /// + /// No flags specified. + /// None = 0, + + /// + /// Disables IP checksum calculation. + /// NoIPChecksum = 1, + + /// + /// Disables ICMP checksum calculation. + /// NoICMPChecksum = 2, + + /// + /// Disables ICMPv6 checksum calculation. + /// NoICMPv6Checksum = 4, + + /// + /// Disables TCP checksum calculation. + /// NoTCPChecksum = 8, - NoUDPChecksum = 16 + + /// + /// Disables UDP checksum calculation. + /// + NoUDPChecksum = 16, } diff --git a/Divert.Windows/DivertIOControl.cs b/Divert.Windows/DivertIOControl.cs new file mode 100644 index 0000000..ba4615c --- /dev/null +++ b/Divert.Windows/DivertIOControl.cs @@ -0,0 +1,88 @@ +using System.ComponentModel; +using System.Runtime.InteropServices; +using Windows.Win32; +using Windows.Win32.Foundation; + +namespace Divert.Windows; + +/// +/// Replaces and as they +/// are not applicable to thread pool bound handles. +/// +internal static unsafe class DivertIOControl +{ + private static uint CTL_CODE(uint deviceType, uint function, uint method, uint access) + { + return (deviceType << 16) | (access << 14) | (function << 2) | method; + } + + private static readonly uint SetParamControlCode = CTL_CODE( + PInvoke.FILE_DEVICE_NETWORK, + 0x925, + PInvoke.METHOD_IN_DIRECT, + (uint)FileAccess.ReadWrite + ); + + private static readonly uint GetParamControlCode = CTL_CODE( + PInvoke.FILE_DEVICE_NETWORK, + 0x926, + PInvoke.METHOD_OUT_DIRECT, + (uint)FileAccess.Read + ); + + private static void ManualResetCallback(uint errorCode, uint numBytes, NativeOverlapped* pOVERLAP) + { + var manualResetEvent = (ManualResetEventSlim)ThreadPoolBoundHandle.GetNativeOverlappedState(pOVERLAP)!; + manualResetEvent.Set(); + } + + private static readonly IOCompletionCallback manualResetCallback = ManualResetCallback; + + private static ulong DeviceIOControl(ThreadPoolBoundHandle threadPoolBoundHandle, WINDIVERT_IOCTL* ioctl, uint code) + { + ulong value; + using var eventHandle = new ManualResetEventSlim(initialState: false); + var nativeOverlapped = threadPoolBoundHandle.AllocateNativeOverlapped(manualResetCallback, eventHandle, null); + try + { + using var _ = threadPoolBoundHandle.Handle.DangerousGetHandle(out var handle); + bool success = PInvoke.DeviceIoControl( + new HANDLE(handle), + code, + ioctl, + (uint)sizeof(WINDIVERT_IOCTL), + &value, + sizeof(ulong), + null, + nativeOverlapped + ); + if (!success) + { + int error = Marshal.GetLastPInvokeError(); + if (error is not (int)WIN32_ERROR.ERROR_IO_PENDING) + { + throw new Win32Exception(error); + } + eventHandle.Wait(); + } + } + finally + { + threadPoolBoundHandle.FreeNativeOverlapped(nativeOverlapped); + } + + return value; + } + + public static void SetParam(ThreadPoolBoundHandle threadPoolBoundHandle, WINDIVERT_PARAM param, ulong value) + { + var ioctl = new WINDIVERT_IOCTL { SetParam = param, Value = value }; + DeviceIOControl(threadPoolBoundHandle, &ioctl, SetParamControlCode); + } + + public static ulong GetParam(ThreadPoolBoundHandle threadPoolBoundHandle, WINDIVERT_PARAM param) + { + var ioctl = new WINDIVERT_IOCTL { GetParam = param }; + return DeviceIOControl(threadPoolBoundHandle, &ioctl, GetParamControlCode); + } +} diff --git a/Divert.Windows/DivertLayer.cs b/Divert.Windows/DivertLayer.cs index 435cd14..4abed19 100644 --- a/Divert.Windows/DivertLayer.cs +++ b/Divert.Windows/DivertLayer.cs @@ -1,10 +1,32 @@ namespace Divert.Windows; +/// +/// Specifies the WinDivert layer. +/// public enum DivertLayer { - Network = 0, - Forward = 1, - Flow = 2, - Socket = 3, - Reflect = 4 + /// + /// Network packets to/from the local machine. + /// + Network = WINDIVERT_LAYER.WINDIVERT_LAYER_NETWORK, + + /// + /// Network packets being forwarded through the local machine. + /// + Forward = WINDIVERT_LAYER.WINDIVERT_LAYER_NETWORK_FORWARD, + + /// + /// Network flow events. + /// + Flow = WINDIVERT_LAYER.WINDIVERT_LAYER_FLOW, + + /// + /// Socket events. + /// + Socket = WINDIVERT_LAYER.WINDIVERT_LAYER_SOCKET, + + /// + /// WinDivert handle events. + /// + Reflect = WINDIVERT_LAYER.WINDIVERT_LAYER_REFLECT, } diff --git a/Divert.Windows/DivertReceiveResult.cs b/Divert.Windows/DivertReceiveResult.cs new file mode 100644 index 0000000..8c7ee8e --- /dev/null +++ b/Divert.Windows/DivertReceiveResult.cs @@ -0,0 +1,30 @@ +namespace Divert.Windows; + +/// +/// Represents the result of a Divert receive operation. +/// +/// The length of the received data. +/// The length of the addresses. +public readonly struct DivertReceiveResult(int dataLength, int addressLength) +{ + /// + /// Gets the length of the received data. + /// + public int DataLength { get; } = dataLength; + + /// + /// Gets the length of the addresses. + /// + public int AddressLength { get; } = addressLength; + + /// + /// Deconstructs the result into its components. + /// + /// The length of the received data. + /// The length of the addresses. + public void Deconstruct(out int dataLength, out int addressLength) + { + dataLength = DataLength; + addressLength = AddressLength; + } +} diff --git a/Divert.Windows/DivertService.cs b/Divert.Windows/DivertService.cs index 946bced..e0d48c1 100644 --- a/Divert.Windows/DivertService.cs +++ b/Divert.Windows/DivertService.cs @@ -1,268 +1,331 @@ using System.ComponentModel; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; -using System.Runtime.Versioning; -using Windows.Win32; +using System.Threading.Channels; +using Divert.Windows.AsyncOperation; using Windows.Win32.Foundation; -[assembly: SupportedOSPlatform("windows6.0.6000")] - namespace Divert.Windows; -unsafe sealed public class DivertService : IDisposable +/// +/// Main entry point for WinDivert operations. +/// +public sealed unsafe class DivertService : IDisposable { - internal HANDLE Handle { get; } + /// + /// The highest priority for a WinDivert handle. + /// + public const int HighestPriority = Constants.WINDIVERT_PRIORITY_HIGHEST; + + /// + /// The lowest priority for a WinDivert handle. + /// + public const int LowestPriority = Constants.WINDIVERT_PRIORITY_LOWEST; + + /// + /// The default packet queue length for receive operations. + /// + public const int DefaultQueueLength = Constants.WINDIVERT_PARAM_QUEUE_LENGTH_DEFAULT; + + /// + /// The minimum packet queue length for receive operations. + /// + public const int MinQueueLength = Constants.WINDIVERT_PARAM_QUEUE_LENGTH_MIN; + + /// + /// The maximum packet queue length for receive operations. + /// + public const int MaxQueueLength = Constants.WINDIVERT_PARAM_QUEUE_LENGTH_MAX; /// - /// Opens a WinDivert handle for the given filter. + /// The default packet queue time. /// - /// A packet filter string specified in the WinDivert filter language. + public static TimeSpan DefaultQueueTime => TimeSpan.FromMilliseconds(Constants.WINDIVERT_PARAM_QUEUE_TIME_DEFAULT); + + /// + /// The minimum packet queue time. + /// + public static TimeSpan MinQueueTime => TimeSpan.FromMilliseconds(Constants.WINDIVERT_PARAM_QUEUE_TIME_MIN); + + /// + /// The maximum packet queue time. + /// + public static TimeSpan MaxQueueTime => TimeSpan.FromMilliseconds(Constants.WINDIVERT_PARAM_QUEUE_TIME_MAX); + + /// + /// The default max number of bytes in the packet queue for receive operations. + /// + public const int DefaultQueueSize = Constants.WINDIVERT_PARAM_QUEUE_SIZE_DEFAULT; + + /// + /// The minimum max number of bytes in the packet queue for receive operations. + /// + public const int MinQueueSize = Constants.WINDIVERT_PARAM_QUEUE_SIZE_MIN; + + /// + /// The maximum max number of bytes in the packet queue for receive operations. + /// + public const int MaxQueueSize = Constants.WINDIVERT_PARAM_QUEUE_SIZE_MAX; + + /// + /// The maximum number of packets in a single send or receive operation. + /// + public const int MaxBatchSize = Constants.WINDIVERT_BATCH_MAX; + + private readonly DivertHandle divertHandle; + private readonly bool runContinuationsAsynchronously; + + private readonly ThreadPoolBoundHandle threadPoolBoundHandle; + private readonly Channel sendVtsPool; + private readonly Channel receiveVtsPool; + private readonly DivertReceiveExecutor receiveExecutor; + private readonly DivertSendExecutor sendExecutor; + + private DivertService(DivertHandle divertHandle) + { + this.divertHandle = divertHandle; + threadPoolBoundHandle = ThreadPoolBoundHandle.BindHandle(divertHandle); + sendVtsPool = Channel.CreateUnbounded(); + receiveVtsPool = Channel.CreateUnbounded(); + receiveExecutor = new DivertReceiveExecutor(); + sendExecutor = new DivertSendExecutor(); + } + + private static DivertHandle OpenHandle( + ReadOnlySpan filter, + DivertLayer layer, + short priority, + DivertFlags flags + ) + { + if (priority < LowestPriority || priority > HighestPriority) + { + throw new ArgumentOutOfRangeException(nameof(priority)); + } + + var pBuffer = Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)); + var handle = NativeMethods.WinDivertOpen(new(pBuffer), (WINDIVERT_LAYER)layer, priority, (ulong)flags); + if (new HANDLE(handle) == HANDLE.INVALID_HANDLE_VALUE) + { + throw new Win32Exception(Marshal.GetLastPInvokeError()); + } + return new DivertHandle(handle); + } + + /// + /// Initializes a new instance of the class. + /// + /// The packet filter. /// The layer. /// The priority of the handle. - /// Additional flags. + /// The handle flags. + /// Whether to force continuations to run asynchronously. public DivertService( DivertFilter filter, DivertLayer layer = DivertLayer.Network, short priority = 0, - DivertFlags flags = DivertFlags.None) + DivertFlags flags = DivertFlags.None, + bool runContinuationsAsynchronously = true + ) + : this(OpenHandle(DivertHelper.CompileFilter(filter, layer), layer, priority, flags)) { - ArgumentNullException.ThrowIfNull(filter); - - using var s = new CString(filter.Clause); - IntPtr errorStr; - uint errorPos; - bool success = NativeMethods.WinDivertHelperCompileFilter( - s.Ptr, - (WINDIVERT_LAYER)layer, - null, - 0, - &errorStr, - &errorPos); - if (!success) - { - string? errorString = Marshal.PtrToStringAnsi(errorStr); - throw new ArgumentException($"{errorPos}: {errorString}", nameof(filter)); - } + this.runContinuationsAsynchronously = runContinuationsAsynchronously; + } - var handle = NativeMethods.WinDivertOpen(s.Ptr, (WINDIVERT_LAYER)layer, priority, (ulong)flags); - if (handle.Value.ToInt64() == -1) - { - throw new Win32Exception(Marshal.GetLastWin32Error()); - } - Handle = handle; + /// + /// Initializes a new instance of the class. + /// + /// The packet filter. + /// The layer. + /// The priority of the handle. + /// The handle flags. + /// Whether to force continuations to run asynchronously. + public DivertService( + ReadOnlySpan filter, + DivertLayer layer = DivertLayer.Network, + short priority = 0, + DivertFlags flags = DivertFlags.None, + bool runContinuationsAsynchronously = true + ) + : this(OpenHandle(filter, layer, priority, flags)) + { + this.runContinuationsAsynchronously = runContinuationsAsynchronously; + } + + /// + /// Initializes a new instance of the class. + /// + /// An existing WinDivert handle. + /// Whether to force continuations to run asynchronously. + public DivertService(SafeHandle handle, bool runContinuationsAsynchronously = true) + : this(new DivertHandle(handle.DangerousGetHandle(), ownsHandle: false)) + { + this.runContinuationsAsynchronously = runContinuationsAsynchronously; } private bool disposed = false; - public void Dispose() + private static void DisposeVtsPool(Channel pool) { - if (!disposed) + pool.Writer.TryComplete(); + while (pool.Reader.TryRead(out var vts)) { - bool success = NativeMethods.WinDivertClose(Handle); - if (!success) - { - throw new Win32Exception(Marshal.GetLastWin32Error()); - } - disposed = true; + vts.Dispose(); } - GC.SuppressFinalize(this); } - ~DivertService() + /// + /// Releases all resources used by the . + /// + public void Dispose() { - Dispose(); + if (disposed) + { + return; + } + disposed = true; + + DisposeVtsPool(sendVtsPool); + DisposeVtsPool(receiveVtsPool); + threadPoolBoundHandle.Dispose(); + divertHandle.Dispose(); } - internal void ThrowIfDisposed() + /// + /// Gets the underlying WinDivert handle. + /// + public DivertHandle SafeHandle => divertHandle; + + private DivertValueTaskSource GetVts(Channel vtsPool) { - if (disposed) + if (!vtsPool.Reader.TryRead(out var vts)) { - throw new ObjectDisposedException(nameof(DivertService)); + vts = new DivertValueTaskSource( + sendVtsPool, + divertHandle, + threadPoolBoundHandle, + runContinuationsAsynchronously + ); } + return vts; } - public void Shutdown(DivertShutdown how) + /// + /// Receives one or more packets from the network stack. + /// + /// Buffer to receive the packet data. + /// Buffer to receive the packet addresses. + /// Token to observe cancellation requests. + /// A ValueTask representing the asynchronous receive operation. + public ValueTask ReceiveAsync( + Memory buffer, + Memory addresses, + CancellationToken cancellationToken = default + ) { - ThrowIfDisposed(); - - bool success = NativeMethods.WinDivertShutdown(Handle, (WINDIVERT_SHUTDOWN)how); - if (!success) + ObjectDisposedException.ThrowIf(disposed, this); + if (cancellationToken.IsCancellationRequested) { - throw new Win32Exception(Marshal.GetLastWin32Error()); + return ValueTask.FromCanceled(cancellationToken); } + + var vts = GetVts(receiveVtsPool); + return receiveExecutor.ReceiveAsync(vts, buffer, addresses, cancellationToken); } - public (int packetLength, DivertAddress address) Receive(Span buffer) + /// + /// Injects one or more packets into the network stack. + /// + /// Buffer containing the packet data. + /// Addresses of the packets to be injected. + /// Token to observe cancellation requests. + /// A ValueTask representing the asynchronous send operation. + public ValueTask SendAsync( + ReadOnlyMemory buffer, + ReadOnlyMemory addresses, + CancellationToken cancellationToken = default + ) { - ThrowIfDisposed(); - - uint packetLength; - WINDIVERT_ADDRESS address; - fixed (byte* pBuffer = buffer) + ObjectDisposedException.ThrowIf(disposed, this); + if (cancellationToken.IsCancellationRequested) { - bool success = NativeMethods.WinDivertRecv( - Handle, - pBuffer, checked((uint)buffer.Length), - &packetLength, - &address); - if (!success) - { - throw new Win32Exception(Marshal.GetLastWin32Error()); - } + return ValueTask.FromCanceled(cancellationToken); } - return ((int)packetLength, new DivertAddress(address)); + + var vts = GetVts(sendVtsPool); + return sendExecutor.SendAsync(vts, buffer, addresses, cancellationToken); } - public (int packetLength, int addressLength) ReceiveEx( - Span buffer, - Span addresses, - CancellationToken cancellationToken = default) + /// + /// Gets the version of the WinDivert driver. + /// + public Version Version { - ThrowIfDisposed(); - cancellationToken.ThrowIfCancellationRequested(); - - using var eventHandle = new ManualResetEvent(initialState: false); - using var _ = cancellationToken.Register(() => eventHandle.Set()); - var overlapped = new NativeOverlapped + get { - EventHandle = eventHandle.SafeWaitHandle.DangerousGetHandle() - }; - uint addrLen = (uint)(sizeof(DivertAddress) * addresses.Length); - fixed (byte* pBuffer = buffer) - fixed (DivertAddress* pAddr = addresses) + int major = (int) + DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_VERSION_MAJOR); + int minor = (int) + DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_VERSION_MINOR); + return new Version(major, minor); + } + } + + /// + /// Gets or sets the length of the receive queue. + /// + public int QueueLength + { + get => (int)DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_LENGTH); + set { - bool success = NativeMethods.WinDivertRecvEx( - Handle, - pBuffer, checked((uint)buffer.Length), - null, - 0, - (WINDIVERT_ADDRESS*)pAddr, - &addrLen, - &overlapped); - if (!success) + if (value < MinQueueLength || value > MaxQueueLength) { - int error = Marshal.GetLastWin32Error(); - if (error == 997) // ERROR_IO_PENDING - { - eventHandle.WaitOne(); - if (cancellationToken.IsCancellationRequested) - { - success = PInvoke.CancelIoEx(Handle, &overlapped); - if (!success) - { - error = Marshal.GetLastWin32Error(); - if (error != 1168) // ERROR_NOT_FOUND - { - throw new Win32Exception(error); - } - } - } - } - else - { - throw new Win32Exception(error); - } + throw new ArgumentOutOfRangeException(nameof(value)); } - uint packetsLength; - success = PInvoke.GetOverlappedResult(Handle, &overlapped, &packetsLength, true); - if (!success) - { - int error = Marshal.GetLastWin32Error(); - if (error == 995) // ERROR_OPERATION_ABORTED - { - throw new OperationCanceledException(); - } - else - { - throw new Win32Exception(error); - } - } - int addressLength = (int)addrLen / sizeof(DivertAddress); - return ((int)packetsLength, addressLength); + + DivertIOControl.SetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_LENGTH, (ulong)value); } } - public void Send(ReadOnlySpan buffer, DivertAddress address) + /// + /// Gets or sets the maximum packet queue time. + /// + public TimeSpan QueueTime { - ThrowIfDisposed(); - - var divertAddress = address.Struct; - fixed (byte* pBuffer = buffer) + get => + TimeSpan.FromMilliseconds( + DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_TIME) + ); + set { - bool success = NativeMethods.WinDivertSend( - Handle, - pBuffer, checked((uint)buffer.Length), - null, - &divertAddress); - if (!success) + if (value < MinQueueTime || value > MaxQueueTime) { - throw new Win32Exception(Marshal.GetLastWin32Error()); + throw new ArgumentOutOfRangeException(nameof(value)); } + + DivertIOControl.SetParam( + threadPoolBoundHandle, + WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_TIME, + (ulong)value.TotalMilliseconds + ); } } - public void SendEx( - ReadOnlySpan buffer, - Span addresses, - CancellationToken cancellationToken = default) + /// + /// Gets or sets the maximum number of bytes in the receive queue. + /// + public int QueueSize { - ThrowIfDisposed(); - cancellationToken.ThrowIfCancellationRequested(); - - using var eventHandle = new ManualResetEvent(initialState: false); - using var _ = cancellationToken.Register(() => eventHandle.Set()); - var overlapped = new NativeOverlapped - { - EventHandle = eventHandle.SafeWaitHandle.DangerousGetHandle() - }; - fixed (byte* pBuffer = buffer) - fixed (DivertAddress* pAddr = addresses) + get => (int)DivertIOControl.GetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_SIZE); + set { - bool success = NativeMethods.WinDivertSendEx( - Handle, - pBuffer, checked((uint)buffer.Length), - null, - 0, - (WINDIVERT_ADDRESS*)pAddr, - (uint)(sizeof(WINDIVERT_ADDRESS) * addresses.Length), - &overlapped); - if (!success) + if (value < MinQueueSize || value > MaxQueueSize) { - int error = Marshal.GetLastWin32Error(); - if (error == 997) // ERROR_IO_PENDING - { - eventHandle.WaitOne(); - if (cancellationToken.IsCancellationRequested) - { - success = PInvoke.CancelIoEx(Handle, &overlapped); - if (!success) - { - error = Marshal.GetLastWin32Error(); - if (error != 1168) // ERROR_NOT_FOUND - { - throw new Win32Exception(error); - } - } - } - } - else - { - throw new Win32Exception(error); - } - } - uint bytesSent; - success = PInvoke.GetOverlappedResult(Handle, &overlapped, &bytesSent, true); - if (!success) - { - int error = Marshal.GetLastWin32Error(); - if (error == 995) // ERROR_OPERATION_ABORTED - { - throw new OperationCanceledException(); - } - else - { - throw new Win32Exception(error); - } + throw new ArgumentOutOfRangeException(nameof(value)); } + + DivertIOControl.SetParam(threadPoolBoundHandle, WINDIVERT_PARAM.WINDIVERT_PARAM_QUEUE_SIZE, (ulong)value); } } } diff --git a/Divert.Windows/DivertShutdown.cs b/Divert.Windows/DivertShutdown.cs index 2173006..067ff9d 100644 --- a/Divert.Windows/DivertShutdown.cs +++ b/Divert.Windows/DivertShutdown.cs @@ -1,8 +1,22 @@ namespace Divert.Windows; +/// +/// Specifies how to shut down a WinDivert handle. +/// public enum DivertShutdown { - Receive = 0x1, - Send = 0x2, - Both = 0x3 + /// + /// Stops new packets from being queued for receiving. + /// + Receive = WINDIVERT_SHUTDOWN.WINDIVERT_SHUTDOWN_RECV, + + /// + /// Stops new packets from being injected. + /// + Send = WINDIVERT_SHUTDOWN.WINDIVERT_SHUTDOWN_SEND, + + /// + /// Stops both receiving and sending of packets. + /// + Both = WINDIVERT_SHUTDOWN.WINDIVERT_SHUTDOWN_BOTH, } diff --git a/Divert.Windows/NativeMethods.cs b/Divert.Windows/NativeMethods.cs index 1cdd2a5..563fa65 100644 --- a/Divert.Windows/NativeMethods.cs +++ b/Divert.Windows/NativeMethods.cs @@ -1,91 +1,115 @@ using System.Runtime.InteropServices; -using Windows.Win32.Foundation; namespace Divert.Windows; -unsafe internal static class NativeMethods +internal static unsafe partial class NativeMethods { private const string dllName = "WinDivert.dll"; - [DllImport(dllName, SetLastError = true)] - public static extern HANDLE WinDivertOpen( - IntPtr filter, - WINDIVERT_LAYER layer, - short priority, - ulong flags); + [LibraryImport(dllName, SetLastError = true)] + public static partial IntPtr WinDivertOpen(IntPtr filter, WINDIVERT_LAYER layer, short priority, ulong flags); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertRecv( - HANDLE handle, + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertRecv( + IntPtr handle, void* pPacket, uint packetLen, uint* pRecvLen, - WINDIVERT_ADDRESS* pAddr); + WINDIVERT_ADDRESS* pAddr + ); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertRecvEx( - HANDLE handle, + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertRecvEx( + IntPtr handle, void* pPacket, uint packetLen, uint* pRecvLen, ulong flags, WINDIVERT_ADDRESS* pAddr, uint* pAddrLen, - NativeOverlapped* lpOverlapped); + NativeOverlapped* lpOverlapped + ); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertSend( - HANDLE handle, + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertSend( + IntPtr handle, void* pPacket, uint packetLen, uint* pSendLen, - WINDIVERT_ADDRESS* pAddr); + WINDIVERT_ADDRESS* pAddr + ); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertSendEx( - HANDLE handle, + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertSendEx( + IntPtr handle, void* pPacket, uint packetLen, uint* pSendLen, ulong flags, WINDIVERT_ADDRESS* pAddr, uint addrLen, - NativeOverlapped* lpOverlapped); - - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertShutdown( - HANDLE handle, - WINDIVERT_SHUTDOWN how); - - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertClose( - HANDLE handle); - - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertSetParam( - HANDLE handle, - WINDIVERT_PARAM param, - ulong value); - - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertGetParam( - HANDLE handle, - WINDIVERT_PARAM param, - ulong* pValue); - - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertHelperCalcChecksums( + NativeOverlapped* lpOverlapped + ); + + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertShutdown(IntPtr handle, WINDIVERT_SHUTDOWN how); + + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertClose(IntPtr handle); + + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertSetParam(IntPtr handle, WINDIVERT_PARAM param, ulong value); + + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertGetParam(IntPtr handle, WINDIVERT_PARAM param, ulong* pValue); + + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertHelperCalcChecksums( void* pPacket, uint packetLen, WINDIVERT_ADDRESS* pAddr, - ulong flags); + ulong flags + ); + + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertHelperDecrementTTL(void* pPacket, uint packetLen); - [DllImport(dllName, SetLastError = true)] - public static extern BOOL WinDivertHelperCompileFilter( + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertHelperCompileFilter( IntPtr filter, WINDIVERT_LAYER layer, byte* @object, uint objLen, IntPtr* errorStr, - uint* errorPos); + uint* errorPos + ); + + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertHelperFormatFilter( + IntPtr filter, + WINDIVERT_LAYER layer, + byte* buffer, + uint bufLen + ); + + [LibraryImport(dllName, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static partial bool WinDivertHelperEvalFilter( + IntPtr filter, + void* pPacket, + uint packetLen, + WINDIVERT_ADDRESS* pAddr + ); } diff --git a/Divert.Windows/NativeMethods.txt b/Divert.Windows/NativeMethods.txt index fbcf31c..8926c54 100644 --- a/Divert.Windows/NativeMethods.txt +++ b/Divert.Windows/NativeMethods.txt @@ -1,2 +1,7 @@ -GetOverlappedResult +INVALID_HANDLE_VALUE +WIN32_ERROR CancelIoEx +FILE_DEVICE_NETWORK +METHOD_IN_DIRECT +METHOD_OUT_DIRECT +DeviceIoControl diff --git a/Divert.Windows/NativeTypes.cs b/Divert.Windows/NativeTypes.cs index 30ae8e8..3f57897 100644 --- a/Divert.Windows/NativeTypes.cs +++ b/Divert.Windows/NativeTypes.cs @@ -8,7 +8,7 @@ internal enum WINDIVERT_LAYER WINDIVERT_LAYER_NETWORK_FORWARD = 1, WINDIVERT_LAYER_FLOW = 2, WINDIVERT_LAYER_SOCKET = 3, - WINDIVERT_LAYER_REFLECT = 4 + WINDIVERT_LAYER_REFLECT = 4, } [StructLayout(LayoutKind.Explicit, Size = 8)] @@ -22,7 +22,7 @@ internal struct WINDIVERT_DATA_NETWORK } [StructLayout(LayoutKind.Explicit, Size = 64)] -unsafe internal struct WINDIVERT_DATA_FLOW +internal unsafe struct WINDIVERT_DATA_FLOW { [FieldOffset(0)] public ulong EndpointId; @@ -50,7 +50,7 @@ unsafe internal struct WINDIVERT_DATA_FLOW } [StructLayout(LayoutKind.Explicit, Size = 64)] -unsafe internal struct WINDIVERT_DATA_SOCKET +internal unsafe struct WINDIVERT_DATA_SOCKET { [FieldOffset(0)] public ulong EndpointId; @@ -110,7 +110,7 @@ internal enum WINDIVERT_ADDRESS_BITS : byte } [StructLayout(LayoutKind.Explicit, Size = 80)] -unsafe internal struct WINDIVERT_ADDRESS +internal unsafe struct WINDIVERT_ADDRESS { [FieldOffset(0)] public long Timestamp; @@ -166,12 +166,25 @@ internal enum WINDIVERT_PARAM WINDIVERT_PARAM_QUEUE_TIME = 1, WINDIVERT_PARAM_QUEUE_SIZE = 2, WINDIVERT_PARAM_VERSION_MAJOR = 3, - WINDIVERT_PARAM_VERSION_MINOR = 4 + WINDIVERT_PARAM_VERSION_MINOR = 4, } internal enum WINDIVERT_SHUTDOWN { WINDIVERT_SHUTDOWN_RECV = 0x1, WINDIVERT_SHUTDOWN_SEND = 0x2, - WINDIVERT_SHUTDOWN_BOTH = 0x3 + WINDIVERT_SHUTDOWN_BOTH = 0x3, +} + +[StructLayout(LayoutKind.Explicit, Size = 16)] +internal struct WINDIVERT_IOCTL +{ + [FieldOffset(0)] + public WINDIVERT_PARAM GetParam; + + [FieldOffset(0)] + public ulong Value; + + [FieldOffset(8)] + public WINDIVERT_PARAM SetParam; } diff --git a/Divert.Windows/SafeHandleExtensions.cs b/Divert.Windows/SafeHandleExtensions.cs new file mode 100644 index 0000000..0124f5b --- /dev/null +++ b/Divert.Windows/SafeHandleExtensions.cs @@ -0,0 +1,30 @@ +using System.Runtime.InteropServices; + +namespace Divert.Windows; + +internal ref struct SafeHandleReference(T safeHandle) : IDisposable + where T : SafeHandle +{ + private bool disposed; + + public void Dispose() + { + if (!disposed) + { + safeHandle.DangerousRelease(); + disposed = true; + } + } +} + +internal static class SafeHandleExtensions +{ + public static SafeHandleReference DangerousGetHandle(this T safeHandle, out IntPtr handle) + where T : SafeHandle + { + bool success = false; + safeHandle.DangerousAddRef(ref success); + handle = safeHandle.DangerousGetHandle(); + return new SafeHandleReference(safeHandle); + } +} diff --git a/Examples/Flow/Flow.csproj b/Examples/Flow/Flow.csproj new file mode 100644 index 0000000..8927562 --- /dev/null +++ b/Examples/Flow/Flow.csproj @@ -0,0 +1,15 @@ + + + Exe + $(DefaultTargetFramework) + app.manifest + win-x64 + true + $(ProjectRoot).local/windows/Examples/$(MSBuildThisFileName) + CA2007 + + + + + + diff --git a/Examples/Flow/Program.cs b/Examples/Flow/Program.cs new file mode 100644 index 0000000..af780d7 --- /dev/null +++ b/Examples/Flow/Program.cs @@ -0,0 +1,136 @@ +using System.Net; +using System.Net.Sockets; +using System.Runtime.Versioning; +using Divert.Windows; + +[assembly: SupportedOSPlatform("windows6.0.6000")] + +// Logs flow events on a loopback TCP listener. + +using var service = new DivertService( + DivertFilter.ProcessId == Environment.ProcessId & DivertFilter.TCP, + DivertLayer.Flow, + flags: DivertFlags.Sniff | DivertFlags.ReceiveOnly +) +{ + QueueTime = DivertService.MaxQueueTime, +}; + +static void WriteLine(string prefix, string message) +{ + Console.WriteLine($"{prefix}: {message}"); +} + +using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); +listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); +listener.Listen(); +int port = ((IPEndPoint)listener.LocalEndPoint!).Port; +WriteLine(nameof(TcpListener), $"Listening on port {port}..."); +using var cts = new CancellationTokenSource(); +var token = cts.Token; + +var sniff = Task.Run(async () => +{ + try + { + var addresses = new DivertAddress[1]; + while (true) + { + await service.ReceiveAsync(Memory.Empty, addresses, token); + var @event = addresses[0].Event; + var socketData = addresses[0].GetFlowData(); + if (socketData.LocalPort == port) + { + WriteLine( + nameof(DivertService), + $"{@event} from {nameof(TcpListener)}, " + + $"local port = {socketData.LocalPort}, remote port = {socketData.RemotePort}." + ); + } + else + { + WriteLine( + nameof(DivertService), + $"{@event} from {nameof(TcpClient)}, " + + $"local port = {socketData.LocalPort}, remote port = {socketData.RemotePort}." + ); + } + } + } + catch (OperationCanceledException) when (token.IsCancellationRequested) { } + catch (Exception e) + { + WriteLine("Error", e.ToString()); + } +}); + +var listen = Task.Run(async () => +{ + try + { + while (true) + { + var client = await listener.AcceptAsync(token); + int clientPort = ((IPEndPoint)client.RemoteEndPoint!).Port; + WriteLine(nameof(TcpListener), $"Accepted connection from port {clientPort}."); + _ = Task.Run(async () => + { + using var _ = client; + var buffer = new byte[1024]; + + while (true) + { + int length = await client.ReceiveAsync(buffer, token); + if (length is 0) + { + WriteLine(nameof(TcpListener), "Closing connection..."); + client.Shutdown(SocketShutdown.Send); + break; + } + } + }); + } + } + catch (OperationCanceledException) when (token.IsCancellationRequested) { } + catch (Exception e) + { + WriteLine("Error", e.ToString()); + } +}); + +WriteLine("Info", "Preparing TCP client..."); +try +{ + for (int i = 0; i < 3; i++) + { + await Task.Delay(TimeSpan.FromSeconds(3), token); + Console.WriteLine(); + using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + WriteLine(nameof(TcpClient), $"Connecting to port {port}..."); + await client.ConnectAsync(IPAddress.Loopback, port, token); + WriteLine(nameof(TcpClient), $"Connected to port {port}."); + + WriteLine(nameof(TcpClient), "Closing connection..."); + client.Shutdown(SocketShutdown.Send); + var buffer = new byte[1024]; + while (true) + { + int length = await client.ReceiveAsync(buffer, token); + if (length is 0) + { + break; + } + } + } +} +catch (Exception e) +{ + WriteLine("Error", e.ToString()); +} + +await Task.Delay(TimeSpan.FromSeconds(3), token); +cts.Cancel(); +await Task.WhenAll(sniff, listen).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); +Console.WriteLine(); +WriteLine("Info", "Done."); +Console.ReadLine(); diff --git a/Examples/Flow/app.manifest b/Examples/Flow/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Examples/Flow/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/Examples/Greeting/Greeting.csproj b/Examples/Greeting/Greeting.csproj new file mode 100644 index 0000000..8927562 --- /dev/null +++ b/Examples/Greeting/Greeting.csproj @@ -0,0 +1,15 @@ + + + Exe + $(DefaultTargetFramework) + app.manifest + win-x64 + true + $(ProjectRoot).local/windows/Examples/$(MSBuildThisFileName) + CA2007 + + + + + + diff --git a/Examples/Greeting/Program.cs b/Examples/Greeting/Program.cs new file mode 100644 index 0000000..341ee83 --- /dev/null +++ b/Examples/Greeting/Program.cs @@ -0,0 +1,25 @@ +using System.Net; +using System.Net.Sockets; +using System.Runtime.Versioning; +using System.Text; +using Divert.Windows; + +// Captures a UDP packet and print to console. + +[assembly: SupportedOSPlatform("windows6.0.6000")] + +using var client = new UdpClient(new IPEndPoint(IPAddress.Loopback, 0)); +var localEndPoint = (IPEndPoint)client.Client.LocalEndPoint!; +int port = localEndPoint.Port; +Console.WriteLine($"Created UDP client on port {port}."); + +using var service = new DivertService(DivertFilter.UDP & DivertFilter.LocalPort == port); +var buffer = new byte[1024]; +var receive = service.ReceiveAsync(buffer, new DivertAddress[1]).AsTask(); + +Console.WriteLine("Sending packet to self..."); +await client.SendAsync("Hello"u8.ToArray(), localEndPoint); + +(int packetLength, _) = await receive; +string message = Encoding.UTF8.GetString(buffer.AsSpan(28, packetLength)); +Console.WriteLine($"{message} from WinDivert!"); diff --git a/Examples/Greeting/app.manifest b/Examples/Greeting/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Examples/Greeting/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/Examples/Http/Http.csproj b/Examples/Http/Http.csproj new file mode 100644 index 0000000..8927562 --- /dev/null +++ b/Examples/Http/Http.csproj @@ -0,0 +1,15 @@ + + + Exe + $(DefaultTargetFramework) + app.manifest + win-x64 + true + $(ProjectRoot).local/windows/Examples/$(MSBuildThisFileName) + CA2007 + + + + + + diff --git a/Examples/Http/Program.cs b/Examples/Http/Program.cs new file mode 100644 index 0000000..9796597 --- /dev/null +++ b/Examples/Http/Program.cs @@ -0,0 +1,138 @@ +using System.Buffers.Binary; +using System.Collections.Concurrent; +using System.Net; +using System.Runtime.Versioning; +using Divert.Windows; + +// Directs all HTTP traffic to a localhost:8080. + +[assembly: SupportedOSPlatform("windows6.0.6000")] + +if (!(args is [string arg, ..] && ushort.TryParse(arg, out ushort listenPort))) +{ + listenPort = 8080; +} + +Console.WriteLine($"Redirecting all HTTP traffic to http://localhost:{listenPort}..."); +using var listener = new HttpListener() { Prefixes = { $"http://*:{listenPort}/" } }; +listener.Start(); +var listen = Task.Run(async () => +{ + var buffer = "Hello from local HTTP server!"u8.ToArray(); + while (true) + { + var context = await listener.GetContextAsync(); + _ = Task.Run(async () => + { + var remoteEndPoint = context.Request.RemoteEndPoint; + Console.WriteLine( + $"Received request from {remoteEndPoint.Address}:{remoteEndPoint.Port} for {context.Request.Url}" + ); + var response = context.Response; + response.ContentLength64 = buffer.Length; + response.ContentType = "text/plain"; + await response.OutputStream.WriteAsync(buffer); + response.OutputStream.Close(); + response.Close(); + }); + } +}); + +using var outService = new DivertService(DivertFilter.Outbound & DivertFilter.RemotePort == 80); +using var inService = new DivertService( + (DivertFilter.RemoteAddress == IPAddress.Loopback | DivertFilter.RemoteAddress == IPAddress.IPv6Loopback) + & DivertFilter.LocalPort == listenPort +); +var portMapping = new ConcurrentDictionary(); + +var redirect = Task.Run(async () => +{ + var buffer = new byte[ushort.MaxValue + 40]; + var addresses = new DivertAddress[1]; + while (true) + { + var result = await outService.ReceiveAsync(buffer, addresses); + ushort sourcePort; + IPAddress originalSource; + IPAddress originalDestination; + if (addresses[0].IsIPv6) + { + sourcePort = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(40)); + originalSource = new IPAddress(buffer.AsSpan(8, 16)); + originalDestination = new IPAddress(buffer.AsSpan(24, 16)); + IPAddress.IPv6Loopback.GetAddressBytes().CopyTo(buffer, 8); + IPAddress.IPv6Loopback.GetAddressBytes().CopyTo(buffer, 24); + BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(42), listenPort); + } + else + { + int ihl = (buffer[0] & 0x0F) * 4; + sourcePort = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(ihl)); + originalSource = new IPAddress(buffer.AsSpan(12, 4)); + originalDestination = new IPAddress(buffer.AsSpan(16, 4)); + IPAddress.Loopback.GetAddressBytes().CopyTo(buffer, 12); + IPAddress.Loopback.GetAddressBytes().CopyTo(buffer, 16); + BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(ihl + 2), listenPort); + } + Console.WriteLine( + $"Redirecting outbound HTTP packet from {originalSource}:{sourcePort} to {originalDestination}:80..." + ); + portMapping[sourcePort] = (originalSource, originalDestination); + var packet = buffer.AsMemory(0, result.DataLength); + DivertHelper.CalculateChecksums(packet.Span, ref addresses[0]); + await outService.SendAsync(packet, addresses); + } +}); + +var inject = Task.Run(async () => +{ + var buffer = new byte[ushort.MaxValue + 40]; + var addresses = new DivertAddress[1]; + while (true) + { + (int packetLength, _) = await inService.ReceiveAsync(buffer, addresses); + ushort targetPort; + IPAddress? originalSource = null; + IPAddress? originalDestination = null; + if (addresses[0].IsIPv6) + { + targetPort = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(42)); + if (portMapping.TryGetValue(targetPort, out var value)) + { + (originalSource, originalDestination) = value; + originalDestination.GetAddressBytes().CopyTo(buffer, 8); + originalSource.GetAddressBytes().CopyTo(buffer, 24); + BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(40), 80); + } + } + else + { + int ihl = (buffer[0] & 0x0F) * 4; + targetPort = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(ihl + 2)); + if (portMapping.TryGetValue(targetPort, out var value)) + { + (originalSource, originalDestination) = value; + originalDestination.GetAddressBytes().CopyTo(buffer, 12); + originalSource.GetAddressBytes().CopyTo(buffer, 16); + BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(ihl), 80); + } + } + if (originalSource is not null) + { + Console.WriteLine( + $"Injecting inbound HTTP packet from {originalDestination}:{80} to {originalSource}:{targetPort}..." + ); + } + var packet = buffer.AsMemory(0, packetLength); + DivertHelper.CalculateChecksums(packet.Span, ref addresses[0]); + await inService.SendAsync(packet, addresses); + } +}); + +using var client = new HttpClient(); +string response = await client.GetStringAsync("http://example.com/"); +Console.WriteLine($"Response from example.com (injected): {response}"); + +Console.WriteLine("Ready."); + +await await Task.WhenAny(listen, redirect, inject); diff --git a/Examples/Http/app.manifest b/Examples/Http/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Examples/Http/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/Examples/Ping/Ping.csproj b/Examples/Ping/Ping.csproj index ae0f9c1..ee85910 100644 --- a/Examples/Ping/Ping.csproj +++ b/Examples/Ping/Ping.csproj @@ -1,18 +1,15 @@ - Exe - net6.0 - enable - enable + $(DefaultTargetFramework) + app.manifest + win-x64 + true + $(ProjectRoot).local/windows/Examples/$(MSBuildThisFileName) + CA2007 - - - - - - \ No newline at end of file + diff --git a/Examples/Ping/Program.cs b/Examples/Ping/Program.cs index 918c2c7..2d2798b 100644 --- a/Examples/Ping/Program.cs +++ b/Examples/Ping/Program.cs @@ -1,38 +1,31 @@ -using System.Diagnostics; -using System.Net; +using System.Net; using System.Net.NetworkInformation; +using System.Net.Sockets; using System.Runtime.Versioning; -using System.Security.Principal; using Divert.Windows; [assembly: SupportedOSPlatform("windows6.0.6000")] -var identity = WindowsIdentity.GetCurrent(); -var principal = new WindowsPrincipal(identity); -if (!principal.IsInRole(WindowsBuiltInRole.Administrator)) +// Divert and re-inject ICMP packets to/from default gateway (or loopback if none). + +string? remoteIP = args.Length > 0 ? args[0] : null; +if (remoteIP is null) { - var startInfo = new ProcessStartInfo - { - FileName = Environment.ProcessPath, - Verb = "runas", - Arguments = args.Length > 0 ? args[0] : string.Empty, - UseShellExecute = true - }; - Process.Start(startInfo); - Environment.Exit(0); + var gateWays = + from nic in NetworkInterface.GetAllNetworkInterfaces() + where + nic is { OperationalStatus: OperationalStatus.Up, NetworkInterfaceType: not NetworkInterfaceType.Loopback } + from gateway in nic.GetIPProperties().GatewayAddresses + let address = gateway.Address + where address is { AddressFamily: AddressFamily.InterNetwork } + select address; + remoteIP = gateWays.FirstOrDefault()?.ToString(); } - -string remoteIP = args.Length > 0 ? args[0] : "8.8.8.8"; +remoteIP ??= IPAddress.Loopback.ToString(); var outboundFilter = - DivertFilter.Outbound - & !DivertFilter.Loopback - & DivertFilter.RemoteAddress == remoteIP - & DivertFilter.ICMP; -var inboundFilter = - DivertFilter.Inbound - & DivertFilter.RemoteAddress == remoteIP - & DivertFilter.ICMP; + DivertFilter.Outbound & !DivertFilter.Loopback & DivertFilter.RemoteAddress == remoteIP & DivertFilter.ICMP; +var inboundFilter = DivertFilter.Inbound & DivertFilter.RemoteAddress == remoteIP & DivertFilter.ICMP; Console.WriteLine($"{nameof(outboundFilter)}: {outboundFilter}"); Console.WriteLine($"{nameof(inboundFilter)}: {inboundFilter}"); @@ -52,18 +45,20 @@ { Console.WriteLine( $"Reply from {reply.Address} (Ping): " - + $"bytes={reply.Buffer?.Length} " - + $"time={reply.RoundtripTime}ms " - + $"TTL={reply.Options?.Ttl}"); + + $"bytes={reply.Buffer?.Length} " + + $"time={reply.RoundtripTime}ms " + + $"TTL={reply.Options?.Ttl}" + ); } else { Console.WriteLine(reply.Status); + break; } } }); -var outbound = Task.Run(() => +var outbound = Task.Run(async () => { using var cts = new CancellationTokenSource(); cts.CancelAfter(3500); @@ -73,11 +68,11 @@ { try { - var (packetLength, _) = outDivert.ReceiveEx(buffer, addresses, cts.Token); - var packet = buffer.AsSpan(0, packetLength); - var remoteAddress = new IPAddress(packet[16..20]); + (int packetLength, _) = await outDivert.ReceiveAsync(buffer, addresses, cts.Token).ConfigureAwait(false); + var packet = buffer.AsMemory(0, packetLength); + var remoteAddress = new IPAddress(packet[16..20].Span); Console.WriteLine($"Pinging {remoteAddress} with {packet.Length - 28} bytes of data (Divert):"); - outDivert.Send(packet, addresses[0]); + await outDivert.SendAsync(packet, addresses, cts.Token); } catch (OperationCanceledException) { @@ -92,25 +87,27 @@ } }); -var inbound = Task.Run(() => +var inbound = Task.Run(async () => { using var cts = new CancellationTokenSource(); cts.CancelAfter(2500); var buffer = new byte[ushort.MaxValue]; + var addresses = new DivertAddress[1]; while (true) { try { - var (packetLength, address) = inDivert.Receive(buffer); - var packet = buffer.AsSpan(0, packetLength); - var remoteAddress = new IPAddress(packet[12..16]); - long timestamp = BitConverter.ToInt64(packet.Slice(28, sizeof(long))); + (int packetLength, _) = await inDivert.ReceiveAsync(buffer, addresses, cts.Token).ConfigureAwait(false); + var packet = buffer.AsMemory(0, packetLength); + var remoteAddress = new IPAddress(packet[12..16].Span); + long timestamp = BitConverter.ToInt64(packet.Slice(28, sizeof(long)).Span); Console.WriteLine( - $"Reply from {remoteAddress} (Divert): " + $"Reply from {remoteAddress} (Divert): " + $"bytes={packet.Length - 28} " + $"time={DateTimeOffset.Now.ToUnixTimeMilliseconds() - timestamp}ms " - + $"TTL={packet[8]}"); - inDivert.SendEx(packet, new[] { address }, cts.Token); + + $"TTL={packet.Span[8]}" + ); + await inDivert.SendAsync(packet, addresses, cts.Token); } catch (OperationCanceledException) { @@ -125,4 +122,6 @@ } }); -await Task.WhenAll(ping, outbound, inbound); +await Task.WhenAll(ping, outbound, inbound).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); +Console.WriteLine("Done."); +Console.ReadLine(); diff --git a/Examples/Ping/app.manifest b/Examples/Ping/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Examples/Ping/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/Examples/Socket/Program.cs b/Examples/Socket/Program.cs new file mode 100644 index 0000000..7cafa92 --- /dev/null +++ b/Examples/Socket/Program.cs @@ -0,0 +1,131 @@ +using System.Net; +using System.Net.Sockets; +using System.Runtime.Versioning; +using Divert.Windows; + +[assembly: SupportedOSPlatform("windows6.0.6000")] + +// Logs socket events on a loopback TCP listener. + +using var service = new DivertService( + DivertFilter.ProcessId == Environment.ProcessId & DivertFilter.TCP, + DivertLayer.Socket, + flags: DivertFlags.Sniff | DivertFlags.ReceiveOnly +) +{ + QueueTime = DivertService.MaxQueueTime, +}; + +static void WriteLine(string prefix, string message) +{ + Console.WriteLine($"{prefix}: {message}"); +} + +using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); +listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); +listener.Listen(); +int port = ((IPEndPoint)listener.LocalEndPoint!).Port; +WriteLine(nameof(TcpListener), $"Listening on port {port}..."); +using var cts = new CancellationTokenSource(); +var token = cts.Token; + +var sniff = Task.Run(async () => +{ + try + { + var addresses = new DivertAddress[1]; + while (true) + { + await service.ReceiveAsync(Memory.Empty, addresses, token); + var @event = addresses[0].Event; + var socketData = addresses[0].GetSocketData(); + if (socketData.LocalPort == port) + { + WriteLine( + nameof(DivertService), + $"{@event} from {nameof(TcpListener)}, " + + $"local port = {socketData.LocalPort}, remote port = {socketData.RemotePort}." + ); + } + else + { + WriteLine( + nameof(DivertService), + $"{@event} from {nameof(TcpClient)}, " + + $"local port = {socketData.LocalPort}, remote port = {socketData.RemotePort}." + ); + } + } + } + catch (OperationCanceledException) when (token.IsCancellationRequested) { } + catch (Exception e) + { + WriteLine("Error", e.ToString()); + } +}); + +var listen = Task.Run(async () => +{ + try + { + while (true) + { + var client = await listener.AcceptAsync(token); + int clientPort = ((IPEndPoint)client.RemoteEndPoint!).Port; + WriteLine(nameof(TcpListener), $"Accepted connection from port {clientPort}."); + _ = Task.Run(async () => + { + using var _ = client; + var buffer = new byte[1024]; + + while (true) + { + int length = await client.ReceiveAsync(buffer, token); + if (length is 0) + { + WriteLine(nameof(TcpListener), "Closing connection..."); + client.Shutdown(SocketShutdown.Send); + break; + } + } + }); + } + } + catch (OperationCanceledException) when (token.IsCancellationRequested) { } + catch (Exception e) + { + WriteLine("Error", e.ToString()); + } +}); + +try +{ + WriteLine("Info", "Preparing TCP client..."); + await Task.Delay(TimeSpan.FromSeconds(3), token); + using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + WriteLine(nameof(TcpClient), $"Connecting to port {port}..."); + await client.ConnectAsync(IPAddress.Loopback, port, token); + WriteLine(nameof(TcpClient), $"Connected to port {port}."); + + WriteLine(nameof(TcpClient), "Closing connection..."); + client.Shutdown(SocketShutdown.Send); + var buffer = new byte[1024]; + while (true) + { + int length = await client.ReceiveAsync(buffer, token); + if (length is 0) + { + break; + } + } +} +catch (Exception e) +{ + WriteLine("Error", e.ToString()); +} + +await Task.Delay(TimeSpan.FromSeconds(3), token); +cts.Cancel(); +await Task.WhenAll(sniff, listen).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); +WriteLine("Info", "Done."); +Console.ReadLine(); diff --git a/Examples/Socket/Socket.csproj b/Examples/Socket/Socket.csproj new file mode 100644 index 0000000..8927562 --- /dev/null +++ b/Examples/Socket/Socket.csproj @@ -0,0 +1,15 @@ + + + Exe + $(DefaultTargetFramework) + app.manifest + win-x64 + true + $(ProjectRoot).local/windows/Examples/$(MSBuildThisFileName) + CA2007 + + + + + + diff --git a/Examples/Socket/app.manifest b/Examples/Socket/app.manifest new file mode 100644 index 0000000..76ff66a --- /dev/null +++ b/Examples/Socket/app.manifest @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0a04128 --- /dev/null +++ b/LICENSE @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/License b/License deleted file mode 100644 index 646bbf8..0000000 --- a/License +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2022 gdlol - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/ReadMe.md b/ReadMe.md index 6129734..d3714d7 100644 --- a/ReadMe.md +++ b/ReadMe.md @@ -1,5 +1,31 @@ # Divert.Windows -WinDivert .NET APIs. +High quality .NET APIs for WinDivert. See https://reqrypt.org/windivert.html. + +# Example + +```csharp +using var client = new UdpClient(new IPEndPoint(IPAddress.Loopback, 0)); +var localEndPoint = (IPEndPoint)client.Client.LocalEndPoint!; +int port = localEndPoint.Port; +Console.WriteLine($"Created UDP client on port {port}."); + +using var service = new DivertService(DivertFilter.UDP & DivertFilter.LocalPort == port); +var buffer = new byte[1024]; +var receive = service.ReceiveAsync(buffer, new DivertAddress[1]).AsTask(); + +Console.WriteLine("Sending packet to self..."); +await client.SendAsync("Hello"u8.ToArray(), localEndPoint); + +(int packetLength, _) = await receive; +string message = Encoding.UTF8.GetString(buffer.AsSpan(28, packetLength)); +Console.WriteLine($"{message} from WinDivert!"); +``` + +See also [Examples/](Examples/) + +# License + +[LGPL-3.0](LICENSE) diff --git a/Workspace.proj b/Workspace.proj new file mode 100644 index 0000000..204d901 --- /dev/null +++ b/Workspace.proj @@ -0,0 +1,6 @@ + + + + + + diff --git a/package.json b/package.json new file mode 100644 index 0000000..dba536a --- /dev/null +++ b/package.json @@ -0,0 +1,24 @@ +{ + "private": true, + "type": "module", + "scripts": { + "build": "pnpm run cake Build", + "cake": "dotnet run --project Automation --target", + "format": "pnpm run cake Format", + "git-push": "pnpm run cake GitPush", + "lint": "pnpm run cake Lint", + "pack": "pnpm run cake Pack", + "publish": "pnpm run cake Publish", + "restore": "pnpm run cake Restore", + "test": "pnpm run cake Test", + "test-report": "pnpm run cake TestReport" + }, + "dependencies": { + "@prettier/plugin-xml": "^3.4.2", + "cspell": "^9.2.1", + "prettier": "^3.6.2", + "prettier-plugin-ini": "^1.3.0", + "prettier-plugin-packagejson": "^2.5.19", + "prettier-plugin-sh": "^0.18.0" + } +}