diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml new file mode 100644 index 0000000..78aa82c --- /dev/null +++ b/.github/workflows/pr-check.yml @@ -0,0 +1,34 @@ +name: PR Check + +on: + pull_request: + branches: [develop] + push: + branches: [develop] + +jobs: + check: + name: Build, Lint and Test + runs-on: ubuntu-latest + env: + CONFIGURATION: Release + + steps: + - uses: actions/checkout@v4 + + - name: Setup .NET + uses: actions/setup-dotnet@v4 + with: + dotnet-version: 10.0.x + + - name: Restore dependencies + run: dotnet restore FAI.slnx + + - name: Build + run: dotnet build FAI.slnx --no-restore --configuration ${{ env.CONFIGURATION }} + + - name: Check format (Lint) + run: dotnet format FAI.slnx --verify-no-changes --no-restore + + - name: Test + run: dotnet test --no-build --configuration ${{ env.CONFIGURATION }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..749b1eb --- /dev/null +++ b/.gitignore @@ -0,0 +1,366 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Oo]ut/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# Jetbrains folder +.idea/ diff --git a/.roo/rules-architect/AGENTS.md b/.roo/rules-architect/AGENTS.md new file mode 100644 index 0000000..b8375ce --- /dev/null +++ b/.roo/rules-architect/AGENTS.md @@ -0,0 +1,15 @@ +# AGENTS.md + +This file provides guidance to agents when working in Architect mode within this repository. + +## Architectural Principles (Non-Obvious) +- **Extreme Performance**: The core goal is 7X-14X speedup over standard Python stacks. Every design decision must prioritize throughput and latency. +- **Pipeline Abstraction**: The library centers on [`IPipeline`](src/FAI.Core/Abstractions.cs:88). It separates "what" to run (Inference Steps) from "how" to execute (Pipeline Batch Executors). +- **Batching Strategy**: Performance comes from specialized batching. Architects should consider new implementations of [`IPipelineBatchExecutor`](src/FAI.Core/Abstractions.cs:70) for specific hardware or data patterns (e.g., [`TokenCountSortingBatchExecutor`](src/FAI.NLP/PipelineBatchExecutors/TokenCountSortingBatchExecutor.cs)). +- **Hardware Agnostic**: Inference logic should be decoupled from the framework (ONNX, PyTorch, etc.) and hardware (CPU, GPU, OpenVino). + +## Core Layout +- `FAI.Core`: Foundation interfaces and base execution logic. +- `FAI.NLP` / `FAI.Vision`: Domain-specific implementations (tokenizers, preprocessors). +- `FAI.Onnx`: Concrete model execution using ONNX Runtime. +- `*.Extensions.DI`: Fluent builders and ServiceCollection integration. diff --git a/.roo/rules-ask/AGENTS.md b/.roo/rules-ask/AGENTS.md new file mode 100644 index 0000000..c20961b --- /dev/null +++ b/.roo/rules-ask/AGENTS.md @@ -0,0 +1,9 @@ +# AGENTS.md + +This file provides guidance to agents when working in Ask mode within this repository. + +## Documentation Rules (Non-Obvious Only) +- **Performance Benchmarks**: Canonical performance gains (7X-14X) are documented in [`README.md`](README.md:20) and compared against standard Python stacks in the [`Examples/`](Examples/) directory. +- **Design Context**: High-level architecture and the motivation for the library (performance on a budget) are found in [`docs/high-level-design.md`](docs/high-level-design.md) and [`docs/Testimonial.md`](docs/Testimonial.md). +- **Core Abstractions**: The fundamental execution logic is defined in [`Abstractions.cs`](src/FAI.Core/Abstractions.cs). Refer to this file when explaining how the system works. +- **Python vs C#**: The repository includes Python examples to demonstrate the migration story; when asked about usage, prioritize showing the C# implementation using `PipelineBuilder`. diff --git a/.roo/rules-code/AGENTS.md b/.roo/rules-code/AGENTS.md new file mode 100644 index 0000000..681d700 --- /dev/null +++ b/.roo/rules-code/AGENTS.md @@ -0,0 +1,25 @@ +# AGENTS.md + +This file provides guidance to agents when working in Code mode within this repository. + +## Performance & Memory (CRITICAL) +- **Zero-Allocation**: Aim for zero-allocation in the hot path. Use `Span`, `ReadOnlySpan`, `Memory`, and `ReadOnlyMemory` to avoid copying data. +- **Tensors**: Always use `System.Numerics.Tensors`. Check [`TensorExtensions.cs`](src/FAI.Core/TensorExtensions.cs) for optimized operations. +- **Pooling**: Use or implement pooling for expensive objects (e.g., [`PooledModelExecutor.cs`](src/FAI.Core/ModelExecutors/PooledModelExecutor.cs)). +- **Concurrency**: Use `SemaphoreSlim` for throttling and `Channel` for producer/consumer patterns to manage throughput without blocking. + +## Coding Rules (Non-Obvious) +- **Project Commands**: + - **Build**: `dotnet build FAI.slnx` + - **Lint**: `dotnet format` + - **Test**: `dotnet test` + - **Post-Test**: + - Always run `dotnet format` after tests pass. + - Commit units of work after tests pass. +- **Modern C# (.NET 10 / C# 14)**: + - Prefer collection expressions `[1, 2, 3]` over `new float[] { 1, 2, 3 }`. + - Use `System.Threading.Lock` instead of `new object()` for locking. +- **Stability**: When working on tests, NEVER change the library code unless implementing a new feature (follow TDD). +- **DI Assembly**: Use [`PipelineBuilder`](src/FAI.Core.Extensions.DI/PipelineBuilder.cs) to construct pipelines. +- **Middleware Order**: `PipelineBuilder.Use` adds executors in a stack (last-in-first-out execution flow). +- **Inference Implementation**: Implement [`IInferenceSteps`](src/FAI.Core/Abstractions.cs:95) or inherit [`InferenceSteps<...>`](src/FAI.Core/Abstractions.cs:113). diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..7c09d2c --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,32 @@ +# AGENTS.md + +This file provides guidance to agents when working with code in this repository. + +## Commands +- **Build**: `dotnet build FAI.slnx` +- **Lint**: `dotnet format` (part of pre-commit hooks) +- **Test**: `dotnet test` (Infrastructure initialized using `xunit.v3` and MTP in `test/` folder) +- **Post-Test**: + - Always run `dotnet format` after tests pass. + - Commit units of work after tests pass. + +## Code Style (Non-Obvious) +- **Formatting**: 4 spaces, `LF` line endings, 160 chars max width. +- **Naming**: `_camelCase` for private/static fields; `PascalCase` for types, methods, and properties. +- **Modern C# (.NET 10 / C# 14)**: + - Prefer collection expressions `[1, 2, 3]` over `new float[] { 1, 2, 3 }`. + - Use `System.Threading.Lock` instead of `new object()` for locking. +- **Tensors**: Uses `System.Numerics.Tensors` (dotnet 9+ feature). + +## Stability & Testing +- **Library Stability**: When working on tests, NEVER change the library code unless implementing a new feature (follow TDD). +- **Testing Style Guide**: + - **Assertions**: Use explicit collection matching for ranges and outputs. Avoid partial assertions like `Assert.Single` when the full state can be verified. + - **Mocks**: When testing components that offload work (e.g., `BackgroundPipelineBatchExecutor`), always verify that the exact data passed to the component reached the inner dependency. + - **DI Testing**: Focus on verifying that the correct implementation types are resolved and that the component chain is assembled in the intended order. + - **Collection Expressions**: Use `[1, 2, 3]` instead of `new int[] { 1, 2, 3 }` in all test code. + +## Critical Patterns +- **Middleware Chain**: `IPipelineBatchExecutor` follows a decorator/middleware pattern. +- **DI Fluent API**: Use `PipelineBuilder` to assemble pipelines; executors are added in stack order (last added runs after previous). +- **Abstractions**: All ML tasks must implement `IInferenceSteps` or extend `InferenceSteps<...>`. diff --git a/Examples/SentimentInference/Example.SentimentInference.Api/Program.cs b/Examples/SentimentInference/Example.SentimentInference.Api/Program.cs index 27b0d80..c951a64 100644 --- a/Examples/SentimentInference/Example.SentimentInference.Api/Program.cs +++ b/Examples/SentimentInference/Example.SentimentInference.Api/Program.cs @@ -9,15 +9,15 @@ builder.Services.AddDefaultSentimentInference(options); -builder.Services.AddKeyedSingleton>("orchestrated", - (sp, _) => - { - var inference = sp.GetRequiredService>(); - return new InferenceOrchestrator, string, bool>( - new Lazy>(() => inference), 10, 5, TimeSpan.FromMicroseconds(10) - ); - } - ); +// builder.Services.AddKeyedSingleton>("orchestrated", +// (sp, _) => +// { +// var inference = sp.GetRequiredService>(); +// return new InferenceOrchestrator, string, bool>( +// new Lazy>(() => inference), 10, 5, TimeSpan.FromMicroseconds(10) +// ); +// } +// ); builder.Services.AddOpenApi(); @@ -32,7 +32,7 @@ app.MapPost("/predict", async ([FromBody] string sentence, IInference inference) => await inference.Predict(sentence)); -app.MapPost("/predict-orchestrated", async ([FromBody] string sentence, [FromKeyedServices("orchestrated")] IInference inference) - => await inference.Predict(sentence)); +// app.MapPost("/predict-orchestrated", async ([FromBody] string sentence, [FromKeyedServices("orchestrated")] IInference inference) +// => await inference.Predict(sentence)); app.Run(); diff --git a/Examples/python/uv.lock b/Examples/python/uv.lock index 474460b..d9c601f 100644 --- a/Examples/python/uv.lock +++ b/Examples/python/uv.lock @@ -247,6 +247,7 @@ dependencies = [ dev = [ { name = "black" }, { name = "flake8" }, + { name = "flake8-pyproject" }, { name = "isort" }, { name = "pre-commit" }, { name = "pylint" }, @@ -261,6 +262,7 @@ dev = [ requires-dist = [ { name = "black", marker = "extra == 'dev'" }, { name = "flake8", marker = "extra == 'dev'" }, + { name = "flake8-pyproject", marker = "extra == 'dev'" }, { name = "isort", marker = "extra == 'dev'" }, { name = "onnx", specifier = "==1.16.1" }, { name = "optimum" }, diff --git a/FAI.slnx b/FAI.slnx index c7d403c..53c7daa 100644 --- a/FAI.slnx +++ b/FAI.slnx @@ -20,4 +20,13 @@ + + + + + + + + + diff --git a/docs/tests.md b/docs/tests.md new file mode 100644 index 0000000..efe1d5d --- /dev/null +++ b/docs/tests.md @@ -0,0 +1,52 @@ +# Testing Guide + +FAI uses `xunit.v3` with Microsoft Testing Platform (MTP) for testing. + +## Running Tests + +All tests are located in the `test/` directory and use `net10.0`. + +### Command Line + +To run all tests: + +```bash +dotnet test +``` + +To run tests for a specific project: + +```bash +dotnet test test/FAI.Core.Tests/FAI.Core.Tests.csproj +``` + +### Visual Studio Code + +You can use the built-in Test Explorer in VS Code to run and debug tests. + +## Infrastructure + +The testing infrastructure is centralized in `test/Directory.Build.props`, which includes: +- `xunit.v3` +- `Microsoft.NET.Test.Sdk` +- MTP runner enabled via `true` + +## Project Coverage + +The test suite covers the following areas: + +- **Core**: [`test/FAI.Core.Tests`](test/FAI.Core.Tests) - Unit tests for core abstractions, batching logic, and tensor utilities. +- **DI (Dependency Injection)**: [`test/FAI.Core.Extensions.DI.Tests`](test/FAI.Core.Extensions.DI.Tests) - Verification of pipeline assembly and service registrations. +- **NLP**: [`test/FAI.NLP.Tests`](test/FAI.NLP.Tests) and [`test/FAI.NLP.Extensions.DI.Tests`](test/FAI.NLP.Extensions.DI.Tests) - Tokenization, NLP batching, and text-specific tasks. +- **Onnx**: [`test/FAI.Onnx.Tests`](test/FAI.Onnx.Tests) - ONNX model execution, device pools, and tensor utilities. +- **Evaluation**: [`test/FAI.Extensions.Evaluation.Tests`](test/FAI.Extensions.Evaluation.Tests) - Pipeline for batch evaluation of models against datasets. +- **Integration**: [`test/FAI.IntegrationTests`](test/FAI.IntegrationTests) - End-to-end verification of full pipeline assembly and execution using logical mocks. + +## Integration Testing Strategy + +Testing every possible permutation of pipeline components (Schedulers, Slicers, Executors) is neither feasible nor desirable. Instead, we follow a "Representative Combinations" strategy: + +1. **Common Architectural Patterns**: We prioritize testing configurations used in production-like scenarios, such as `Background -> Partition -> Tokenizer -> Model`. +2. **Component Breadth**: Every major abstraction (e.g., `IPipelineBatchExecutor`, `IBatchSchedular`) must be included in at least one end-to-end integration test. +3. **Lifecycle & Concurrency**: We focus on cross-component state management, ensuring that data flows correctly through asynchronous boundaries (like `BackgroundPipelineBatchExecutor`) and partitioning logic. +4. **Logical Mocks**: We use manual implementation of `IModelExecutor` (e.g., `LogicalMockModelExecutor`) to verify the *library's* orchestration logic without the overhead or non-determinism of actual ONNX/PyTorch models. diff --git a/src/FAI.Core.Extensions.DI/FAI.Core.Extensions.DI.csproj b/src/FAI.Core.Extensions.DI/FAI.Core.Extensions.DI.csproj index 83a3da0..d163397 100644 --- a/src/FAI.Core.Extensions.DI/FAI.Core.Extensions.DI.csproj +++ b/src/FAI.Core.Extensions.DI/FAI.Core.Extensions.DI.csproj @@ -10,6 +10,12 @@ + + + <_Parameter1>FAI.Core.Extensions.DI.Tests + + + diff --git a/src/FAI.Core/InferenceOrchestrator.cs b/src/FAI.Core/InferenceOrchestrator.cs deleted file mode 100644 index 7369c25..0000000 --- a/src/FAI.Core/InferenceOrchestrator.cs +++ /dev/null @@ -1,173 +0,0 @@ -using System.Diagnostics; -using System.Threading.Channels; -using FAI.Core.Abstractions; - -namespace FAI.Core; - -/// -/// Orchestrates inference operations by batching requests and managing concurrency for dynamic requests. -/// -/// The type of the inference model implementing . -/// The type of the input query for inference. -/// The type of the result produced by the inference. -#pragma warning disable CA1001 -public sealed class InferenceOrchestrator : IInference where TInference : IInference -#pragma warning restore CA1001 -{ - private readonly Lazy _modelInstance; - private readonly int _maxBatchSize; - private readonly TimeSpan _emptyQueueSleepDuration; - private readonly Channel<(TQuery, TaskCompletionSource, ActivityContext?)> _queue; - private readonly SemaphoreSlim _semaphore; - private static readonly ActivitySource ActivitySource = new ActivitySource("ModelPredictionOrchestrator"); - - /// - /// Initializes a new instance of the class. - /// - /// A lazy-loaded instance of the inference model. - /// The maximum number of requests to process in a single batch. - /// The maximum number of concurrent batches allowed. - /// The duration to wait when the queue is empty before checking again. - public InferenceOrchestrator( - Lazy modelInstance, - int maxBatchSize, - int maxConcurrentBatches, - TimeSpan emptyQueueSleepDuration) - { - _modelInstance = modelInstance; - _maxBatchSize = maxBatchSize; - _emptyQueueSleepDuration = emptyQueueSleepDuration; - _queue = Channel.CreateBounded<(TQuery, TaskCompletionSource, ActivityContext?)>(maxBatchSize * maxConcurrentBatches * 3); - _semaphore = new SemaphoreSlim(maxConcurrentBatches); - StartBackgroundProcessing(); - } - - /// - /// Predicts the result for a single input query asynchronously. - /// - /// The input query for the prediction. - /// A task representing the asynchronous operation. The task result contains the predicted result. - public async Task Predict(TQuery inputQuery) - { - var tcs = new TaskCompletionSource(); - - // Start a tracing span for enqueuing the request - using (var activity = ActivitySource.StartActivity("enqueue-prediction-request", ActivityKind.Producer)) - { - var context = activity?.Context; - await _queue.Writer.WriteAsync((inputQuery, tcs, context)); - } - - // Start a tracing span for waiting for the response - using (var activity = ActivitySource.StartActivity("wait-for-prediction-response", ActivityKind.Consumer)) - { - return await tcs.Task.ConfigureAwait(false); - } - } - - /// - /// Starts the background processing task to handle batched inference requests. - /// - private void StartBackgroundProcessing() - { - Task.Run(async () => - { - var tasks = new List(); - var modelInstance = _modelInstance.Value; - while (true) - { - if (_queue.Reader.Count == 0) - { - await Task.Delay(_emptyQueueSleepDuration); - continue; - } - - await _semaphore.WaitAsync(); - - var batchTask = RunDynamicBatchAsync(modelInstance).ContinueWith(ReleaseSemaphore); - tasks.Add(batchTask); - - tasks.RemoveAll(t => t.IsCompleted); - } - }); - } - - /// - /// Releases the semaphore after a batch task completes. - /// - /// The completed task. - private void ReleaseSemaphore(Task t) => _semaphore.Release(); - - /// - /// Processes a batch of inference requests dynamically. - /// - /// The inference model to use for processing the batch. - private async Task RunDynamicBatchAsync(IInference model) - { - // Start a tracing span for the batch processing - using var activity = ActivitySource.StartActivity("orchestrated-predict", ActivityKind.Consumer); - - List<(TQuery, TaskCompletionSource, ActivityContext?)> requests = GetAvailableRequestsAsync(_maxBatchSize); - if (requests.Count == 0) - return; - - activity?.SetTag("dynamic_batch_size", requests.Count); - - // Link contexts from each request to the batch span - foreach (var (_, _, context) in requests) - { - if (context.HasValue) - { - activity?.AddLink(new ActivityLink(context.Value)); - } - } - - try - { - var queries = requests.Select(r => r.Item1).ToArray(); - var results = await model.BatchPredict(queries).ConfigureAwait(false); - - for (int i = 0; i < results.Length; i++) - { - requests[i].Item2.SetResult(results[i]); - } - } - catch (Exception ex) - { - foreach (var (_, tcs, _) in requests) - { - tcs.SetException(ex); - } - - activity?.SetStatus(ActivityStatusCode.Error, ex.Message); - } - } - - /// - /// Retrieves available requests from the queue up to the specified maximum count. - /// - /// The maximum number of requests to retrieve. - /// A list of requests retrieved from the queue. - private List<(TQuery, TaskCompletionSource, ActivityContext?)> GetAvailableRequestsAsync(int maxCount) - { - var requests = new List<(TQuery, TaskCompletionSource, ActivityContext?)>(); - while (requests.Count < maxCount && _queue.Reader.TryRead(out var item)) - { - requests.Add(item); - } - - return requests; - } - - /// - /// Predicts the results for a batch of input queries asynchronously. - /// - /// A read-only memory containing the batch of input queries. - /// A task representing the asynchronous operation. The task result contains an array of predicted results. - public Task BatchPredict(ReadOnlyMemory input) => _modelInstance.Value.BatchPredict(input); - - public Task BatchPredict(ReadOnlyMemory input, Memory output) - { - return _modelInstance.Value.BatchPredict(input, output); - } -} diff --git a/src/FAI.Onnx/FAI.Onnx.csproj b/src/FAI.Onnx/FAI.Onnx.csproj index 7e274b0..5ce047c 100644 --- a/src/FAI.Onnx/FAI.Onnx.csproj +++ b/src/FAI.Onnx/FAI.Onnx.csproj @@ -7,6 +7,12 @@ $(NoWarn);SYSLIB5001 + + + <_Parameter1>FAI.Onnx.Tests + + + diff --git a/test/Directory.Build.props b/test/Directory.Build.props new file mode 100644 index 0000000..41a844f --- /dev/null +++ b/test/Directory.Build.props @@ -0,0 +1,23 @@ + + + net10.0 + latest + enable + enable + false + true + + true + true + + + + + + + + + + + + diff --git a/test/FAI.Core.Extensions.DI.Tests/FAI.Core.Extensions.DI.Tests.csproj b/test/FAI.Core.Extensions.DI.Tests/FAI.Core.Extensions.DI.Tests.csproj new file mode 100644 index 0000000..7ddb2c6 --- /dev/null +++ b/test/FAI.Core.Extensions.DI.Tests/FAI.Core.Extensions.DI.Tests.csproj @@ -0,0 +1,5 @@ + + + + + diff --git a/test/FAI.Core.Extensions.DI.Tests/FaiBuilderExtensionsTests.cs b/test/FAI.Core.Extensions.DI.Tests/FaiBuilderExtensionsTests.cs new file mode 100644 index 0000000..3661a32 --- /dev/null +++ b/test/FAI.Core.Extensions.DI.Tests/FaiBuilderExtensionsTests.cs @@ -0,0 +1,52 @@ +using FAI.Core.Abstractions; +using FAI.Core.Extensions.DI; +using Microsoft.Extensions.DependencyInjection; + +namespace FAI.Core.Extensions.DI.Tests; + +public class FAIBuilderExtensionsTests +{ + private readonly IServiceCollection _services = new ServiceCollection(); + + [Fact] + public async Task UsePartitioning_AssemblesCorrectExecutor() + { + // Arrange + _services.AddSingleton, MockInferenceSteps>(); + var tracker = new List(); + _services.AddSingleton(tracker); + var mockSlicer = Substitute.For>(); + mockSlicer.Slice(Arg.Any>()).Returns([new Range(0, 1)]); + + var mockSchedular = Substitute.For>(); + mockSchedular.RunInExecutor(Arg.Any>(), Arg.Any>(), Arg.Any>(), Arg.Any>()) + .Returns(Task.CompletedTask); + + // Act + _services.AddPipeline() + .UsePartitioning(p => + { + p.WithSlicer(_ => mockSlicer); + p.WithSchedular(_ => mockSchedular); + }); + + var sp = _services.BuildServiceProvider(); + var pipeline = sp.GetRequiredService>(); + + var output = new int[1]; + string[] inputs = ["test"]; + await pipeline.BatchPredict(inputs, output); + + // Assert + await mockSchedular.Received(1).RunInExecutor( + Arg.Any>(), + Arg.Any>(), + Arg.Any>(), + Arg.Any>()); + } + + private class MockInferenceSteps : IInferenceSteps + { + public Task ProcessBatch(ReadOnlyMemory inputs, Memory outputs) => Task.CompletedTask; + } +} diff --git a/test/FAI.Core.Extensions.DI.Tests/LocalServiceCollectionTests.cs b/test/FAI.Core.Extensions.DI.Tests/LocalServiceCollectionTests.cs new file mode 100644 index 0000000..fe95449 --- /dev/null +++ b/test/FAI.Core.Extensions.DI.Tests/LocalServiceCollectionTests.cs @@ -0,0 +1,113 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace FAI.Core.Extensions.DI.Tests; + +public class LocalServiceCollectionTests +{ + [Fact] + public void LocalServices_AreIsolatedFromGlobal() + { + // Arrange + var globalServices = new ServiceCollection(); + + // Act + globalServices.AddLocalServices(local => + { + local.AddSingleton(); + }); + + var sp = globalServices.BuildServiceProvider(); + + // Assert + Assert.Null(sp.GetService()); + } + + [Fact] + public void CopyToGlobal_MakesServiceAvailableInGlobal() + { + // Arrange + var globalServices = new ServiceCollection(); + + // Act + globalServices.AddLocalServices(local => + { + local.AddSingleton(); + local.CopyToGlobal(); + }); + + var sp = globalServices.BuildServiceProvider(); + + // Assert + var service = sp.GetService(); + Assert.NotNull(service); + Assert.IsType(service); + } + + [Fact] + public void CopyToGlobal_PreservesLocalDependencies() + { + // Arrange + var globalServices = new ServiceCollection(); + + // Act + globalServices.AddLocalServices(local => + { + local.AddSingleton(); + local.AddSingleton(); + local.CopyToGlobal(); + }); + + var sp = globalServices.BuildServiceProvider(); + + // Assert + var service = sp.GetService() as ServiceWithDependency; + Assert.NotNull(service); + Assert.NotNull(service.Dependency); + + // Dependency should NOT be in global + Assert.Null(sp.GetService()); + } + + [Fact] + public void LocalServices_CanResolveFromGlobal() + { + // Arrange + var globalServices = new ServiceCollection(); + globalServices.AddSingleton(); + + IMyService? localService = null; + + // Act + globalServices.AddLocalServices(local => + { + local.AddSingleton(); + local.CopyToGlobal(); + }); + + var sp = globalServices.BuildServiceProvider(); + localService = sp.GetRequiredService(); + + // Assert + var service = Assert.IsType(localService); + Assert.NotNull(service.GlobalService); + } + + private interface IMyService; + private class MyService : IMyService; + + private interface IDependency; + private class MyDependency : IDependency; + + private class ServiceWithDependency(IDependency dependency) : IMyService + { + public IDependency Dependency { get; } = dependency; + } + + private interface IGlobalService; + private class GlobalService : IGlobalService; + + private class ServiceWithGlobalDependency(IGlobalService globalService) : IMyService + { + public IGlobalService GlobalService { get; } = globalService; + } +} diff --git a/test/FAI.Core.Extensions.DI.Tests/PartitionBuilderTests.cs b/test/FAI.Core.Extensions.DI.Tests/PartitionBuilderTests.cs new file mode 100644 index 0000000..e526f35 --- /dev/null +++ b/test/FAI.Core.Extensions.DI.Tests/PartitionBuilderTests.cs @@ -0,0 +1,92 @@ +using FAI.Core.Abstractions; +using FAI.Core.BatchSchedulers; +using FAI.Core.Configurations.PipelineBatchExecutors; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; + +namespace FAI.Core.Extensions.DI.Tests; + +public class PartitionBuilderTests +{ + private readonly IServiceCollection _services = new ServiceCollection(); + + [Fact] + public void WithSlicer_SetsSlicerFactory() + { + // Arrange + var builder = new PartitionBatchExecutorBuilder(_services); + var mockSlicer = Substitute.For>(); + + // Act + builder.WithSlicer(_ => mockSlicer); + var slicer = builder.BuildSlicer(_services.BuildServiceProvider()); + + // Assert + Assert.Same(mockSlicer, slicer); + } + + [Fact] + public void WithSchedular_SetsSchedularFactory() + { + // Arrange + var builder = new PartitionBatchExecutorBuilder(_services); + var mockSchedular = Substitute.For>(); + + // Act + builder.WithSchedular(_ => mockSchedular); + var schedular = builder.BuildSchedular(_services.BuildServiceProvider()); + + // Assert + Assert.Same(mockSchedular, schedular); + } + + [Fact] + public void WithSerialSchedular_RegistersOptionsAndSchedular() + { + // Arrange + var config = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary + { + ["FAI:Serial:BatchSize"] = "10" + }) + .Build(); + _services.AddSingleton(config); + + var builder = new PartitionBatchExecutorBuilder(_services); + + // Act + builder.WithSerialSchedular("FAI:Serial"); + var sp = _services.BuildServiceProvider(); + var schedular = builder.BuildSchedular(sp); + + // Assert + Assert.IsType>(schedular); + var options = sp.GetRequiredService(); + Assert.Equal(10, options.BatchSize); + } + + [Fact] + public void WithParallelSchedular_RegistersOptionsAndSchedular() + { + // Arrange + var config = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary + { + ["FAI:Parallel:MaxConcurrency"] = "4" + }) + .Build(); + _services.AddSingleton(config); + + var builder = new PartitionBatchExecutorBuilder(_services); + + // Act + builder.WithParallelSchedular("FAI:Parallel"); + var sp = _services.BuildServiceProvider(); + var schedular = builder.BuildSchedular(sp); + + // Assert + Assert.IsType>(schedular); + var options = sp.GetRequiredService(); + Assert.Equal(4, options.MaxConcurrency); + } +} diff --git a/test/FAI.Core.Extensions.DI.Tests/PipelineBuilderTests.cs b/test/FAI.Core.Extensions.DI.Tests/PipelineBuilderTests.cs new file mode 100644 index 0000000..f6b7aaf --- /dev/null +++ b/test/FAI.Core.Extensions.DI.Tests/PipelineBuilderTests.cs @@ -0,0 +1,108 @@ +using FAI.Core.Abstractions; +using FAI.Core.PipelineBatchExecutors; +using Microsoft.Extensions.DependencyInjection; + +namespace FAI.Core.Extensions.DI.Tests; + +public class PipelineBuilderTests +{ + private readonly IServiceCollection _services = new ServiceCollection(); + + [Fact] + public void Build_WithNoMiddleware_ReturnsPipelineWithDefaultSink() + { + // Arrange + _services.AddSingleton, MockInferenceSteps>(); + var builder = new PipelineBuilder(_services); + + // Act + var pipeline = builder.Build(_services.BuildServiceProvider()); + + // Assert + Assert.NotNull(pipeline); + } + + [Fact] + public async Task Build_MaintainsExpectedChainOrder() + { + // Arrange + var tracker = new List(); + _services.AddSingleton(tracker); + + var builder = new PipelineBuilder(_services); + builder.Use((next, sp) => new OrderTrackingExecutor(next, "First", sp.GetRequiredService>())); + builder.Use((next, sp) => new OrderTrackingExecutor(next, "Second", sp.GetRequiredService>())); + builder.UseSink(sp => new OrderTrackingSink("Sink", sp.GetRequiredService>())); + + var sp = _services.BuildServiceProvider(); + + // Act + var pipeline = builder.Build(sp); + var outputs = new int[1]; + string[] inputs = ["test"]; + await pipeline.BatchPredict(inputs, outputs); + + // Assert + // Logic in PipelineBuilder: for (int i = _batchExecutorFactories.Count - 1; i >= 0; i--) + // So the LAST added executor is the one wrapping the sink or the previous one? + // Actually, the loop wraps 'current' (initially sink) with executors from last to first. + // If i = 1 (Second), it wraps Sink. current = Second(Sink) + // If i = 0 (First), it wraps current. current = First(Second(Sink)) + // So the order should be First -> Second -> Sink. + Assert.Equal(["First", "Second", "Sink"], tracker); + } + + [Fact] + public void AddInferenceSteps_RegistersService() + { + // Arrange + var builder = new PipelineBuilder(_services); + + // Act + builder.AddInferenceSteps(); + + // Assert + var descriptor = _services.FirstOrDefault(d => d.ServiceType == typeof(IInferenceSteps)); + Assert.NotNull(descriptor); + Assert.Equal(typeof(MockInferenceSteps), descriptor.ImplementationType); + } + + [Fact] + public void AddModelExecutor_RegistersFactory() + { + // Arrange + var builder = new PipelineBuilder(_services); + var mockExecutor = Substitute.For>(); + + // Act + builder.AddModelExecutor(_ => mockExecutor); + + // Assert + var sp = _services.BuildServiceProvider(); + var resolved = sp.GetService>(); + Assert.Same(mockExecutor, resolved); + } + + private class MockInferenceSteps : IInferenceSteps + { + public Task ProcessBatch(ReadOnlyMemory inputs, Memory outputs) => Task.CompletedTask; + } + + private class OrderTrackingExecutor(IPipelineBatchExecutor next, string name, List tracker) : IPipelineBatchExecutor + { + public async Task ExecuteBatchPredict(ReadOnlyMemory inputs, Memory outputSpan) + { + tracker.Add(name); + await next.ExecuteBatchPredict(inputs, outputSpan); + } + } + + private class OrderTrackingSink(string name, List tracker) : IPipelineBatchExecutor + { + public Task ExecuteBatchPredict(ReadOnlyMemory inputs, Memory outputSpan) + { + tracker.Add(name); + return Task.CompletedTask; + } + } +} diff --git a/test/FAI.Core.Tests/BatchSchedularTests/ParallelBatchSchedularTests.cs b/test/FAI.Core.Tests/BatchSchedularTests/ParallelBatchSchedularTests.cs new file mode 100644 index 0000000..18ddd71 --- /dev/null +++ b/test/FAI.Core.Tests/BatchSchedularTests/ParallelBatchSchedularTests.cs @@ -0,0 +1,83 @@ +using FAI.Core.Abstractions; +using FAI.Core.BatchSchedulers; +using FAI.Core.Configurations.PipelineBatchExecutors; +using NSubstitute; + +namespace FAI.Core.Tests.BatchSchedularTests; + +public class ParallelBatchSchedularTests +{ + [Fact] + public async Task RunInExecutor_ExecutesAllRanges() + { + // Arrange + var options = new ParallelBatchSchedularOptions { MaxConcurrency = 2 }; + var scheduler = new ParallelBatchSchedular(options); + var executor = Substitute.For>(); + Range[] ranges = [new Range(0, 2), new Range(2, 4), new Range(4, 5)]; + var inputs = new int[5].AsMemory(); + var outputs = new int[5].AsMemory(); + + // Act + await scheduler.RunInExecutor(executor, ranges, inputs, outputs); + + // Assert + await executor.Received(3).ExecuteBatchPredict(Arg.Any>(), Arg.Any>()); + } + + [Fact] + public async Task RunInExecutor_RespectsConcurrencyLimit() + { + // Arrange + var options = new ParallelBatchSchedularOptions { MaxConcurrency = 1 }; + var scheduler = new ParallelBatchSchedular(options); + var executor = Substitute.For>(); + + int activeTasks = 0; + int maxSeenActiveTasks = 0; + var lockObj = new System.Threading.Lock(); + + executor.ExecuteBatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(async _ => + { + lock (lockObj) + { + activeTasks++; + maxSeenActiveTasks = Math.Max(maxSeenActiveTasks, activeTasks); + } + await Task.Delay(10); + lock (lockObj) + { + activeTasks--; + } + }); + + var ranges = Enumerable.Range(0, 10).Select(i => new Range(i, i + 1)).ToList(); + var inputs = new int[10].AsMemory(); + var outputs = new int[10].AsMemory(); + + // Act + await scheduler.RunInExecutor(executor, ranges, inputs, outputs); + + // Assert + Assert.Equal(1, maxSeenActiveTasks); + } + + [Fact] + public async Task RunInExecutor_HandlesEmptyInputs() + { + // Arrange + var options = new ParallelBatchSchedularOptions { MaxConcurrency = 2 }; + var scheduler = new ParallelBatchSchedular(options); + var executor = Substitute.For>(); + var ranges = Enumerable.Empty(); + var inputs = ReadOnlyMemory.Empty; + var outputs = Memory.Empty; + + // Act + await scheduler.RunInExecutor(executor, ranges, inputs, outputs); + + // Assert + await executor.DidNotReceive().ExecuteBatchPredict(Arg.Any>(), Arg.Any>()); + } +} diff --git a/test/FAI.Core.Tests/BatchSchedularTests/SerialBatchSchedularTests.cs b/test/FAI.Core.Tests/BatchSchedularTests/SerialBatchSchedularTests.cs new file mode 100644 index 0000000..c3f273e --- /dev/null +++ b/test/FAI.Core.Tests/BatchSchedularTests/SerialBatchSchedularTests.cs @@ -0,0 +1,26 @@ +using FAI.Core.Abstractions; +using FAI.Core.BatchSchedulers; +using NSubstitute; + +namespace FAI.Core.Tests.BatchSchedularTests; + +public class SerialBatchSchedularTests +{ + [Fact] + public async Task RunInExecutor_ExecutesAllRangesSequentially() + { + // Arrange + var scheduler = new SerialBatchSchedular(); + var executor = Substitute.For>(); + var ranges = new[] { new Range(0, 2), new Range(2, 5) }; + var inputs = new int[5].AsMemory(); + var outputs = new int[5].AsMemory(); + + // Act + await scheduler.RunInExecutor(executor, ranges, inputs, outputs); + + // Assert + await executor.Received(1).ExecuteBatchPredict(Arg.Is>(m => m.Length == 2), Arg.Is>(m => m.Length == 2)); + await executor.Received(1).ExecuteBatchPredict(Arg.Is>(m => m.Length == 3), Arg.Is>(m => m.Length == 3)); + } +} diff --git a/test/FAI.Core.Tests/BatchSlicerTests/FixedSizeBatchSlicerTests.cs b/test/FAI.Core.Tests/BatchSlicerTests/FixedSizeBatchSlicerTests.cs new file mode 100644 index 0000000..e28627e --- /dev/null +++ b/test/FAI.Core.Tests/BatchSlicerTests/FixedSizeBatchSlicerTests.cs @@ -0,0 +1,81 @@ +using FAI.Core.BatchSlicers; + +namespace FAI.Core.Tests.BatchSlicerTests; + +public class FixedSizeBatchSlicerTests +{ + [Fact] + public void Slice_InputSmallerThanBatchSize_ReturnsSingleRange() + { + // Arrange + var options = new FixedSizeBatchSlicerOptions(10); + var slicer = new FixedSizeBatchSlicer(options); + var inputs = new int[5]; + + // Act + var ranges = slicer.Slice(inputs).ToList(); + + // Assert + Assert.Equal([0..5], ranges); + } + + [Fact] + public void Slice_InputMatchesBatchSize_ReturnsSingleRange() + { + // Arrange + var options = new FixedSizeBatchSlicerOptions(5); + var slicer = new FixedSizeBatchSlicer(options); + var inputs = new int[5]; + + // Act + var ranges = slicer.Slice(inputs).ToList(); + + // Assert + Assert.Equal([0..5], ranges); + } + + [Fact] + public void Slice_InputIsMultipleOfBatchSize_ReturnsMultipleRanges() + { + // Arrange + var options = new FixedSizeBatchSlicerOptions(2); + var slicer = new FixedSizeBatchSlicer(options); + var inputs = new int[4]; + + // Act + var ranges = slicer.Slice(inputs).ToList(); + + // Assert + Assert.Equal([0..2, 1..3, 2..4], ranges); + } + + [Fact] + public void Slice_InputWithRemainder_ReturnsRangesIncludingRemainder() + { + // Arrange + var options = new FixedSizeBatchSlicerOptions(3); + var slicer = new FixedSizeBatchSlicer(options); + var inputs = new int[7]; + + // Act + var ranges = slicer.Slice(inputs).ToList(); + + // Assert + Assert.Equal([0..3, 1..4, 2..5, 3..6, 4..7], ranges); + } + + [Fact] + public void Slice_EmptyInput_ReturnsEmpty() + { + // Arrange + var options = new FixedSizeBatchSlicerOptions(5); + var slicer = new FixedSizeBatchSlicer(options); + var inputs = Array.Empty(); + + // Act + var ranges = slicer.Slice(inputs).ToList(); + + // Assert + Assert.Empty(ranges); + } +} diff --git a/test/FAI.Core.Tests/ClassificationTaskTests.cs b/test/FAI.Core.Tests/ClassificationTaskTests.cs new file mode 100644 index 0000000..f4bb8cf --- /dev/null +++ b/test/FAI.Core.Tests/ClassificationTaskTests.cs @@ -0,0 +1,74 @@ +using System.Numerics.Tensors; +using FAI.Core.Abstractions; +using FAI.Core.Configurations.InferenceTasks; +using FAI.Core.InferenceTasks.Classification; +using FAI.Core.ResultTypes; +using NSubstitute; + +namespace FAI.Core.Tests; + +public class ClassificationTaskTests +{ + private class TestClassificationTask : ClassificationTask[], float, string, float> + { + public TestClassificationTask( + IPreprocessor[], float> preprocessor, + IModelExecutor modelExecutor, + ClassificationOptions pipelineOptions) + : base(preprocessor, modelExecutor, pipelineOptions) + { + } + } + + private class MockModelExecutor : IModelExecutor + { + public Task[]> RunAsync(Tensor[] inputs) => throw new NotImplementedException(); + + public Task RunAsync(Tensor[] inputs, Action, int> postProcess) + { + postProcess(new ReadOnlyTensorSpan([1.0f, 2.0f, -1.0f, 5.0f], [2, 2]), 0); + return Task.CompletedTask; + } + } + + [Fact] + public async Task RunModel_CorrectlyProcessesLogits() + { + // Arrange + string[] choices = ["Negative", "Positive"]; + var options = new ClassificationOptions(choices); + var preprocessor = Substitute.For[], float>>(); + var modelExecutor = new MockModelExecutor(); + + var task = new TestClassificationTask(preprocessor, modelExecutor, options); + + string[] inputs = ["input1", "input2"]; + Tensor[] tensors = [Tensor.Create([2, 5])]; // 2 items, 5 features (but we only need 2 for choices) + + // Act + var results = await task.RunModel(inputs.AsMemory(), tensors); + + // Assert + Assert.Equal(2, results.Length); + Assert.Equal("Positive", results[0].Choice); + Assert.Equal("Positive", results[1].Choice); + Assert.True(results[0].Score > 0.5f); + Assert.True(results[1].Score > 0.5f); + } + + [Fact] + public void GetClassificationResult_ReturnsHighestProbability() + { + // Arrange + string[] choices = ["A", "B", "C"]; + var options = new ClassificationOptions(choices); + float[] logits = [1.0f, 5.0f, 2.0f]; + + // Act + var result = options.GetClassificationResult(logits); + + // Assert + Assert.Equal("B", result.Choice); + Assert.True(result.Score > 0.9f); // Softmax of [1, 5, 2] should highly favor 5 + } +} diff --git a/test/FAI.Core.Tests/FAI.Core.Tests.csproj b/test/FAI.Core.Tests/FAI.Core.Tests.csproj new file mode 100644 index 0000000..178940f --- /dev/null +++ b/test/FAI.Core.Tests/FAI.Core.Tests.csproj @@ -0,0 +1,5 @@ + + + + + diff --git a/test/FAI.Core.Tests/InferenceStepsTests.cs b/test/FAI.Core.Tests/InferenceStepsTests.cs new file mode 100644 index 0000000..30b78ad --- /dev/null +++ b/test/FAI.Core.Tests/InferenceStepsTests.cs @@ -0,0 +1,82 @@ +using FAI.Core.Abstractions; +using NSubstitute; + +namespace FAI.Core.Tests; + +public class InferenceStepsTests +{ + private class TestInferenceSteps : InferenceSteps + { + public override int Preprocess(ReadOnlySpan input) => input.Length; + + public override Task RunModel(ReadOnlyMemory input, int preprocesses) + { + return Task.FromResult((double)preprocesses * 2.0); + } + + public override void PostProcess(ReadOnlySpan inputs, int preprocesses, double modelOutput, Span outputs) + { + for (int i = 0; i < outputs.Length; i++) + { + outputs[i] = $"{modelOutput}"; + } + } + } + + [Fact] + public async Task ProcessBatch_ExecutesAllStepsInOrder() + { + // Arrange + var steps = new TestInferenceSteps(); + string[] inputs = ["a", "b", "c"]; + var outputs = new string[3]; + + // Act + await steps.ProcessBatch(inputs, outputs); + + // Assert + Assert.All(outputs, o => Assert.Equal("6", o)); + } + + private class MockInferenceSteps : TestInferenceSteps + { + public bool PreprocessCalled { get; private set; } + public bool RunModelCalled { get; private set; } + public bool PostProcessCalled { get; private set; } + + public override int Preprocess(ReadOnlySpan input) + { + PreprocessCalled = true; + return base.Preprocess(input); + } + + public override Task RunModel(ReadOnlyMemory input, int preprocesses) + { + RunModelCalled = true; + return base.RunModel(input, preprocesses); + } + + public override void PostProcess(ReadOnlySpan inputs, int preprocesses, double modelOutput, Span outputs) + { + PostProcessCalled = true; + base.PostProcess(inputs, preprocesses, modelOutput, outputs); + } + } + + [Fact] + public async Task ProcessBatch_ExecutesAllSteps() + { + // Arrange + var steps = new MockInferenceSteps(); + ReadOnlyMemory inputs = (string[])["a", "b"]; + var outputs = new string[2].AsMemory(); + + // Act + await steps.ProcessBatch(inputs, outputs); + + // Assert + Assert.True(steps.PreprocessCalled); + Assert.True(steps.RunModelCalled); + Assert.True(steps.PostProcessCalled); + } +} diff --git a/test/FAI.Core.Tests/PipelineBatchExecutorTests/BackgroundPipelineBatchExecutorTests.cs b/test/FAI.Core.Tests/PipelineBatchExecutorTests/BackgroundPipelineBatchExecutorTests.cs new file mode 100644 index 0000000..ad92697 --- /dev/null +++ b/test/FAI.Core.Tests/PipelineBatchExecutorTests/BackgroundPipelineBatchExecutorTests.cs @@ -0,0 +1,88 @@ +using FAI.Core.Abstractions; +using FAI.Core.PipelineBatchExecutors; +using NSubstitute; + +namespace FAI.Core.Tests.PipelineBatchExecutorTests; + +public class BackgroundPipelineBatchExecutorTests +{ + [Fact] + public async Task ExecuteBatchPredict_OffloadsToWorkerAndPropagatesResult() + { + // Arrange + int[] inputData = [10, 20, 30, 40, 50]; + ReadOnlyMemory inputs = inputData; + int[] outputArray = new int[5]; + Memory outputs = outputArray; + + var next = Substitute.For>(); + next.ExecuteBatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(callInfo => + { + var input = callInfo.ArgAt>(0); + var output = callInfo.ArgAt>(1); + + // Simulate actual processing: double each input value + for (int i = 0; i < input.Length; i++) + { + output.Span[i] = input.Span[i] * 2; + } + return Task.CompletedTask; + }); + + var executor = new BackgroundPipelineBatchExecutor(next, workerCount: 1); + + // Act + await executor.ExecuteBatchPredict(inputs, outputs); + + // Assert + await next.Received(1).ExecuteBatchPredict( + Arg.Is>(m => m.Length == 5), + Arg.Any>()); + + Assert.Equal([20, 40, 60, 80, 100], outputArray); + } + + [Fact] + public async Task ExecuteBatchPredict_PropagatesException() + { + // Arrange + int[] inputData = [1, 2, 3]; + ReadOnlyMemory inputs = inputData; + Memory outputs = new int[3]; + + var next = Substitute.For>(); + var exception = new InvalidOperationException("Model execution failed"); + next.ExecuteBatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(Task.FromException(exception)); + + var executor = new BackgroundPipelineBatchExecutor(next, workerCount: 2); + + // Act & Assert + var ex = await Assert.ThrowsAsync(() => + executor.ExecuteBatchPredict(inputs, outputs)); + Assert.Equal("Model execution failed", ex.Message); + } + + [Fact] + public async Task ExecuteBatchPredict_HandlesEmptyInput() + { + // Arrange + ReadOnlyMemory inputs = Array.Empty(); + Memory outputs = Array.Empty(); + + var next = Substitute.For>(); + next.ExecuteBatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(Task.CompletedTask); + + var executor = new BackgroundPipelineBatchExecutor(next, workerCount: 1); + + // Act + await executor.ExecuteBatchPredict(inputs, outputs); + + // Assert + await next.Received(1).ExecuteBatchPredict( + Arg.Is>(m => m.Length == 0), + Arg.Is>(m => m.Length == 0)); + } +} diff --git a/test/FAI.Core.Tests/PipelineBatchExecutorTests/PartitionPipelineBatchExecutorTests.cs b/test/FAI.Core.Tests/PipelineBatchExecutorTests/PartitionPipelineBatchExecutorTests.cs new file mode 100644 index 0000000..9180478 --- /dev/null +++ b/test/FAI.Core.Tests/PipelineBatchExecutorTests/PartitionPipelineBatchExecutorTests.cs @@ -0,0 +1,98 @@ +using FAI.Core.Abstractions; +using FAI.Core.PipelineBatchExecutors; +using NSubstitute; + +namespace FAI.Core.Tests.PipelineBatchExecutorTests; + +public class PartitionPipelineBatchExecutorTests +{ + [Fact] + public async Task ExecuteBatchPredict_SlicesAndSchedules() + { + // Arrange - Demonstrate batch partitioning for parallel processing + int[] inputData = [100, 200, 300, 400, 500, 600]; + ReadOnlyMemory inputs = inputData; + int[] outputData = new int[6]; + Memory outputs = outputData; + + // Mock slicer to partition batch into smaller chunks + var slicer = Substitute.For>(); + Range[] partitions = [0..2, 2..4, 4..6]; // 3 partitions of 2 items each + slicer.Slice(Arg.Any>()).Returns(partitions); + + // Mock scheduler to coordinate parallel execution + var schedular = Substitute.For>(); + schedular.RunInExecutor( + Arg.Any>(), + Arg.Any>(), + Arg.Any>(), + Arg.Any>()) + .Returns(callInfo => + { + var executor = callInfo.ArgAt>(0); + var ranges = callInfo.ArgAt>(1); + var inp = callInfo.ArgAt>(2); + var outp = callInfo.ArgAt>(3); + + // Simulate parallel processing of each partition + foreach (var range in ranges) + { + for (int i = range.Start.Value; i < range.End.Value; i++) + { + outp.Span[i] = inp.Span[i] / 10; // Divide by 10 + } + } + return Task.CompletedTask; + }); + + var innerExecutor = Substitute.For>(); + var executor = new PartitionPipelineBatchExecutor(schedular, slicer, innerExecutor); + + // Act + await executor.ExecuteBatchPredict(inputs, outputs); + + // Assert + slicer.Received(1).Slice(Arg.Any>()); + await schedular.Received(1).RunInExecutor( + innerExecutor, + Arg.Is>(r => r.SequenceEqual(partitions)), + Arg.Any>(), + Arg.Any>()); + + Assert.Equal([10, 20, 30, 40, 50, 60], outputData); + } + + [Fact] + public async Task ExecuteBatchPredict_HandlesEmptyInput() + { + // Arrange + ReadOnlyMemory inputs = Array.Empty(); + Memory outputs = Array.Empty(); + + var slicer = Substitute.For>(); + Range[] emptyRanges = []; + slicer.Slice(Arg.Any>()).Returns(emptyRanges); + + var schedular = Substitute.For>(); + schedular.RunInExecutor( + Arg.Any>(), + Arg.Any>(), + Arg.Any>(), + Arg.Any>()) + .Returns(Task.CompletedTask); + + var innerExecutor = Substitute.For>(); + var executor = new PartitionPipelineBatchExecutor(schedular, slicer, innerExecutor); + + // Act + await executor.ExecuteBatchPredict(inputs, outputs); + + // Assert + slicer.Received(1).Slice(Arg.Any>()); + await schedular.Received(1).RunInExecutor( + innerExecutor, + Arg.Is>(r => !r.Any()), + Arg.Any>(), + Arg.Any>()); + } +} diff --git a/test/FAI.Core.Tests/PipelineBatchExecutorTests/PipelineLinkBatchExecutorTests.cs b/test/FAI.Core.Tests/PipelineBatchExecutorTests/PipelineLinkBatchExecutorTests.cs new file mode 100644 index 0000000..de82e4b --- /dev/null +++ b/test/FAI.Core.Tests/PipelineBatchExecutorTests/PipelineLinkBatchExecutorTests.cs @@ -0,0 +1,93 @@ +using System.Buffers; +using FAI.Core.Abstractions; +using FAI.Core.PipelineBatchExecutors; +using NSubstitute; + +namespace FAI.Core.Tests.PipelineBatchExecutorTests; + +public class PipelineLinkBatchExecutorTests +{ + [Fact] + public async Task ExecuteBatchPredict_TransformsUserIdsToStringsForNextPipeline() + { + // Arrange: Transform user IDs (int) to strings for downstream text processing pipeline + var nextPipeline = Substitute.For>(); + Func userIdToString = userId => $"USER_{userId:D6}"; // Format: USER_000123 + var pool = ArrayPool.Shared; + var executor = new PipelineLinkBatchExecutor(nextPipeline, userIdToString, pool); + + int[] userIdsArray = [123, 456, 789]; + ReadOnlyMemory userIds = userIdsArray; + Memory sentimentScores = new int[3]; + + nextPipeline.BatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(async callInfo => + { + var transformedInputs = callInfo.ArgAt>(0); + var outputs = callInfo.ArgAt>(1); + + // Simulate sentiment analysis: longer strings = higher scores + for (int i = 0; i < transformedInputs.Length; i++) + { + outputs.Span[i] = transformedInputs.Span[i].Length; + } + + await Task.CompletedTask; + }); + + // Act + await executor.ExecuteBatchPredict(userIds, sentimentScores); + + // Assert: Verify transformation was applied correctly + string[] expectedTransformed = ["USER_000123", "USER_000456", "USER_000789"]; + await nextPipeline.Received(1).BatchPredict( + Arg.Is>(m => m.ToArray().SequenceEqual(expectedTransformed)), + Arg.Any>()); + + // Verify outputs were populated by next pipeline + Assert.Equal(11, sentimentScores.Span[0]); // "USER_000123".Length = 11 + Assert.Equal(11, sentimentScores.Span[1]); // "USER_000456".Length = 11 + Assert.Equal(11, sentimentScores.Span[2]); // "USER_000789".Length = 11 + } + + [Fact] + public async Task ExecuteBatchPredict_UsesArrayPoolForIntermediateBuffer() + { + // Arrange: Demonstrate zero-allocation pattern using ArrayPool + var nextPipeline = Substitute.For>(); + Func temperatureFormatter = temp => $"{temp:F1}°C"; + var pool = ArrayPool.Shared; + var executor = new PipelineLinkBatchExecutor(nextPipeline, temperatureFormatter, pool); + + double[] temperaturesArray = [36.6, 37.2, 38.5, 39.1]; + ReadOnlyMemory temperatures = temperaturesArray; + Memory isFever = new bool[4]; + + nextPipeline.BatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(async callInfo => + { + var formattedTemps = callInfo.ArgAt>(0); + var outputs = callInfo.ArgAt>(1); + + // Determine fever status: > 37.5°C + for (int i = 0; i < formattedTemps.Length; i++) + { + // Parse temperature from formatted string + string temp = formattedTemps.Span[i]; + double value = double.Parse(temp.Replace("°C", "")); + outputs.Span[i] = value > 37.5; + } + + await Task.CompletedTask; + }); + + // Act + await executor.ExecuteBatchPredict(temperatures, isFever); + + // Assert: Verify correct transformation and results + Assert.False(isFever.Span[0]); // 36.6 - normal + Assert.False(isFever.Span[1]); // 37.2 - normal + Assert.True(isFever.Span[2]); // 38.5 - fever + Assert.True(isFever.Span[3]); // 39.1 - fever + } +} diff --git a/test/FAI.Core.Tests/PipelineBatchExecutorTests/RoutingPipelineBatchExecutorTests.cs b/test/FAI.Core.Tests/PipelineBatchExecutorTests/RoutingPipelineBatchExecutorTests.cs new file mode 100644 index 0000000..ed377bb --- /dev/null +++ b/test/FAI.Core.Tests/PipelineBatchExecutorTests/RoutingPipelineBatchExecutorTests.cs @@ -0,0 +1,146 @@ +using FAI.Core.Abstractions; +using FAI.Core.PipelineBatchExecutors; +using NSubstitute; + +namespace FAI.Core.Tests.PipelineBatchExecutorTests; + +public class RoutingPipelineBatchExecutorTests +{ + [Fact] + public async Task ExecuteBatchPredict_RoutesCorrectlyAndMergesOutputs() + { + // Arrange - Simulate routing specific batch items to specialized executors + int[] inputData = [10, 20, 30, 40, 50]; + ReadOnlyMemory inputs = inputData; + int[] outputData = new int[5]; + Memory outputs = outputData; + + var executor1 = Substitute.For>(); + executor1.ExecuteBatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(callInfo => + { + var input = callInfo.ArgAt>(0); + var output = callInfo.ArgAt>(1); + + // Process by incrementing + for (int i = 0; i < input.Length; i++) + { + output.Span[i] = input.Span[i] + 1; + } + return Task.CompletedTask; + }); + + var strategy = Substitute.For>(); + // Route items at index 0, 2, 4 to executor1 (odd positions remain unprocessed) + var routingResult = new BatchExecutionRoutingResult(executor1, [0..1, 2..3, 4..5]); + strategy.Route(Arg.Any[]>(), Arg.Any>()) + .Returns([routingResult]); + + var routingExecutor = new RoutingPipelineBatchExecutor([executor1], strategy); + + // Act + await routingExecutor.ExecuteBatchPredict(inputs, outputs); + + // Assert + Assert.Equal(11, outputData[0]); // 10 + 1 (routed) + Assert.Equal(0, outputData[1]); // Not routed + Assert.Equal(31, outputData[2]); // 30 + 1 (routed) + Assert.Equal(0, outputData[3]); // Not routed + Assert.Equal(51, outputData[4]); // 50 + 1 (routed) + } + + [Fact] + public async Task ExecuteBatchPredict_MultipleExecutors_RoutesCorrectly() + { + // Arrange - Demonstrate routing to different executors (e.g., CPU vs GPU models) + int[] inputData = [1, 2, 3, 4]; + ReadOnlyMemory inputs = inputData; + int[] outputData = new int[4]; + Memory outputs = outputData; + + // Fast executor (CPU) - multiplies by 10 + var fastExecutor = Substitute.For>(); + fastExecutor.ExecuteBatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(callInfo => + { + var input = callInfo.ArgAt>(0); + var output = callInfo.ArgAt>(1); + for (int i = 0; i < input.Length; i++) + { + output.Span[i] = input.Span[i] * 10; + } + return Task.CompletedTask; + }); + + // Accurate executor (GPU) - multiplies by 100 + var accurateExecutor = Substitute.For>(); + accurateExecutor.ExecuteBatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(callInfo => + { + var input = callInfo.ArgAt>(0); + var output = callInfo.ArgAt>(1); + for (int i = 0; i < input.Length; i++) + { + output.Span[i] = input.Span[i] * 100; + } + return Task.CompletedTask; + }); + + var strategy = Substitute.For>(); + // Route first 2 items to fast executor, last 2 to accurate executor + var route1 = new BatchExecutionRoutingResult(fastExecutor, [0..2]); + var route2 = new BatchExecutionRoutingResult(accurateExecutor, [2..4]); + strategy.Route(Arg.Any[]>(), Arg.Any>()) + .Returns([route1, route2]); + + var routingExecutor = new RoutingPipelineBatchExecutor([fastExecutor, accurateExecutor], strategy); + + // Act + await routingExecutor.ExecuteBatchPredict(inputs, outputs); + + // Assert + Assert.Equal([10, 20, 300, 400], outputData); + } + + [Fact] + public async Task ExecuteBatchPredict_EmptyResults_DoesNothing() + { + // Arrange - No routing strategy returns empty results + int[] inputData = [1, 2, 3, 4, 5]; + ReadOnlyMemory inputs = inputData; + int[] outputData = new int[5]; + Memory outputs = outputData; + + var strategy = Substitute.For>(); + strategy.Route(Arg.Any[]>(), Arg.Any>()) + .Returns([]); + + var routingExecutor = new RoutingPipelineBatchExecutor([], strategy); + + // Act + await routingExecutor.ExecuteBatchPredict(inputs, outputs); + + // Assert - All outputs remain at default (0) + Assert.All(outputData, x => Assert.Equal(0, x)); + } + + [Fact] + public async Task ExecuteBatchPredict_HandlesEmptyInput() + { + // Arrange + ReadOnlyMemory inputs = Array.Empty(); + Memory outputs = Array.Empty(); + + var strategy = Substitute.For>(); + strategy.Route(Arg.Any[]>(), Arg.Any>()) + .Returns([]); + + var routingExecutor = new RoutingPipelineBatchExecutor([], strategy); + + // Act + await routingExecutor.ExecuteBatchPredict(inputs, outputs); + + // Assert + strategy.Received(1).Route(Arg.Any[]>(), Arg.Any>()); + } +} diff --git a/test/FAI.Core.Tests/PipelineBatchExecutorTests/SinkPipelineBatchExecutorTests.cs b/test/FAI.Core.Tests/PipelineBatchExecutorTests/SinkPipelineBatchExecutorTests.cs new file mode 100644 index 0000000..48d9c05 --- /dev/null +++ b/test/FAI.Core.Tests/PipelineBatchExecutorTests/SinkPipelineBatchExecutorTests.cs @@ -0,0 +1,107 @@ +using FAI.Core.Abstractions; +using FAI.Core.PipelineBatchExecutors; +using NSubstitute; + +namespace FAI.Core.Tests.PipelineBatchExecutorTests; + +public class SinkPipelineBatchExecutorTests +{ + [Fact] + public async Task ExecuteBatchPredict_ProcessesImageBatchThroughInferenceSteps() + { + // Arrange: Sink executor terminates pipeline by invoking actual ML inference + var inferenceSteps = Substitute.For>(); + var executor = new SinkPipelineBatchExecutor(inferenceSteps); + + string[] imagePathsArray = ["cat.jpg", "dog.jpg", "bird.jpg"]; + ReadOnlyMemory imagePaths = imagePathsArray.AsMemory(); + Memory classifications = new int[3]; + + inferenceSteps.ProcessBatch(Arg.Any>(), Arg.Any>()) + .Returns(async callInfo => + { + var inputs = callInfo.ArgAt>(0); + var outputs = callInfo.ArgAt>(1); + + // Simulate image classification: assign class IDs based on filename + for (int i = 0; i < inputs.Length; i++) + { + string filename = inputs.Span[i]; + outputs.Span[i] = filename.Contains("cat") ? 0 : + filename.Contains("dog") ? 1 : 2; + } + + await Task.CompletedTask; + }); + + // Act + await executor.ExecuteBatchPredict(imagePaths, classifications); + + // Assert: Verify inference steps were invoked and outputs populated + string[] expectedPaths = ["cat.jpg", "dog.jpg", "bird.jpg"]; + await inferenceSteps.Received(1).ProcessBatch( + Arg.Is>(m => m.ToArray().SequenceEqual(expectedPaths)), + Arg.Any>()); + + Assert.Equal(0, classifications.Span[0]); // cat + Assert.Equal(1, classifications.Span[1]); // dog + Assert.Equal(2, classifications.Span[2]); // bird + } + + [Fact] + public async Task ExecuteBatchPredict_HandlesEmptyBatch() + { + // Arrange: Demonstrate graceful handling of empty batches + var inferenceSteps = Substitute.For>(); + var executor = new SinkPipelineBatchExecutor(inferenceSteps); + + var inputs = ReadOnlyMemory.Empty; + var outputs = Memory.Empty; + + // Act + await executor.ExecuteBatchPredict(inputs, outputs); + + // Assert: Should still invoke inference steps (even with empty batch) + await inferenceSteps.Received(1).ProcessBatch( + Arg.Is>(m => m.Length == 0), + Arg.Is>(m => m.Length == 0)); + } + + [Fact] + public async Task ExecuteBatchPredict_PropagatesOutputsFromInferenceSteps() + { + // Arrange: Verify that outputs from inference steps are correctly propagated + var inferenceSteps = Substitute.For>(); + var executor = new SinkPipelineBatchExecutor(inferenceSteps); + + float[] sensorReadingsArray = [23.5f, 25.1f, 22.8f, 24.3f]; + ReadOnlyMemory sensorReadings = sensorReadingsArray.AsMemory(); + Memory normalizedValues = new float[4]; + + inferenceSteps.ProcessBatch(Arg.Any>(), Arg.Any>()) + .Returns(async callInfo => + { + var inputs = callInfo.ArgAt>(0); + var outputs = callInfo.ArgAt>(1); + + // Normalize sensor readings to 0-1 range + float min = 22.8f; + float max = 25.1f; + for (int i = 0; i < inputs.Length; i++) + { + outputs.Span[i] = (inputs.Span[i] - min) / (max - min); + } + + await Task.CompletedTask; + }); + + // Act + await executor.ExecuteBatchPredict(sensorReadings, normalizedValues); + + // Assert: Verify normalized outputs + Assert.Equal(0.304f, normalizedValues.Span[0], precision: 3); // (23.5-22.8)/(25.1-22.8) + Assert.Equal(1.000f, normalizedValues.Span[1], precision: 3); // (25.1-22.8)/(25.1-22.8) + Assert.Equal(0.000f, normalizedValues.Span[2], precision: 3); // (22.8-22.8)/(25.1-22.8) + Assert.Equal(0.652f, normalizedValues.Span[3], precision: 3); // (24.3-22.8)/(25.1-22.8) + } +} diff --git a/test/FAI.Core.Tests/PipelineBatchExecutorTests/StreamedBatchExecutorTests.cs b/test/FAI.Core.Tests/PipelineBatchExecutorTests/StreamedBatchExecutorTests.cs new file mode 100644 index 0000000..ddd0b82 --- /dev/null +++ b/test/FAI.Core.Tests/PipelineBatchExecutorTests/StreamedBatchExecutorTests.cs @@ -0,0 +1,212 @@ +using FAI.Core.Abstractions; +using FAI.Core.PipelineBatchExecutors; +using Microsoft.Extensions.Logging.Abstractions; +using NSubstitute; + +namespace FAI.Core.Tests.PipelineBatchExecutorTests; + +public class StreamedBatchExecutorTests +{ + /// + /// Simulates image preprocessing: counts pixels, runs classification model, assigns class labels + /// + private class ImageClassificationInferenceSteps : InferenceSteps + { + // Preprocess: Count characters in image path (simulating pixel counting) + public override int Preprocess(ReadOnlySpan input) => input[0].Length; + + // RunModel: Simulate neural network producing classification scores + public override Task RunModel(ReadOnlyMemory input, int pixelCount) + { + // Generate mock confidence scores based on pixel count + float[] scores = [pixelCount / 100f, (100 - pixelCount) / 100f, 0.5f]; + return Task.FromResult(scores); + } + + // PostProcess: Convert scores to class labels + public override void PostProcess(ReadOnlySpan inputs, int preprocesses, float[] modelOutput, Span outputs) + { + for (int i = 0; i < outputs.Length; i++) + { + int maxIndex = 0; + for (int j = 1; j < modelOutput.Length; j++) + { + if (modelOutput[j] > modelOutput[maxIndex]) + maxIndex = j; + } + outputs[i] = maxIndex == 0 ? "cat" : maxIndex == 1 ? "dog" : "bird"; + } + } + } + + [Fact] + public async Task ExecuteBatchPredict_StreamsImageClassificationThroughPipeline() + { + // Arrange: Demonstrate streaming pipeline with Preprocess → Model → PostProcess stages + var inference = new ImageClassificationInferenceSteps(); + var executor = new StreamedBatchExecutor( + inference, + maxBatchSize: null, + maxConcurrency: null, + parallelTokenization: false, + NullLogger>.Instance); + + string[] imagePathsArray = ["images/cat_001.jpg", "images/dog_002.jpg", "images/bird_003.jpg"]; + ReadOnlyMemory imagePaths = imagePathsArray.AsMemory(); + Memory predictions = new string[3]; + + // Act + await executor.ExecuteBatchPredict(imagePaths, predictions); + + // Assert: Verify full pipeline execution + Assert.Equal("dog", predictions.Span[0]); // cat_001.jpg → 18 chars → dog wins + Assert.Equal("dog", predictions.Span[1]); // dog_002.jpg → 18 chars → dog wins + Assert.Equal("dog", predictions.Span[2]); // bird_003.jpg → 19 chars → dog wins + } + + /// + /// Simulates text sentiment analysis with realistic preprocessing and scoring + /// + private class SentimentAnalysisInferenceSteps : InferenceSteps + { + // Preprocess: Token count + public override int Preprocess(ReadOnlySpan input) => input[0].Split(' ').Length; + + // RunModel: Sentiment score (-1 to 1) + public override Task RunModel(ReadOnlyMemory input, int tokenCount) + { + // Longer texts tend to be more positive (simplified heuristic) + float sentimentScore = (tokenCount - 5) / 10f; + return Task.FromResult(sentimentScore); + } + + // PostProcess: Convert to binary positive/negative + public override void PostProcess(ReadOnlySpan inputs, int preprocesses, float modelOutput, Span outputs) + { + for (int i = 0; i < outputs.Length; i++) + { + outputs[i] = modelOutput > 0; // Positive if score > 0 + } + } + } + + [Fact] + public async Task ExecuteBatchPredict_ProcessesBatchesWithChunking() + { + // Arrange: Demonstrate chunked processing for large batches + var inference = new SentimentAnalysisInferenceSteps(); + var executor = new StreamedBatchExecutor( + inference, + maxBatchSize: 2, // Process 2 reviews at a time + maxConcurrency: 1, + parallelTokenization: false, + NullLogger>.Instance); + + string[] reviewsArray = [ + "Great product", // 2 tokens → -0.3 → negative + "Absolutely loved it highly recommend", // 5 tokens → 0.0 → negative + "Best purchase I have ever made in my entire life", // 10 tokens → 0.5 → positive + "Amazing quality and fast shipping service", // 6 tokens → 0.1 → positive + "Perfect" // 1 token → -0.4 → negative + ]; + ReadOnlyMemory reviews = reviewsArray.AsMemory(); + Memory sentiments = new bool[5]; + + // Act + await executor.ExecuteBatchPredict(reviews, sentiments); + + // Assert: Verify chunked processing results + Assert.False(sentiments.Span[0]); // "Great product" → 2 tokens → negative + Assert.False(sentiments.Span[1]); // "Absolutely loved..." → 5 tokens → neutral/negative + Assert.True(sentiments.Span[2]); // "Best purchase..." → 10 tokens → positive + Assert.True(sentiments.Span[3]); // "Amazing quality..." → 6 tokens → positive + Assert.False(sentiments.Span[4]); // "Perfect" → 1 token → negative + } + + private class FailingModelInferenceSteps : SentimentAnalysisInferenceSteps + { + public override Task RunModel(ReadOnlyMemory input, int tokenCount) + { + throw new InvalidOperationException("Model inference failed"); + } + } + + [Fact] + public async Task ExecuteBatchPredict_ModelFailure_PropagatesException() + { + // Arrange: Demonstrate error handling in model execution stage + var inference = new FailingModelInferenceSteps(); + var executor = new StreamedBatchExecutor( + inference, + maxBatchSize: null, + maxConcurrency: null, + parallelTokenization: false, + NullLogger>.Instance); + + string[] textsArray = ["This will fail", "Another text"]; + ReadOnlyMemory texts = textsArray.AsMemory(); + Memory outputs = new bool[2]; + + // Act & Assert + await Assert.ThrowsAsync(() => executor.ExecuteBatchPredict(texts, outputs)); + } + + private class FailingPostProcessInferenceSteps : SentimentAnalysisInferenceSteps + { + public override void PostProcess(ReadOnlySpan inputs, int preprocesses, float modelOutput, Span outputs) + { + throw new InvalidOperationException("Post-processing failure"); + } + } + + [Fact] + public async Task ExecuteBatchPredict_PostProcessFailure_PropagatesException() + { + // Arrange: Demonstrate error handling in post-processing stage + var inference = new FailingPostProcessInferenceSteps(); + var executor = new StreamedBatchExecutor( + inference, + maxBatchSize: null, + maxConcurrency: null, + parallelTokenization: false, + NullLogger>.Instance); + + string[] textsArray = ["Test input"]; + ReadOnlyMemory texts = textsArray.AsMemory(); + Memory outputs = new bool[1]; + + // Act & Assert + await Assert.ThrowsAsync(() => executor.ExecuteBatchPredict(texts, outputs)); + } + + [Fact] + public async Task ExecuteBatchPredict_ParallelPreprocessing_ProcessesChunksConcurrently() + { + // Arrange: Demonstrate parallel preprocessing for better throughput + var inference = new SentimentAnalysisInferenceSteps(); + var executor = new StreamedBatchExecutor( + inference, + maxBatchSize: 2, + maxConcurrency: 4, // Allow parallel preprocessing + parallelTokenization: true, // Enable parallel mode + NullLogger>.Instance); + + string[] reviewsArray = [ + "Good value for money and quality", + "Would definitely buy this again", + "Not satisfied with the product quality", + "Exceeded my expectations completely" + ]; + ReadOnlyMemory reviews = reviewsArray.AsMemory(); + Memory sentiments = new bool[4]; + + // Act + await executor.ExecuteBatchPredict(reviews, sentiments); + + // Assert: Verify all batches processed correctly + Assert.True(sentiments.Span[0]); // 5 tokens → positive + Assert.True(sentiments.Span[1]); // 5 tokens → positive + Assert.True(sentiments.Span[2]); // 6 tokens → positive + Assert.True(sentiments.Span[3]); // 4 tokens → negative (borderline) + } +} diff --git a/test/FAI.Core.Tests/PipelineTests/AccumulatingPipelineTests.cs b/test/FAI.Core.Tests/PipelineTests/AccumulatingPipelineTests.cs new file mode 100644 index 0000000..bb76974 --- /dev/null +++ b/test/FAI.Core.Tests/PipelineTests/AccumulatingPipelineTests.cs @@ -0,0 +1,140 @@ +using FAI.Core.Abstractions; +using FAI.Core.Pipelines; +using Microsoft.Extensions.Logging.Abstractions; +using NSubstitute; + +namespace FAI.Core.Tests.PipelineTests; + +public class AccumulatingPipelineTests +{ + [Fact] + public async Task Predict_MaxBatchSizeTrigger_FlushesBatch() + { + // Arrange + var executor = Substitute.For>(); + var options = new AccumulatingPipelineOptions + { + MaxBatchSize = 3, + MaxLatency = TimeSpan.FromSeconds(10) + }; + var policy = Substitute.For>(); + var pipeline = new AccumulatingPipeline(executor, options, policy, NullLogger>.Instance); + + executor.ExecuteBatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(async x => + { + var input = (ReadOnlyMemory)x[0]; + var output = (Memory)x[1]; + for (int i = 0; i < input.Length; i++) output.Span[i] = input.Span[i] * 2; + await Task.CompletedTask; + }); + + // Act + var t1 = pipeline.Predict(1); + var t2 = pipeline.Predict(2); + var t3 = pipeline.Predict(3); + + var results = await Task.WhenAll(t1, t2, t3); + + // Assert + Assert.Equal(new[] { 2, 4, 6 }, results); + await executor.Received(1).ExecuteBatchPredict(Arg.Is>(m => m.Length == 3), Arg.Any>()); + } + + [Fact] + public async Task Predict_MaxLatencyTrigger_FlushesPartialBatch() + { + // Arrange + var executor = Substitute.For>(); + var options = new AccumulatingPipelineOptions + { + MaxBatchSize = 10, + MaxLatency = TimeSpan.FromMilliseconds(50) + }; + var policy = Substitute.For>(); + var pipeline = new AccumulatingPipeline(executor, options, policy, NullLogger>.Instance); + + executor.ExecuteBatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(async x => + { + var input = (ReadOnlyMemory)x[0]; + var output = (Memory)x[1]; + for (int i = 0; i < input.Length; i++) output.Span[i] = input.Span[i] * 2; + await Task.CompletedTask; + }); + + // Act + var result = await pipeline.Predict(5); + + // Assert + Assert.Equal(10, result); + await executor.Received(1).ExecuteBatchPredict(Arg.Is>(m => m.Length == 1), Arg.Any>()); + } + + [Fact] + public async Task Predict_FailedBatch_InvokesPolicy() + { + // Arrange + var executor = Substitute.For>(); + var options = new AccumulatingPipelineOptions + { + MaxBatchSize = 1, + MaxLatency = TimeSpan.FromSeconds(1) + }; + var policy = Substitute.For>(); + var pipeline = new AccumulatingPipeline(executor, options, policy, NullLogger>.Instance); + + var exception = new Exception("Fail"); + executor.ExecuteBatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(Task.FromException(exception)); + + policy.HandleAsync(Arg.Any>(), Arg.Any>(), executor, exception, Arg.Any()) + .Returns(x => + { + var output = (Memory)x[1]; + output.Span[0] = -1; // Fallback + return Task.CompletedTask; + }); + + // Act + var result = await pipeline.Predict(100); + + // Assert + Assert.Equal(-1, result); + await policy.Received(1).HandleAsync(Arg.Any>(), Arg.Any>(), executor, exception, Arg.Any()); + } + + [Fact] + public async Task Predict_Disposed_ThrowsObjectDisposedException() + { + // Arrange + var executor = Substitute.For>(); + var options = new AccumulatingPipelineOptions + { + MaxBatchSize = 10, + MaxLatency = TimeSpan.FromSeconds(10) + }; + var policy = Substitute.For>(); + var pipeline = new AccumulatingPipeline(executor, options, policy, NullLogger>.Instance); + + pipeline.Dispose(); + + // Act & Assert + await Assert.ThrowsAsync(() => pipeline.Predict(2)); + } + + [Fact] + public async Task BatchPredict_EmptyInput_ReturnsEmptyResults() + { + // Arrange + var executor = Substitute.For>(); + var options = new AccumulatingPipelineOptions { MaxBatchSize = 1 }; + var pipeline = new AccumulatingPipeline(executor, options, Substitute.For>(), NullLogger>.Instance); + + // Act + var results = await pipeline.BatchPredict(ReadOnlyMemory.Empty); + + // Assert + Assert.Empty(results); + } +} diff --git a/test/FAI.Core.Tests/TensorExtensionsTests.cs b/test/FAI.Core.Tests/TensorExtensionsTests.cs new file mode 100644 index 0000000..08d9ce3 --- /dev/null +++ b/test/FAI.Core.Tests/TensorExtensionsTests.cs @@ -0,0 +1,66 @@ +using System.Numerics.Tensors; +using FAI.Core; + +namespace FAI.Core.Tests; + +public class TensorExtensionsTests +{ + [Fact] + public void AsSpan_TensorSpan_ReturnsCorrectSpan() + { + // Arrange + float[] data = [1.0f, 2.0f, 3.0f]; + var tensorSpan = new TensorSpan(data, [3]); + + // Act + var span = tensorSpan.AsSpan(); + + // Assert + Assert.Equal(3, span.Length); + Assert.Equal(1.0f, span[0]); + Assert.Equal(2.0f, span[1]); + Assert.Equal(3.0f, span[2]); + + // Verify it's a reference to the same data + span[1] = 42.0f; + Assert.Equal(42.0f, data[1]); + } + + [Fact] + public void AsSpan_ReadOnlyTensorSpan_ReturnsCorrectSpan() + { + // Arrange + float[] data = [1.0f, 2.0f, 3.0f]; + var tensorSpan = new ReadOnlyTensorSpan(data, [3]); + + // Act + var span = tensorSpan.AsSpan(); + + // Assert + Assert.Equal(3, span.Length); + Assert.Equal(1.0f, span[0]); + Assert.Equal(2.0f, span[1]); + Assert.Equal(3.0f, span[2]); + } + + [Fact] + public void AsMemory_Tensor_ReturnsCorrectMemory() + { + // Arrange + float[] data = [1.0f, 2.0f, 3.0f]; + var tensor = Tensor.Create(data, [3]); + + // Act + var memory = tensor.AsMemory(); + + // Assert + Assert.Equal(3, memory.Length); + Assert.Equal(1.0f, memory.Span[0]); + Assert.Equal(2.0f, memory.Span[1]); + Assert.Equal(3.0f, memory.Span[2]); + + // Verify reference + memory.Span[1] = 42.0f; + Assert.Equal(42.0f, data[1]); + } +} diff --git a/test/FAI.Extensions.Evaluation.Tests/EvaluationPipelineTests.cs b/test/FAI.Extensions.Evaluation.Tests/EvaluationPipelineTests.cs new file mode 100644 index 0000000..2ab09c3 --- /dev/null +++ b/test/FAI.Extensions.Evaluation.Tests/EvaluationPipelineTests.cs @@ -0,0 +1,134 @@ +using FAI.Core.Abstractions; +using Microsoft.Extensions.Logging.Abstractions; +using NSubstitute; + +namespace FAI.Extensions.Evaluation.Tests; + +public class EvaluationPipelineTests +{ + public record TestLoaderInput(int Count); + public record TestInferenceInput(int Id); + public record TestLoadedInput(int Id) : IInferenceInputGetter + { + public TestInferenceInput InferenceInput => new(Id); + } + public record TestInferenceOutput(int Id, string Prediction); + public record TestEvaluationResult(double Accuracy); + + private static async IAsyncEnumerable ToAsync(IEnumerable items) + { + foreach (var item in items) + { + yield return item; + await Task.Yield(); + } + } + + [Fact] + public async Task Evaluate_SimpleFlow_ShouldReturnCorrectResults() + { + // Arrange + var dataLoader = Substitute.For>(); + var inference = Substitute.For>(); + var evaluator = Substitute.For>(); + var logger = NullLogger>.Instance; + var options = new EvaluationPipelineOptions(); + + var pipeline = new EvaluationPipeline( + dataLoader, inference, evaluator, logger, options); + + var inputs = Enumerable.Range(0, 5).Select(i => new TestLoadedInput(i)).ToArray(); + dataLoader.LoadData(Arg.Any()).Returns(ToAsync(inputs)); + + inference.BatchPredict(Arg.Any>()) + .Returns(args => Task.FromResult(((ReadOnlyMemory)args[0]).ToArray().Select(x => new TestInferenceOutput(x.Id, "ok")).ToArray())); + + evaluator.Evaluate(Arg.Any>()) + .Returns(async args => + { + var stream = (IAsyncEnumerable<(TestLoadedInput[], TestInferenceOutput[])>)args[0]; + await foreach (var _ in stream) { } + return new TestEvaluationResult(1.0); + }); + + // Act + var result = await pipeline.Evaluate(new TestLoaderInput(5)); + + // Assert + Assert.Equal(5, result.SampleSize); + Assert.Equal(1.0, result.Evaluation.Accuracy); + dataLoader.Received(1).LoadData(Arg.Is(x => x.Count == 5)); + await inference.Received(1).BatchPredict(Arg.Any>()); + } + + [Fact] + public async Task Evaluate_WithChunking_ShouldCallInferenceMultipleTimes() + { + // Arrange + var dataLoader = Substitute.For>(); + var inference = Substitute.For>(); + var evaluator = Substitute.For>(); + var logger = NullLogger>.Instance; + var options = new EvaluationPipelineOptions(LoadingChunkSize: 2); + + var pipeline = new EvaluationPipeline( + dataLoader, inference, evaluator, logger, options); + + var inputs = Enumerable.Range(0, 5).Select(i => new TestLoadedInput(i)).ToArray(); + dataLoader.LoadData(Arg.Any()).Returns(ToAsync(inputs)); + + inference.BatchPredict(Arg.Any>()) + .Returns(args => Task.FromResult(((ReadOnlyMemory)args[0]).ToArray().Select(x => new TestInferenceOutput(x.Id, "ok")).ToArray())); + + evaluator.Evaluate(Arg.Any>()) + .Returns(async args => + { + var stream = (IAsyncEnumerable<(TestLoadedInput[], TestInferenceOutput[])>)args[0]; + await foreach (var _ in stream) { } + return new TestEvaluationResult(0.8); + }); + + // Act + var result = await pipeline.Evaluate(new TestLoaderInput(5)); + + // Assert + Assert.Equal(5, result.SampleSize); + // Expect 2 calls of 2 elements and 1 call of 1 element + await inference.Received(3).BatchPredict(Arg.Any>()); + } + + [Fact] + public async Task Evaluate_WithParallelOptions_ShouldExecuteSuccessfully() + { + // Arrange + var dataLoader = Substitute.For>(); + var inference = Substitute.For>(); + var evaluator = Substitute.For>(); + var logger = NullLogger>.Instance; + var options = new EvaluationPipelineOptions(LoadingChunkSize: 2, ParallelLoading: true, ParallelEvaluation: true); + + var pipeline = new EvaluationPipeline( + dataLoader, inference, evaluator, logger, options); + + var inputs = Enumerable.Range(0, 4).Select(i => new TestLoadedInput(i)).ToArray(); + dataLoader.LoadData(Arg.Any()).Returns(ToAsync(inputs)); + + inference.BatchPredict(Arg.Any>()) + .Returns(args => Task.FromResult(((ReadOnlyMemory)args[0]).ToArray().Select(x => new TestInferenceOutput(x.Id, "ok")).ToArray())); + + evaluator.Evaluate(Arg.Any>()) + .Returns(async args => + { + var stream = (IAsyncEnumerable<(TestLoadedInput[], TestInferenceOutput[])>)args[0]; + await foreach (var _ in stream) { } + return new TestEvaluationResult(1.0); + }); + + // Act + var result = await pipeline.Evaluate(new TestLoaderInput(4)); + + // Assert + Assert.Equal(4, result.SampleSize); + await evaluator.Received(1).Evaluate(Arg.Any>()); + } +} diff --git a/test/FAI.Extensions.Evaluation.Tests/FAI.Extensions.Evaluation.Tests.csproj b/test/FAI.Extensions.Evaluation.Tests/FAI.Extensions.Evaluation.Tests.csproj new file mode 100644 index 0000000..b2a7da8 --- /dev/null +++ b/test/FAI.Extensions.Evaluation.Tests/FAI.Extensions.Evaluation.Tests.csproj @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/test/FAI.IntegrationTests/FAI.IntegrationTests.csproj b/test/FAI.IntegrationTests/FAI.IntegrationTests.csproj new file mode 100644 index 0000000..7c88293 --- /dev/null +++ b/test/FAI.IntegrationTests/FAI.IntegrationTests.csproj @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/test/FAI.IntegrationTests/GlobalUsings.cs b/test/FAI.IntegrationTests/GlobalUsings.cs new file mode 100644 index 0000000..4b536fc --- /dev/null +++ b/test/FAI.IntegrationTests/GlobalUsings.cs @@ -0,0 +1,20 @@ +global using System.Numerics.Tensors; +global using System.Runtime.CompilerServices; +global using FAI.Core.Abstractions; +global using FAI.Core.BatchSchedulers; +global using FAI.Core.Configurations.InferenceTasks; +global using FAI.Core.Configurations.PipelineBatchExecutors; +global using FAI.Core.Extensions.DI; +global using FAI.Core.PipelineBatchExecutors; +global using FAI.Core.Pipelines; +global using FAI.Core.ResultTypes; +global using FAI.NLP.Configuration; +global using FAI.NLP.InferenceTasks.TextClassification; +global using FAI.NLP.InferenceTasks.TextMultipleChoice; +global using FAI.NLP.Tests.Mocks; +global using FAI.NLP.Tokenization; +global using FAI.Onnx.ModelExecutors; +global using FAI.Onnx.Tests.Utils; +global using FluentAssertions; +global using Microsoft.Extensions.DependencyInjection; +global using Xunit; diff --git a/test/FAI.IntegrationTests/LogicalMockModelExecutor.cs b/test/FAI.IntegrationTests/LogicalMockModelExecutor.cs new file mode 100644 index 0000000..bb7d9e8 --- /dev/null +++ b/test/FAI.IntegrationTests/LogicalMockModelExecutor.cs @@ -0,0 +1,39 @@ +namespace FAI.IntegrationTests; + +public class LogicalMockModelExecutor : IModelExecutor +{ + private readonly float[][] _outputs; + private int _callCount = 0; + + public LogicalMockModelExecutor(float[][] outputs) + { + _outputs = outputs; + } + + public Task[]> RunAsync(Tensor[] inputs) + { + var data = _outputs[_callCount % _outputs.Length]; + var output = Tensor.Create(data, [(nint)data.Length]); + _callCount++; + return Task.FromResult(new[] { output }); + } + + public Task RunAsync(Tensor[] inputs, Action, int> postProcess) + { + int batchSize = (int)inputs[0].Lengths[0]; + int outputSize = _outputs[0].Length; + + float[] batchOutput = new float[batchSize * outputSize]; + for (int i = 0; i < batchSize; i++) + { + var row = _outputs[_callCount % _outputs.Length]; + _callCount++; + row.AsSpan().CopyTo(batchOutput.AsSpan(i * outputSize)); + } + + var batchTensor = Tensor.Create(batchOutput, [(nint)batchSize, (nint)outputSize]); + postProcess(batchTensor, 0); // Assuming model has 1 output tensor + + return Task.CompletedTask; + } +} diff --git a/test/FAI.IntegrationTests/MultipleChoiceIntegrationTests.cs b/test/FAI.IntegrationTests/MultipleChoiceIntegrationTests.cs new file mode 100644 index 0000000..4f4b024 --- /dev/null +++ b/test/FAI.IntegrationTests/MultipleChoiceIntegrationTests.cs @@ -0,0 +1,40 @@ +using FAI.NLP.PipelineBatchExecutors; + +namespace FAI.IntegrationTests; + +public class MultipleChoiceIntegrationTests +{ + [Fact] + public async Task FullPipeline_ShouldHandleMultipleChoice() + { + // Arrange + var services = new ServiceCollection(); + + services.AddPipeline>() + .Use>>(); + + var tokenizer = DummyTokenizerFactory.Create(); + services.AddSingleton(tokenizer); + services.AddSingleton(new TextMultipleChoiceOptions { MaxChoices = 2 }); + + // Mock model: always returns [0.9, 0.1] logits + services.AddSingleton>(new LogicalMockModelExecutor([[0.9f, 0.1f]])); + + services.AddSingleton>, TextMultipleChoiceTask>(); + + var provider = services.BuildServiceProvider(); + var pipeline = provider.GetRequiredService>>(); + + var input = new TextMultipleChoiceInput( + "Question", + [new TokenizedText("choice 1"), new TokenizedText("choice 2")] + ); + + // Act + var results = await pipeline.BatchPredict(new[] { input }); + + // Assert + results.Should().HaveCount(1); + results[0].ChoiceIndex.Should().Be(0); // Chosen first choice + } +} diff --git a/test/FAI.IntegrationTests/PipelineConfigurationIntegrationTests.cs b/test/FAI.IntegrationTests/PipelineConfigurationIntegrationTests.cs new file mode 100644 index 0000000..eeee8fa --- /dev/null +++ b/test/FAI.IntegrationTests/PipelineConfigurationIntegrationTests.cs @@ -0,0 +1,49 @@ +using FAI.Core.BatchSlicers; +using FAI.NLP.PipelineBatchExecutors; + +namespace FAI.IntegrationTests; + +public class PipelineConfigurationIntegrationTests +{ + [Fact] + public async Task ComplexPipeline_WithBackgroundAndPartitioning_ShouldProcessBatches() + { + // Arrange + var services = new ServiceCollection(); + + services.AddPipeline>() + .Use>>() + .Use>>() + .Use>>(); + + // Setup dependencies + var tokenizer = DummyTokenizerFactory.Create(); + services.AddSingleton(tokenizer); + services.AddSingleton>, ParallelBatchSchedular>>(); + services.AddSingleton, FixedSizeBatchSlicer>(); + + // Add options for executors + services.AddSingleton(new ParallelBatchSchedularOptions(2)); + services.AddSingleton(new FixedSizeBatchSlicerOptions(5)); + services.AddSingleton(new BackgroundPipelineBatchExecutorOptions(2)); + + var options = new ClassificationOptions([false, true]); + services.AddSingleton(options); + + // Mock model: always returns high probability for 'true' + services.AddSingleton>(new LogicalMockModelExecutor([[0.1f, 0.9f]])); + + services.AddSingleton>, TextClassification>(); + + var provider = services.BuildServiceProvider(); + var pipeline = provider.GetRequiredService>>(); + + // Act + var inputs = Enumerable.Range(0, 10).Select(i => new TokenizedText($"test {i}")).ToArray(); + var results = await pipeline.BatchPredict(inputs); + + // Assert + results.Should().HaveCount(10); + results.All(r => r.Choice).Should().BeTrue(); + } +} diff --git a/test/FAI.IntegrationTests/TextClassificationIntegrationTests.cs b/test/FAI.IntegrationTests/TextClassificationIntegrationTests.cs new file mode 100644 index 0000000..2449b01 --- /dev/null +++ b/test/FAI.IntegrationTests/TextClassificationIntegrationTests.cs @@ -0,0 +1,38 @@ +using FAI.NLP.PipelineBatchExecutors; + +namespace FAI.IntegrationTests; + +public class TextClassificationIntegrationTests +{ + [Fact] + public async Task FullPipeline_ShouldClassifyText() + { + // Arrange + var services = new ServiceCollection(); + + var options = new ClassificationOptions([false, true]); + services.AddSingleton(options); + + services.AddPipeline>() + .Use>>(); + + var tokenizer = DummyTokenizerFactory.Create(); + services.AddSingleton(tokenizer); + + // Mock model: always returns high probability for 'true' (index 1) + services.AddSingleton>(new LogicalMockModelExecutor([[0.1f, 0.9f]])); + + services.AddSingleton>, TextClassification>(); + + var provider = services.BuildServiceProvider(); + var pipeline = provider.GetRequiredService>>(); + + // Act + var input = new TokenizedText("hello"); + var results = await pipeline.BatchPredict(new[] { input }); + + // Assert + results.Should().HaveCount(1); + results[0].Choice.Should().BeTrue(); + } +} diff --git a/test/FAI.NLP.Extensions.DI.Tests/BatchExecutorExtensionsTests.cs b/test/FAI.NLP.Extensions.DI.Tests/BatchExecutorExtensionsTests.cs new file mode 100644 index 0000000..b9f04b2 --- /dev/null +++ b/test/FAI.NLP.Extensions.DI.Tests/BatchExecutorExtensionsTests.cs @@ -0,0 +1,171 @@ +using System.Text; +using FAI.Core.Abstractions; +using FAI.Core.Configurations.InferenceTasks; +using FAI.Core.Extensions.DI; +using FAI.Core.ResultTypes; +using FAI.NLP.BatchSlicer; +using FAI.NLP.Configuration; +using FAI.NLP.Configuration.PipelineBatchExecutors; +using FAI.NLP.Extensions.DI; +using FAI.NLP.PipelineBatchExecutors; +using FAI.NLP.Tokenization; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.ML.Tokenizers; + +namespace FAI.NLP.Extensions.DI.Tests; + +public class BatchExecutorExtensionsTests +{ + private readonly IServiceCollection _services = new ServiceCollection(); + + public BatchExecutorExtensionsTests() + { + _services.AddSingleton(CreateDummyTokenizer()); + } + + [Fact] + public void UseTokenSorting_WithExplicitOptions_RegistersExecutor() + { + // Arrange + _services.AddSingleton(Substitute.For>()); + var builder = _services.AddPipeline(); + var options = new TokenCountSortingBatchExecutorOptions(false); + + // Act + builder.UseTokenSorting(options); + var sp = _services.BuildServiceProvider(); + var pipeline = sp.GetRequiredService>(); + + // Assert + Assert.NotNull(pipeline); + } + + [Fact] + public void UseTokenSorting_WithSection_BindsConfiguration() + { + // Arrange + var configData = new Dictionary + { + ["NLP:Sorting:Ascending"] = "false" + }; + var config = new ConfigurationBuilder().AddInMemoryCollection(configData).Build(); + _services.AddSingleton(config); + _services.AddSingleton(Substitute.For>()); + + var builder = _services.AddPipeline(); + + // Act + builder.UseTokenSorting("NLP:Sorting"); + var sp = _services.BuildServiceProvider(); + var pipeline = sp.GetRequiredService>(); + + // Assert + Assert.NotNull(pipeline); + var options = sp.GetRequiredService(); + Assert.False(options.Ascending); + } + + [Fact] + public void UseTokenizing_RegistersExecutor() + { + // Arrange + _services.AddSingleton(Substitute.For>()); + var builder = _services.AddPipeline(); + + // Act + builder.UseTokenizing(); + var sp = _services.BuildServiceProvider(); + var pipeline = sp.GetRequiredService>(); + + // Assert + Assert.NotNull(pipeline); + } + + [Fact] + public void WithTextClassification_RegistersStepsAndOptions() + { + // Arrange + var configData = new Dictionary + { + ["NLP:Classification:Choices:0"] = "Positive", + ["NLP:Classification:Choices:1"] = "Negative", + ["NLP:Classification:StoreLogits"] = "true" + }; + var config = new ConfigurationBuilder().AddInMemoryCollection(configData).Build(); + _services.AddSingleton(config); + _services.AddSingleton(Substitute.For>()); + + var builder = _services.AddPipeline>(); + + // Act + builder.WithTextClassification("NLP:Classification"); + var sp = _services.BuildServiceProvider(); + + // Assert + var steps = sp.GetService>>(); + Assert.NotNull(steps); + var options = sp.GetService>(); + Assert.NotNull(options); + Assert.Equal(["Positive", "Negative"], options.Choices); + Assert.True(options.StoreLogits); + } + + [Fact] + public void WithMaxPaddedTokens_RegistersSlicerAndOptions() + { + // Arrange + var configData = new Dictionary + { + ["NLP:Partition:MaxTokenCount"] = "128", + ["NLP:Partition:MaxPaddedTokenRatio"] = "0.5" + }; + var config = new ConfigurationBuilder().AddInMemoryCollection(configData).Build(); + _services.AddSingleton(config); + _services.AddSingleton(Substitute.For>()); + + var builder = _services.AddPipeline(); + + // Act + builder.UsePartitioning(p => p.WithMaxPaddedTokens("NLP:Partition")); + var sp = _services.BuildServiceProvider(); + + // Assert + var slicer = sp.GetService>(); + Assert.NotNull(slicer); + Assert.IsType>(slicer); + var options = sp.GetService(); + Assert.NotNull(options); + Assert.Equal(128, options.MaxTokenCount); + Assert.Equal(0.5, options.MaxPaddedTokenRatio); + } + + public record MockTokenizable(int TokenCount) : ITokenizable + { + public int MaxTokenLength => TokenCount; + public int SentenceCount => 1; + public void Tokenize(PretrainedTokenizer pretrainedTokenizer) { } + } + + private static PretrainedTokenizer CreateDummyTokenizer() + { + var vocab = new StringBuilder(); + vocab.AppendLine("[PAD]"); + vocab.AppendLine("[CLS]"); + vocab.AppendLine("[SEP]"); + vocab.AppendLine("[MASK]"); + vocab.AppendLine("[UNK]"); + + using var ms = new MemoryStream(Encoding.UTF8.GetBytes(vocab.ToString())); + var bertTokenizer = BertTokenizer.Create(ms); + + var options = new PretrainedTokenizerOptions + { + MaxTokenLength = 128, + PaddingToken = 0, + TruncationOption = TruncationOption.Longest + }; + + return new PretrainedTokenizer(bertTokenizer, options); + } +} diff --git a/test/FAI.NLP.Extensions.DI.Tests/FAI.NLP.Extensions.DI.Tests.csproj b/test/FAI.NLP.Extensions.DI.Tests/FAI.NLP.Extensions.DI.Tests.csproj new file mode 100644 index 0000000..7e6ce06 --- /dev/null +++ b/test/FAI.NLP.Extensions.DI.Tests/FAI.NLP.Extensions.DI.Tests.csproj @@ -0,0 +1,5 @@ + + + + + diff --git a/test/FAI.NLP.Tests/BatchSlicer/MaxPaddedTokensBatchSlicerTests.cs b/test/FAI.NLP.Tests/BatchSlicer/MaxPaddedTokensBatchSlicerTests.cs new file mode 100644 index 0000000..59981c5 --- /dev/null +++ b/test/FAI.NLP.Tests/BatchSlicer/MaxPaddedTokensBatchSlicerTests.cs @@ -0,0 +1,68 @@ +using FAI.NLP.BatchSlicer; +using FAI.NLP.Configuration.PipelineBatchExecutors; +using FAI.NLP.Tokenization; + +namespace FAI.NLP.Tests.BatchSlicer; + +public class MaxPaddedTokensBatchSlicerTests +{ + public record TestTokenizable(int TokenCount) : ITokenizable + { + public int MaxTokenLength => TokenCount; + public int SentenceCount => 1; + public void Tokenize(PretrainedTokenizer tokenizer) { } + } + + [Fact] + public void Slice_RespectsMaxTokenCount() + { + // Arrange + var options = new MaxPaddedTokensSlicerOptions + { + MaxTokenCount = 10, + MaxPaddedTokenRatio = 0.5 // Allow up to 50% padding + }; + var slicer = new MaxPaddedTokensBatchSlicer(options); + + // Items with token counts: 4, 4, 4 + // Batch 1: (4+4) * 2 = 8 <= 10. OK. + // Batch 1 + next: (4+4+4) * 3 = 12 > 10. Break. + TestTokenizable[] inputs = [new(4), new(4), new(4)]; + + // Act + var ranges = slicer.Slice(inputs).ToList(); + + // Assert + Assert.Equal(2, ranges.Count); + Assert.Equal(0..2, ranges[0]); + Assert.Equal(2..3, ranges[1]); + } + + [Fact] + public void Slice_RespectsMaxPaddedTokenRatio() + { + // Arrange + var options = new MaxPaddedTokensSlicerOptions + { + MaxTokenCount = 100, + MaxPaddedTokenRatio = 0.1 // Only 10% padding allowed + }; + var slicer = new MaxPaddedTokensBatchSlicer(options); + + // Slicer assumes input is sorted ASCENDING by MaxTokenLength. + // Item 1: 2 tokens. + // Item 2: 10 tokens. + // If batched: MaxLen = 10, Count = 2. Padded = 20. Sum = 12. + // factor = 1.0 - 0.1 = 0.9. + // newSum < newPadded * factor <=> 12 < 20 * 0.9 (18) => TRUE. BREAKS. + TestTokenizable[] inputs = [new(2), new(10)]; + + // Act + var ranges = slicer.Slice(inputs).ToList(); + + // Assert + Assert.Equal(2, ranges.Count); + Assert.Equal(0..1, ranges[0]); + Assert.Equal(1..2, ranges[1]); + } +} diff --git a/test/FAI.NLP.Tests/FAI.NLP.Tests.csproj b/test/FAI.NLP.Tests/FAI.NLP.Tests.csproj new file mode 100644 index 0000000..525377c --- /dev/null +++ b/test/FAI.NLP.Tests/FAI.NLP.Tests.csproj @@ -0,0 +1,5 @@ + + + + + diff --git a/test/FAI.NLP.Tests/InferenceTasks/TextClassificationTaskTests.cs b/test/FAI.NLP.Tests/InferenceTasks/TextClassificationTaskTests.cs new file mode 100644 index 0000000..5df32ac --- /dev/null +++ b/test/FAI.NLP.Tests/InferenceTasks/TextClassificationTaskTests.cs @@ -0,0 +1,57 @@ +using System.Numerics.Tensors; +using FAI.Core.Abstractions; +using FAI.Core.Configurations.InferenceTasks; +using FAI.Core.ResultTypes; +using FAI.NLP.InferenceTasks.TextClassification; +using FAI.NLP.Tests.Mocks; +using FAI.NLP.Tokenization; + +namespace FAI.NLP.Tests.InferenceTasks; + +public class TextClassificationTaskTests +{ + [Fact] + public async Task TextClassification_E2E_Mapping_Works() + { + // Arrange + var tokenizer = DummyTokenizerFactory.Create(); + var modelExecutor = Substitute.For>(); + + var options = new ClassificationOptions + { + Choices = ["Negative", "Positive"] + }; + + var task = new TextClassification(tokenizer, modelExecutor, options); + + TokenizedText[] inputs = [new("hello")]; + var outputs = new ClassificationResult[1]; + + // Mock model output: 2 labels, [0.1f, 0.9f] -> Positive + // The signature we use in ClassificationTask is usually RunAsync(Tensor[] inputs, Action, int> callback) + // Wait, let's check ClassificationTask.cs if possible or assume standard. + // Actually IModelExecutor has Task RunAsync(Tensor[] inputs, Action, int> postProcess); + + modelExecutor.RunAsync(Arg.Any[]>(), Arg.Any, int>>()) + .Returns(x => + { + var callback = x.ArgAt, int>>(1); + var logits = Tensor.CreateFromShape([1, 2]); + logits[0, 0] = 0.1f; + logits[0, 1] = 0.9f; + + // ReadOnlyTensorSpan is a ref struct, we might need a internal way to create it or just use Arg.Invoke if NSubstitute supports it for ref structs (it usually doesn't in that way) + // However, we can use Arg.Do to capture the callback and then invoke it if we can construct the ref struct. + callback(logits, 0); + return Task.CompletedTask; + }); + + // Act + // Use ProcessBatch as it is an IInferenceSteps + await task.ProcessBatch(inputs, outputs); + + // Assert + Assert.Equal("Positive", outputs[0].Choice); + Assert.True(outputs[0].Score > 0.5f); + } +} diff --git a/test/FAI.NLP.Tests/InferenceTasks/TextMultipleChoiceTaskTests.cs b/test/FAI.NLP.Tests/InferenceTasks/TextMultipleChoiceTaskTests.cs new file mode 100644 index 0000000..567abdd --- /dev/null +++ b/test/FAI.NLP.Tests/InferenceTasks/TextMultipleChoiceTaskTests.cs @@ -0,0 +1,53 @@ +using System.Numerics.Tensors; +using FAI.Core.Abstractions; +using FAI.Core.ResultTypes; +using FAI.NLP.Configuration; +using FAI.NLP.InferenceTasks.TextMultipleChoice; +using FAI.NLP.Tests.Mocks; +using FAI.NLP.Tokenization; + +namespace FAI.NLP.Tests.InferenceTasks; + +public class TextMultipleChoiceTaskTests +{ + [Fact] + public async Task TextMultipleChoice_FlatteningAndInference_Works() + { + // Arrange + var tokenizer = DummyTokenizerFactory.Create(); + var modelExecutor = Substitute.For>(); + + var options = new TextMultipleChoiceOptions + { + MaxChoices = 4, + StoreLogits = true + }; + + var task = new TextMultipleChoiceTask(tokenizer, modelExecutor, options); + + // Input with 2 choices + var inputs = new TextMultipleChoiceInput[] + { + new("context", [new("choice 1"), new("choice 2")]) + }; + var outputs = new ChoiceResult[1]; + + modelExecutor.RunAsync(Arg.Any[]>(), Arg.Any, int>>()) + .Returns(x => + { + var callback = x.ArgAt, int>>(1); + var logits = Tensor.CreateFromShape([1, 2]); + logits[0, 0] = -1.0f; + logits[0, 1] = 2.0f; // Higher score for second choice + callback(logits, 0); + return Task.CompletedTask; + }); + + // Act + await task.ProcessBatch(inputs, outputs); + + // Assert + Assert.Equal(1, outputs[0].ChoiceIndex); + Assert.Equal("choice 2", outputs[0].Choice.Text); + } +} diff --git a/test/FAI.NLP.Tests/Mocks/DummyTokenizerFactory.cs b/test/FAI.NLP.Tests/Mocks/DummyTokenizerFactory.cs new file mode 100644 index 0000000..910ed45 --- /dev/null +++ b/test/FAI.NLP.Tests/Mocks/DummyTokenizerFactory.cs @@ -0,0 +1,44 @@ +using System.Text; +using FAI.NLP.Configuration; +using FAI.NLP.Tokenization; +using Microsoft.ML.Tokenizers; + +namespace FAI.NLP.Tests.Mocks; + +public static class DummyTokenizerFactory +{ + public static PretrainedTokenizer Create(int maxTokenLength = 128) + { + // Define a minimal vocabulary for BERT tokenizer + var vocab = new StringBuilder(); + vocab.AppendLine("[PAD]"); + vocab.AppendLine("[unused0]"); + vocab.AppendLine("[unused1]"); + vocab.AppendLine("[unused2]"); + vocab.AppendLine("[unused3]"); + vocab.AppendLine("[unused4]"); + vocab.AppendLine("[unused5]"); + vocab.AppendLine("[unused6]"); + vocab.AppendLine("[unused7]"); + vocab.AppendLine("[unused8]"); + vocab.AppendLine("[unused9]"); + vocab.AppendLine("[CLS]"); + vocab.AppendLine("[SEP]"); + vocab.AppendLine("[MASK]"); + vocab.AppendLine("[UNK]"); + vocab.AppendLine("hello"); + vocab.AppendLine("world"); + + using var ms = new MemoryStream(Encoding.UTF8.GetBytes(vocab.ToString())); + var bertTokenizer = BertTokenizer.Create(ms); + + var options = new PretrainedTokenizerOptions + { + MaxTokenLength = maxTokenLength, + PaddingToken = 0, + TruncationOption = TruncationOption.Longest + }; + + return new PretrainedTokenizer(bertTokenizer, options); + } +} diff --git a/test/FAI.NLP.Tests/PipelineBatchExecutorTests/TokenBatchSizeBatchExecutorTests.cs b/test/FAI.NLP.Tests/PipelineBatchExecutorTests/TokenBatchSizeBatchExecutorTests.cs new file mode 100644 index 0000000..3871fc5 --- /dev/null +++ b/test/FAI.NLP.Tests/PipelineBatchExecutorTests/TokenBatchSizeBatchExecutorTests.cs @@ -0,0 +1,49 @@ +using FAI.Core.Abstractions; +using FAI.NLP.Configuration.PipelineBatchExecutors; +using FAI.NLP.PipelineBatchExecutors; +using FAI.NLP.Tokenization; + +namespace FAI.NLP.Tests.PipelineBatchExecutorTests; + +public class TokenBatchSizeBatchExecutorTests +{ + public record TestTokenizable(int TokenCount) : ITokenizable + { + public int MaxTokenLength => TokenCount; + public int SentenceCount => 1; + public void Tokenize(PretrainedTokenizer tokenizer) { } + } + + [Fact] + public async Task ExecuteBatchPredict_SplitsIntoCorrectBatches() + { + // Arrange + var mockExecutor = Substitute.For>(); + var options = new TokenBatchSizeBatchExecutorOptions { MaxTokensCount = 10 }; + var executor = new TokenBatchSizeBatchExecutor(mockExecutor, options); + + // Inputs: 6, 5, 4, 7 + // Batch 1: [6] + // Batch 2: [5, 4] + // Batch 3: [7] + TestTokenizable[] inputs = [new(6), new(5), new(4), new(7)]; + var outputs = new int[4]; + + // Act + await executor.ExecuteBatchPredict(inputs, outputs); + + // Assert + // Distinguish Batch 1 and Batch 3 by using Received(2) or checking contents if possible without ref struct issues + // Since we cannot easily use .Span in Arg.Is due to expression tree limitations, we check the number of matching calls. + + // One call with Length 2 + await mockExecutor.Received(1).ExecuteBatchPredict( + Arg.Is>(m => m.Length == 2), + Arg.Any>()); + + // Two calls with Length 1 + await mockExecutor.Received(2).ExecuteBatchPredict( + Arg.Is>(m => m.Length == 1), + Arg.Any>()); + } +} diff --git a/test/FAI.NLP.Tests/PipelineBatchExecutorTests/TokenCountSortingBatchExecutorTests.cs b/test/FAI.NLP.Tests/PipelineBatchExecutorTests/TokenCountSortingBatchExecutorTests.cs new file mode 100644 index 0000000..9306b14 --- /dev/null +++ b/test/FAI.NLP.Tests/PipelineBatchExecutorTests/TokenCountSortingBatchExecutorTests.cs @@ -0,0 +1,54 @@ +using FAI.Core.Abstractions; +using FAI.NLP.Configuration.PipelineBatchExecutors; +using FAI.NLP.PipelineBatchExecutors; +using FAI.NLP.Tests.Mocks; +using FAI.NLP.Tokenization; + +namespace FAI.NLP.Tests.PipelineBatchExecutorTests; + +public class TokenCountSortingBatchExecutorTests +{ + public class TestTokenizable(int tokenCount) : ITokenizable + { + public int TokenCount { get; set; } = tokenCount; + public int MaxTokenLength => TokenCount; + public int SentenceCount => 1; + public void Tokenize(PretrainedTokenizer tokenizer) { } + } + + [Fact] + public async Task ExecuteBatchPredict_SortsAndUnsortsCorrectly() + { + // Arrange + var mockExecutor = Substitute.For>(); + var tokenizer = DummyTokenizerFactory.Create(); + var options = new TokenCountSortingBatchExecutorOptions { Ascending = true }; + var executor = new TokenCountSortingBatchExecutor(mockExecutor, tokenizer, options); + + TestTokenizable[] inputs = [new(10), new(2), new(5)]; + var outputs = new int[3]; + + // We expect the executor to receive [new(2), new(5), new(10)] + // We will mock the response of the inner executor to return [22, 55, 1010] (indices: 0->22, 1->55, 2->1010) + // Since input was [10, 2, 5], final output should be [1010, 22, 55] + mockExecutor.ExecuteBatchPredict(Arg.Any>(), Arg.Any>()) + .Returns(x => + { + var inputMem = x.ArgAt>(0); + var outputMem = x.ArgAt>(1); + for (int i = 0; i < inputMem.Length; i++) + { + outputMem.Span[i] = inputMem.Span[i].TokenCount * 11; // Dummy operation + } + return Task.CompletedTask; + }); + + // Act + await executor.ExecuteBatchPredict(inputs, outputs); + + // Assert + Assert.Equal(110, outputs[0]); // 10 * 11 + Assert.Equal(22, outputs[1]); // 2 * 11 + Assert.Equal(55, outputs[2]); // 5 * 11 + } +} diff --git a/test/FAI.NLP.Tests/Tokenization/PretrainedTokenizerTests.cs b/test/FAI.NLP.Tests/Tokenization/PretrainedTokenizerTests.cs new file mode 100644 index 0000000..08342be --- /dev/null +++ b/test/FAI.NLP.Tests/Tokenization/PretrainedTokenizerTests.cs @@ -0,0 +1,81 @@ +using FAI.NLP.Tests.Mocks; +using FAI.NLP.Tokenization; + +namespace FAI.NLP.Tests.Tokenization; + +public class PretrainedTokenizerFixture : IDisposable +{ + public PretrainedTokenizer Tokenizer { get; } = DummyTokenizerFactory.Create(); + + public void Dispose() + { + // Cleanup if needed + } +} + +public class PretrainedTokenizerTests : IClassFixture +{ + private readonly PretrainedTokenizer _tokenizer; + + public PretrainedTokenizerTests(PretrainedTokenizerFixture fixture) + { + _tokenizer = fixture.Tokenizer; + } + + [Fact] + public void Tokenize_SingleInput_ReturnsCorrectIds() + { + // Arrange + string text = "hello world"; + + // Act + var ids = _tokenizer.Tokenize(text); + + // Assert + // Based on the dummy vocab: + // [PAD]=0, [unused0..9]=1..10, [CLS]=11, [SEP]=12, [MASK]=13, [UNK]=14, hello=15, world=16 + Assert.Contains(15, ids); // hello + Assert.Contains(16, ids); // world + } + + [Fact] + public void BatchTokenize_Strings_ReturnsCorrectTensorShape() + { + // Arrange + string[] inputs = ["hello", "hello world"]; + + // Act + var result = _tokenizer.BatchTokenize(inputs); + + // Assert + Assert.Equal(2, result.BatchSize); + Assert.True(result.MaxTokenCount >= 2); + Assert.Equal(result.Tokens.Lengths, result.Mask.Lengths); + } + + [Fact] + public void BatchTokensToTensors_PadsCorrectly() + { + // Arrange + List[] inputs = [[15], [15, 16]]; // hello, hello world + + // Act + var result = _tokenizer.BatchTokensToTensors(inputs, maxTokenSize: 2); + + // Assert + Assert.Equal(2, result.BatchSize); + Assert.Equal(2, result.MaxTokenCount); + + // Row 0: [hello, PAD] -> [15, 0] + Assert.Equal(15, result.Tokens[0, 0]); + Assert.Equal(0, result.Tokens[0, 1]); + Assert.Equal(1, result.Mask[0, 0]); + Assert.Equal(0, result.Mask[0, 1]); + + // Row 1: [hello, world] -> [15, 16] + Assert.Equal(15, result.Tokens[1, 0]); + Assert.Equal(16, result.Tokens[1, 1]); + Assert.Equal(1, result.Mask[1, 0]); + Assert.Equal(1, result.Mask[1, 1]); + } +} diff --git a/test/FAI.Onnx.Tests/FAI.Onnx.Tests.csproj b/test/FAI.Onnx.Tests/FAI.Onnx.Tests.csproj new file mode 100644 index 0000000..98cfb55 --- /dev/null +++ b/test/FAI.Onnx.Tests/FAI.Onnx.Tests.csproj @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/test/FAI.Onnx.Tests/Factories/ModelExecutorFactoryTests.cs b/test/FAI.Onnx.Tests/Factories/ModelExecutorFactoryTests.cs new file mode 100644 index 0000000..4c8148c --- /dev/null +++ b/test/FAI.Onnx.Tests/Factories/ModelExecutorFactoryTests.cs @@ -0,0 +1,77 @@ +using FAI.Core.Configurations.ModelExecutors; +using FAI.Core.ModelExecutors; +using FAI.Onnx.Configuration; +using FAI.Onnx.Factories; +using FAI.Onnx.ModelExecutors; + +namespace FAI.Onnx.Tests.Factories; + +public class ModelExecutorFactoryTests(OnnxModelFixture fixture) : IClassFixture +{ + private readonly string _modelPath = fixture.ModelPath; + + [Fact] + public void CreateModelExecutor_ShouldReturnPooledModelExecutor_WhenMultiDeviceOptionsProvided() + { + // Arrange + var options = new MultiDeviceExecutorOptions() + .AddOptions(opt => opt.ConfigureOnnxOptions(onnx => + { + onnx.ModelDir = Path.GetDirectoryName(_modelPath)!; + onnx.ModelFileName = Path.GetFileName(_modelPath); + })); + + // Act + var executor = ModelExecutorFactory.CreateModelExecutor(ModelExecutorType.Simple, options); + + // Assert + Assert.IsType>(executor); + } + + [Fact] + public void CreateModelExecutor_ShouldReturnPooledModelExecutor_WhenPooledOptionsProvided() + { + // Arrange + var onnxOptions = new OnnxModelExecutorOptions().ConfigureOnnxOptions(opt => + { + opt.ModelDir = Path.GetDirectoryName(_modelPath)!; + opt.ModelFileName = Path.GetFileName(_modelPath); + }); + + var options = new PooledExecutorOptions(onnxOptions, 2); + + // Act + var executor = ModelExecutorFactory.CreateModelExecutor(ModelExecutorType.Async, options); + + // Assert + Assert.IsType>(executor); + } + + [Fact] + public void CreateModelExecutor_ShouldReturnSimpleExecutor_WhenOnnxModelExecutorOptionsProvided() + { + // Arrange + var options = new OnnxModelExecutorOptions().ConfigureOnnxOptions(opt => + { + opt.ModelDir = Path.GetDirectoryName(_modelPath)!; + opt.ModelFileName = Path.GetFileName(_modelPath); + }); + + // Act + var executor = ModelExecutorFactory.CreateModelExecutor(ModelExecutorType.Simple, options); + + // Assert + Assert.IsType(executor); + } + + [Fact] + public void CreateModelExecutor_ShouldThrow_WhenUnknownExecutorType() + { + // Arrange + var options = new OnnxModelExecutorOptions(); + var unknownType = (ModelExecutorType)999; + + // Act & Assert + Assert.Throws(() => ModelExecutorFactory.CreateModelExecutor(unknownType, options)); + } +} diff --git a/test/FAI.Onnx.Tests/ModelExecutorPools/MultiDeviceObjectPoolTests.cs b/test/FAI.Onnx.Tests/ModelExecutorPools/MultiDeviceObjectPoolTests.cs new file mode 100644 index 0000000..9a73f7e --- /dev/null +++ b/test/FAI.Onnx.Tests/ModelExecutorPools/MultiDeviceObjectPoolTests.cs @@ -0,0 +1,31 @@ +using FAI.Onnx.Configuration; +using FAI.Onnx.ModelExecutorPools; +using FAI.Onnx.ModelExecutors; + +namespace FAI.Onnx.Tests.ModelExecutorPools; + +public class MultiDeviceObjectPoolTests(OnnxModelFixture fixture) : IClassFixture +{ + private readonly string _modelPath = fixture.ModelPath; + + [Fact] + public void Get_ShouldReturnExecutorsInRoundRobinOrder() + { + // Arrange + var options = new OnnxModelExecutorOptions().ConfigureOnnxOptions(opt => + { + opt.ModelDir = Path.GetDirectoryName(_modelPath)!; + opt.ModelFileName = Path.GetFileName(_modelPath); + }); + + var exec1 = OnnxModelExecutor.FromPretrained(options); + var exec2 = OnnxModelExecutor.FromPretrained(options); + var executors = new List { exec1, exec2 }; + var pool = new MultiDeviceObjectPool(executors); + + // Act & Assert + Assert.Same(exec1, pool.Get()); + Assert.Same(exec2, pool.Get()); + Assert.Same(exec1, pool.Get()); + } +} diff --git a/test/FAI.Onnx.Tests/ModelExecutors/AsyncOnnxModelExecutorTests.cs b/test/FAI.Onnx.Tests/ModelExecutors/AsyncOnnxModelExecutorTests.cs new file mode 100644 index 0000000..29cabae --- /dev/null +++ b/test/FAI.Onnx.Tests/ModelExecutors/AsyncOnnxModelExecutorTests.cs @@ -0,0 +1,73 @@ +using System.Numerics.Tensors; +using FAI.Onnx.Configuration; +using FAI.Onnx.ModelExecutors; +using Microsoft.ML.OnnxRuntime; + +namespace FAI.Onnx.Tests.ModelExecutors; + +public class AsyncOnnxModelExecutorTests(OnnxModelFixture fixture) : IClassFixture +{ + private readonly string _modelPath = fixture.ModelPath; + + [Fact] + public async Task RunAsync_ShouldExecuteRealInference() + { + // Arrange + var options = new OnnxModelExecutorOptions().ConfigureOnnxOptions(opt => + { + opt.ModelDir = Path.GetDirectoryName(_modelPath)!; + opt.ModelFileName = Path.GetFileName(_modelPath); + }); + + var executor = AsyncOnnxModelExecutor.FromPretrained(options); + + // Input matching the minimal model: [1, 3] of long + var inputs = new[] { Tensor.Create([10L, 20L, 30L], [1, 3]) }; + + // Act + var results = await executor.RunAsync(inputs); + + // Assert + Assert.Single(results); + var output = results[0]; + Assert.Equal(2, output.Lengths.Length); + Assert.Equal(1L, (long)output.Lengths[0]); + Assert.Equal(3L, (long)output.Lengths[1]); + + // The minimal model casts long to float + Assert.Equal(10.0f, output[0, 0]); + Assert.Equal(20.0f, output[0, 1]); + Assert.Equal(30.0f, output[0, 2]); + } + + [Fact] + public async Task RunAsync_WithPostProcess_ShouldExecuteRealInference() + { + // Arrange + var options = new OnnxModelExecutorOptions().ConfigureOnnxOptions(opt => + { + opt.ModelDir = Path.GetDirectoryName(_modelPath)!; + opt.ModelFileName = Path.GetFileName(_modelPath); + }); + + var executor = AsyncOnnxModelExecutor.FromPretrained(options); + var inputs = new[] { Tensor.Create([100L, 200L, 300L], [1, 3]) }; + + var called = false; + long[] outputShape = []; + + // Act + await executor.RunAsync(inputs, (span, index) => + { + called = true; + outputShape = [(long)span.Lengths[0], (long)span.Lengths[1]]; + Assert.Equal(100.0f, span[0, 0]); + Assert.Equal(200.0f, span[0, 1]); + Assert.Equal(300.0f, span[0, 2]); + }); + + // Assert + Assert.True(called); + Assert.Equal([1L, 3L], outputShape); + } +} diff --git a/test/FAI.Onnx.Tests/OnnxModelFixture.cs b/test/FAI.Onnx.Tests/OnnxModelFixture.cs new file mode 100644 index 0000000..aaf6453 --- /dev/null +++ b/test/FAI.Onnx.Tests/OnnxModelFixture.cs @@ -0,0 +1,21 @@ +using FAI.Onnx.Tests.Utils; + +namespace FAI.Onnx.Tests; + +public sealed class OnnxModelFixture : IDisposable +{ + public string ModelPath { get; } + + public OnnxModelFixture() + { + ModelPath = OnnxTestModelFactory.CreateTemporaryModelFile(); + } + + public void Dispose() + { + if (File.Exists(ModelPath)) + { + File.Delete(ModelPath); + } + } +} diff --git a/test/FAI.Onnx.Tests/Utils/CircularAtomicCounterTests.cs b/test/FAI.Onnx.Tests/Utils/CircularAtomicCounterTests.cs new file mode 100644 index 0000000..d7ea210 --- /dev/null +++ b/test/FAI.Onnx.Tests/Utils/CircularAtomicCounterTests.cs @@ -0,0 +1,59 @@ +using FAI.Onnx.Utils; + +namespace FAI.Onnx.Tests.Utils; + +public class CircularAtomicCounterTests +{ + [Fact] + public void Next_ShouldReturnSequentialValues() + { + // Arrange + var counter = new CircularAtomicCounter(3); + + // Act & Assert + Assert.Equal(0, counter.Next()); + Assert.Equal(1, counter.Next()); + Assert.Equal(2, counter.Next()); + Assert.Equal(0, counter.Next()); + } + + [Fact] + public void Constructor_ShouldThrow_WhenMaxValueIsZeroOrLess() + { + Assert.Throws(() => new CircularAtomicCounter(0)); + Assert.Throws(() => new CircularAtomicCounter(-1)); + } + + [Fact] + public void Next_ShouldBeThreadSafe() + { + // Arrange + const int maxValue = 10; + const int iterations = 1000; + const int threadCount = 10; + var counter = new CircularAtomicCounter(maxValue); + var results = new int[maxValue]; + var @lock = new System.Threading.Lock(); + + // Act + Parallel.For(0, threadCount, _ => + { + for (int i = 0; i < iterations; i++) + { + int val = counter.Next(); + lock (@lock) + { + results[val]++; + } + } + }); + + // Assert + int totalIncrements = threadCount * iterations; + Assert.Equal(totalIncrements, results.Sum()); + foreach (int count in results) + { + Assert.Equal(totalIncrements / maxValue, count); + } + } +} diff --git a/test/FAI.Onnx.Tests/Utils/OnnxTensorUtilsTests.cs b/test/FAI.Onnx.Tests/Utils/OnnxTensorUtilsTests.cs new file mode 100644 index 0000000..273455f --- /dev/null +++ b/test/FAI.Onnx.Tests/Utils/OnnxTensorUtilsTests.cs @@ -0,0 +1,35 @@ +using FAI.Onnx.Utils; +using Microsoft.ML.OnnxRuntime; + +namespace FAI.Onnx.Tests.Utils; + +public class OnnxTensorUtilsTests +{ + [Fact] + public void ToOrtValues_ShouldCreateCorrectNumberOfOrtValues() + { + // Arrange + float[] d1 = [1f, 2f, 3f]; + float[] d2 = [4f, 5f, 6f]; + Memory[] inputsArray = [d1, d2]; + Span> inputs = inputsArray.AsSpan(); + long[] dims = [1, 3]; + + // Act + // We avoid calling OrtValue.CreateTensorValueFromMemory if it hangs in the environment. + // Instead, we just verify the extension method logic if we could, + // but since it's a direct wrapper, we'll try to run it and see if it was the cause. + + var result = inputs.ToOrtValues(dims); + + // Assert + Assert.NotNull(result); + Assert.Equal(2, result.Length); + + // Cleanup + foreach (var ortValue in result) + { + ortValue.Dispose(); + } + } +} diff --git a/test/FAI.Onnx.Tests/Utils/OnnxTestModelFactory.cs b/test/FAI.Onnx.Tests/Utils/OnnxTestModelFactory.cs new file mode 100644 index 0000000..7b4e220 --- /dev/null +++ b/test/FAI.Onnx.Tests/Utils/OnnxTestModelFactory.cs @@ -0,0 +1,30 @@ +using System.Numerics.Tensors; +using FAI.Onnx.Configuration; +using FAI.Onnx.ModelExecutors; +using Microsoft.ML.OnnxRuntime; + +namespace FAI.Onnx.Tests.Utils; + +public static class OnnxTestModelFactory +{ + // A minimal ONNX model that casts INT64 input 'input' to FLOAT output 'output'. Shape [1, 3]. + // IR Version 6, Opset 11. + private const string ConstantModelBase64 = "CAY6WwogCgVpbnB1dBIGb3V0cHV0IgRDYXN0KgkKAnRvGAGgAQISBHRlc3RaFwoFaW5wdXQSDgoMCAcSCAoCCAEKAggDYhgKBm91dHB1dBIOCgwIARIICgIIAQoCCANCBAoAEAs="; + + public static byte[] CreateMinimalModelBytes() + { + return Convert.FromBase64String(ConstantModelBase64); + } + + public static string CreateTemporaryModelFile() + { + var path = Path.Combine(Path.GetTempPath(), $"{Guid.NewGuid()}.onnx"); + File.WriteAllBytes(path, CreateMinimalModelBytes()); + return path; + } + + public static InferenceSession CreateSession() + { + return new InferenceSession(CreateMinimalModelBytes()); + } +}