diff --git a/documentation/query_documentation.md b/documentation/query_documentation.md index 4c2421151..57877dac9 100644 --- a/documentation/query_documentation.md +++ b/documentation/query_documentation.md @@ -85,6 +85,52 @@ See `NucleotideEquals`. See `HasNucleotideMutation`. +#### `NucleotideMutationProfile` + +``` +{ + "distance": number, + "sequenceName": string, // optional; uses the default sequence if omitted + // exactly one of these + "querySequence": string + "sequenceId": string + "mutations": {"position": number, "symbol": string}[] +} +``` + +This filter is true if a sequence has at most `distance` **differences** from a profile sequence. + +A difference at a position is considered when the database sequence's symbol is **not** ambiguity-compatible with the profile's symbol at that position. Ambiguity-compatible means the database symbol appears in the set of symbols compatible with the profile symbol (e.g. `R` is compatible with `A` because `R` represents A or G; `N` is compatible with any definitive base). Positions where the profile symbol is `N` (missing) are always skipped and never counted as differences. + +**Profile input — exactly one of:** + +- `querySequence`: a full sequence string of the same length as the reference. Each character must be a valid symbol for the sequence type. +- `sequenceId`: the primary key of a sequence already in the database. That sequence is used as the profile. +- `mutations`: an array of mutations relative to the reference. Positions are 1-based. Positions not listed retain the reference symbol. An empty array means the profile equals the reference. + - Each entry: `{"position": number, "symbol": string}` where `symbol` is a single valid character. + +**Example** — find all sequences within 5 mutations of the reference: +```json +{ + "type": "NucleotideMutationProfile", + "distance": 5, + "mutations": [] +} +``` + +**Example** — find sequences within 2 mutations of a specific stored sequence: +```json +{ + "type": "NucleotideMutationProfile", + "distance": 2, + "sequenceId": "EPI_ISL_123456" +} +``` + +#### `AminoAcidMutationProfile` + +See `NucleotideMutationProfile`. Applies to amino acid sequences; `symbol` values must be valid amino acid symbols. + #### `Lineage`: `{"column": string, "value": string | null, "includeSublineages": boolean, ["recombinantFollowingMode": string]}` This filter is true if the lineage in column `column` is equal to or an alias of `value`. diff --git a/performance/CMakeLists.txt b/performance/CMakeLists.txt index fa915a8b8..aeafa980f 100644 --- a/performance/CMakeLists.txt +++ b/performance/CMakeLists.txt @@ -1,8 +1,10 @@ function(add_benchmark BENCH_NAME) add_executable(${BENCH_NAME} ${BENCH_NAME}.cpp) target_link_libraries(${BENCH_NAME} silolib) + target_include_directories(${BENCH_NAME} PRIVATE ${CMAKE_SOURCE_DIR}/performance) endfunction(add_benchmark) add_benchmark(mutation_benchmark) add_benchmark(many_string_equals) add_benchmark(many_short_read_filters) +add_benchmark(nof_sequence_filter) diff --git a/performance/many_short_read_filters.cpp b/performance/many_short_read_filters.cpp index f925c0222..d99e3ff81 100644 --- a/performance/many_short_read_filters.cpp +++ b/performance/many_short_read_filters.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include #include @@ -9,6 +8,7 @@ #include #include +#include "sequence_generator.h" #include "silo/append/database_inserter.h" #include "silo/append/ndjson_line_reader.h" #include "silo/initialize/initializer.h" @@ -25,253 +25,8 @@ using silo::Database; namespace { -constexpr size_t DEFAULT_READ_COUNT = 5'000'000; -constexpr size_t DEFAULT_READ_LENGTH = 200; -constexpr double DEFAULT_MUTATION_RATE = 0.001; -constexpr double DEFAULT_DEATH_RATE = 0.1; -constexpr size_t DEFAULT_GENERATIONS = 5; -constexpr size_t DEFAULT_CHILDREN_PER_NODE = 3; constexpr size_t DEFAULT_QUERY_COUNT = 10'000; -void changeCwdToTestFolder() { - // Look for the test data directory (`testBaseData`) in the current directory and up to - // directories above the current directory. If found, change the current working - // directory to the directory containing the test data directory - size_t search_depth = 4; - std::filesystem::path candidate_directory = std::filesystem::current_path().string(); - for (size_t i = 0; i < search_depth; i++, candidate_directory = candidate_directory / "..") { - if (std::filesystem::exists(candidate_directory / "testBaseData/exampleDataset")) { - std::filesystem::current_path(candidate_directory); - return; - } - } - throw std::runtime_error(fmt::format( - "Should be run in root of repository, got {} and could not find root by heuristics", - std::filesystem::current_path().string() - )); -} - -std::string readReferenceFromFile() { - auto reference_genomes = - silo::ReferenceGenomes::readFromFile("testBaseData/exampleDataset/reference_genomes.json"); - if (reference_genomes.raw_nucleotide_sequences.empty()) { - throw std::runtime_error("No nucleotide sequences found in reference genomes file"); - } - return reference_genomes.raw_nucleotide_sequences.at(0); -} - -using silo::Nucleotide; - -// Simple tree-based sequence evolution model -class SequenceTreeGenerator { - std::mt19937 rng; - const std::string& reference; - double mutation_rate; - double death_rate; - size_t generations; - size_t children_per_node; - - char mutateBase(char base) { - std::uniform_int_distribution dist(0, 3); - char new_base; - do { - Nucleotide::Symbol new_symbol = Nucleotide::SYMBOLS.at(dist(rng)); - new_base = Nucleotide::symbolToChar(new_symbol); - } while (new_base == base); - return new_base; - } - - std::string mutateSequence(std::string_view sequence) { - std::string mutated{sequence}; - const size_t seq_length = sequence.size(); - - // Sample the number of mutations from a binomial distribution - std::binomial_distribution num_mutations_dist(seq_length, mutation_rate); - size_t num_mutations = num_mutations_dist(rng); - - // Randomly choose positions to mutate - std::uniform_int_distribution pos_dist(0, seq_length - 1); - for (size_t i = 0; i < num_mutations; ++i) { - size_t pos = pos_dist(rng); - mutated[pos] = mutateBase(mutated[pos]); - } - return mutated; - } - - public: - SequenceTreeGenerator( - const std::string& ref, - uint64_t seed = 42, - double mut_rate = DEFAULT_MUTATION_RATE, - double death = DEFAULT_DEATH_RATE, - size_t gens = DEFAULT_GENERATIONS, - size_t children = DEFAULT_CHILDREN_PER_NODE - ) - : rng(seed), - reference(ref), - mutation_rate(mut_rate), - death_rate(death), - generations(gens), - children_per_node(children) {} - - // Generate evolved sequences using a tree model - std::vector generateEvolvedSequences() { - std::vector all_generated_sequences = {reference}; - std::vector current_generation = {reference}; - std::vector next_generation; - std::bernoulli_distribution survives(1.0 - death_rate); - - for (size_t gen = 0; gen < generations; ++gen) { - next_generation.clear(); - for (const auto& seq : current_generation) { - for (size_t child = 0; child < children_per_node; ++child) { - if (survives(rng)) { - all_generated_sequences.push_back(mutateSequence(seq)); - next_generation.push_back(all_generated_sequences.back()); - } - } - } - if (next_generation.empty()) { - // If all died, keep at least one survivor - next_generation.push_back(all_generated_sequences.back()); - } - current_generation = std::move(next_generation); - } - return all_generated_sequences; - } -}; - -struct ShortRead { - size_t id; - size_t offset; - std::string sequence; -}; - -// Lazy generator for short reads - generates on-demand without materializing all reads -class ShortReadGenerator { - std::vector evolved_sequences; - std::mt19937 rng; - std::uniform_int_distribution seq_dist; - size_t count; - size_t read_length; - size_t num_positions; - - public: - class iterator { - ShortReadGenerator* generator; - size_t current_id; - - public: - using iterator_category = std::input_iterator_tag; - using value_type = ShortRead; - using difference_type = std::ptrdiff_t; - using pointer = const ShortRead*; - using reference = ShortRead; - - iterator(ShortReadGenerator* gen, size_t id) - : generator(gen), - current_id(id) {} - - ShortRead operator*() { return generator->generateAt(current_id); } - - iterator& operator++() { - ++current_id; - return *this; - } - - iterator operator++(int) { - iterator tmp = *this; - ++current_id; - return tmp; - } - - bool operator==(const iterator& other) const { return current_id == other.current_id; } - bool operator!=(const iterator& other) const { return current_id != other.current_id; } - }; - - ShortReadGenerator( - const std::string& reference, - size_t count, - size_t read_length, - uint64_t seed = 42 - ) - : count(count), - read_length(read_length) { - SequenceTreeGenerator tree_gen(reference, seed); - evolved_sequences = tree_gen.generateEvolvedSequences(); - - SPDLOG_INFO("Generated {} evolved sequences from tree model", evolved_sequences.size()); - - const size_t seq_length = reference.size(); - SILO_ASSERT(read_length < seq_length); - num_positions = seq_length - read_length + 1; - - rng.seed(seed + 1000); - seq_dist = std::uniform_int_distribution(0, evolved_sequences.size() - 1); - } - - ShortRead generateAt(size_t read_id) { - const size_t offset = (read_id * num_positions) / count; - const auto& source_seq = evolved_sequences[seq_dist(rng)]; - return {read_id, offset, source_seq.substr(offset, read_length)}; - } - - iterator begin() { return iterator(this, 0); } - iterator end() { return iterator(this, count); } - - [[nodiscard]] size_t size() const { return count; } -}; - -std::stringstream generateShortReadNdjson( - const std::string& reference, - size_t count = DEFAULT_READ_COUNT, - size_t read_length = DEFAULT_READ_LENGTH -) { - ShortReadGenerator generator(reference, count, read_length); - std::stringstream buffer; - - for (const auto& read : generator) { - buffer << fmt::format( - R"({{"readId":"read_{}","samplingDate":"2024-01-01","locationName":"generated","main":{{"insertions":[],"offset":{},"sequence":"{}"}}}})", - read.id, - read.offset, - read.sequence - ) << "\n"; - } - - return buffer; -} - -std::shared_ptr initializeDatabaseWithSingleReference(std::string reference) { - auto database_config = silo::config::DatabaseConfig::getValidatedConfig(R"( -schema: - instanceName: test - metadata: - - name: readId - type: string - - name: samplingDate - type: date - - name: locationName - type: string - primaryKey: readId -)"); - - silo::ReferenceGenomes reference_genomes{{{"main", reference}}, {}}; - - auto database = std::make_shared(); - database->createTable( - silo::schema::TableName::getDefault(), - silo::initialize::Initializer::createSchemaFromConfigFiles( - std::move(database_config), - std::move(reference_genomes), - {}, - silo::common::PhyloTree{}, - /*without_unaligned_sequences=*/true - ) - ); - return database; -} - struct TestDatabaseResult { std::shared_ptr database; size_t reference_length; @@ -280,14 +35,12 @@ struct TestDatabaseResult { TestDatabaseResult setupTestDatabase() { std::string reference = readReferenceFromFile(); SPDLOG_INFO("Read reference sequence of length {}", reference.size()); - size_t ref_length = reference.size(); + const size_t ref_length = reference.size(); auto input_buffer = generateShortReadNdjson(reference); SPDLOG_INFO("Generated short read NDJSON data"); - auto database = initializeDatabaseWithSingleReference(reference); - - auto input_data_stream = silo::append::NdjsonLineReader{input_buffer}; + auto database = initializeDatabaseWithShortReadSchema(reference); database->appendData(silo::schema::TableName::getDefault(), input_buffer); return {database, ref_length}; @@ -306,20 +59,18 @@ class QueryGenerator { std::string generateQuery() { std::uniform_int_distribution pos_dist(1, reference_length - 1); - size_t position = pos_dist(rng); + const size_t position = pos_dist(rng); - bool use_all_symbols = (query_counter++ % 2 == 1); + const bool use_all_symbols = (query_counter++ % 2 == 1); if (use_all_symbols) { - // Query all 5 symbols (A, C, G, T, -) at the same position in an OR return fmt::format( R"({{"action":{{"type":"Aggregated"}},"filterExpression":{{"children":[{{"children":[{{"children":[{{"column":"locationName","value":"generated","type":"StringEquals"}}],"type":"Or"}},{{"column":"samplingDate","from":"2024-01-01","to":"2024-01-07","type":"DateBetween"}}],"type":"And"}},{{"children":[{{"position":{0},"symbol":"A","type":"NucleotideEquals"}},{{"position":{0},"symbol":"C","type":"NucleotideEquals"}},{{"position":{0},"symbol":"G","type":"NucleotideEquals"}},{{"position":{0},"symbol":"T","type":"NucleotideEquals"}},{{"position":{0},"symbol":"-","type":"NucleotideEquals"}}],"type":"Or"}},{{"column":"samplingDate","from":"2024-01-01","to":"2024-01-07","type":"DateBetween"}}],"type":"And"}}}})", position ); } - // Query a single random symbol at the position std::uniform_int_distribution sym_dist(0, SYMBOLS.size() - 1); - char symbol = SYMBOLS[sym_dist(rng)]; + const char symbol = SYMBOLS[sym_dist(rng)]; return fmt::format( R"({{"action":{{"type":"Aggregated"}},"filterExpression":{{"children":[{{"children":[{{"children":[{{"column":"locationName","value":"generated","type":"StringEquals"}}],"type":"Or"}},{{"column":"samplingDate","from":"2024-01-01","to":"2024-01-07","type":"DateBetween"}}],"type":"And"}},{{"position":{},"symbol":"{}","type":"NucleotideEquals"}},{{"column":"samplingDate","from":"2024-01-01","to":"2024-01-07","type":"DateBetween"}}],"type":"And"}}}})", position, @@ -334,7 +85,6 @@ void executeAllQueries( size_t query_count = DEFAULT_QUERY_COUNT ) { QueryGenerator query_gen(reference_length); - for (size_t query_num = 1; query_num <= query_count; ++query_num) { if (query_num % 1000 == 0) { SPDLOG_INFO("Executing query number {}", query_num); @@ -358,7 +108,7 @@ void run() { auto [database, reference_length] = setupTestDatabase(); - while(true){ + while (true) { SPDLOG_INFO("Starting full query set benchmark ({} queries):", DEFAULT_QUERY_COUNT); auto start = std::chrono::high_resolution_clock::now(); executeAllQueries(database, reference_length); @@ -370,7 +120,7 @@ void run() { } // namespace -int main(){ +int main() { try { run(); } catch (std::exception& e) { diff --git a/performance/nof_sequence_filter.cpp b/performance/nof_sequence_filter.cpp new file mode 100644 index 000000000..cb9dcab0f --- /dev/null +++ b/performance/nof_sequence_filter.cpp @@ -0,0 +1,192 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "sequence_generator.h" +#include "silo/query_engine/exec_node/ndjson_sink.h" +#include "silo/query_engine/action_query.h" +#include "silo/query_engine/planner.h" +#include "silo/query_engine/binder.h" + +using silo::query_engine::ActionQuery; +using silo::query_engine::Planner; +using silo::query_engine::Binder; +using silo::Database; + +namespace { + +// ---- Benchmark infrastructure ---- + +struct BenchmarkResult { + double avg_ms; + double min_ms; + double max_ms; +}; + +BenchmarkResult runBenchmark( + const std::shared_ptr& database, + const std::string& query_str, + int iterations +) { + std::vector durations; + durations.reserve(iterations); + + for (int i = 0; i < iterations; ++i) { + auto query = ActionQuery::parseQuery(query_str); + + // rewrite() and compile() — including the full NOf DP pass — happen inside + // createQueryPlan, so the timer must start before it. + const auto start = std::chrono::high_resolution_clock::now(); + auto query_tree = Binder::bindQuery(std::move(query), database->tables); + auto query_plan = Planner::planQuery(std::move(query_tree), database->tables, {}, "benchmark_query"); + std::ofstream null_output("/dev/null"); + silo::query_engine::exec_node::NdjsonSink sink{&null_output, query_plan.results_schema}; + query_plan.executeAndWrite(sink, /*timeout_in_seconds=*/60); + const auto end = std::chrono::high_resolution_clock::now(); + durations.push_back( + std::chrono::duration_cast(end - start).count() + ); + } + + const int64_t sum = std::accumulate(durations.begin(), durations.end(), int64_t{0}); + const int64_t min_val = *std::min_element(durations.begin(), durations.end()); + const int64_t max_val = *std::max_element(durations.begin(), durations.end()); + + return BenchmarkResult{ + .avg_ms = static_cast(sum) / static_cast(iterations) / 1000.0, + .min_ms = static_cast(min_val) / 1000.0, + .max_ms = static_cast(max_val) / 1000.0, + }; +} + +// ---- Query construction ---- + +// NucleotideMutationProfile with querySequence rewrites to +// Not(N-Of(SymbolInSet children, distance+1, false)) +// with one child per non-N position in query_sequence (~genome_length children for a full +// sequence). This exercises the single-pass NOf optimisation at large scale. +std::string buildMutationProfileQuery(const std::string& query_sequence, uint32_t distance) { + return fmt::format( + R"({{"action":{{"type":"Aggregated"}},"filterExpression":{{"type":"NucleotideMutationProfile","distance":{},"querySequence":"{}"}}}})", + distance, + query_sequence + ); +} + +// ---- Database setup ---- + +std::shared_ptr setupShortReadDatabase(const std::string& reference, size_t read_count) { + SPDLOG_INFO( + "Generating {} short reads (length {})...", read_count, DEFAULT_READ_LENGTH + ); + auto ndjson = generateShortReadNdjson(reference, read_count); + auto database = initializeDatabaseWithShortReadSchema(reference); + database->appendData(silo::schema::TableName::getDefault(), ndjson); + SPDLOG_INFO("Short-read database ready."); + return database; +} + +std::shared_ptr setupFullSequenceDatabase(const std::string& reference, size_t read_count) { + SPDLOG_INFO("Generating {} full-length sequences...", read_count); + auto ndjson = generateFullSequenceNdjson(reference, read_count); + auto database = initializeDatabaseWithFullSequenceSchema(reference); + database->appendData(silo::schema::TableName::getDefault(), ndjson); + SPDLOG_INFO("Full-sequence database ready."); + return database; +} + +// ---- Main benchmark runner ---- + +void runMutationProfileBenchmarks( + const std::string& label, + const std::shared_ptr& database, + const std::string& query_sequence, + const std::vector& distances, + int iterations +) { + SPDLOG_INFO("=== {} ===", label); + SPDLOG_INFO( + " Query sequence length: {} (generates ~{} N-Of children)", + query_sequence.size(), + query_sequence.size() + ); + + for (const uint32_t distance : distances) { + const auto query = buildMutationProfileQuery(query_sequence, distance); + const auto result = runBenchmark(database, query, iterations); + SPDLOG_INFO( + " MutationProfile(distance={:>4}): avg={:.2f}ms min={:.2f}ms max={:.2f}ms", + distance, + result.avg_ms, + result.min_ms, + result.max_ms + ); + } + + SPDLOG_INFO(""); +} + +void run() { + changeCwdToTestFolder(); + SILO_ASSERT(arrow::compute::Initialize().ok()); + + const std::string reference = readReferenceFromFile(); + SPDLOG_INFO("Reference genome length: {}", reference.size()); + SPDLOG_INFO(""); + + // Generate an evolved sequence to use as the query profile. + // Using a leaf of the tree maximises divergence from the reference (~5 * genome_length * + // mutation_rate mutations), giving a realistic large NOf with many non-trivial children. + SequenceTreeGenerator tree_gen(reference); + const auto evolved = tree_gen.generateEvolvedSequences(); + const std::string& query_sequence = evolved.back(); + SPDLOG_INFO( + "Using evolved sequence as query profile ({} sequences generated, using last)", + evolved.size() + ); + SPDLOG_INFO(""); + + const auto short_read_db = setupShortReadDatabase(reference, DEFAULT_FULL_SEQ_COUNT); + const auto short_read_db_large = setupShortReadDatabase(reference, DEFAULT_READ_COUNT); + const auto full_seq_db = setupFullSequenceDatabase(reference, DEFAULT_FULL_SEQ_COUNT); + + // distance=0 tests the "almost nothing matches" extreme (exact profile match). + // Large distances test the "almost everything matches" extreme. + const std::vector distances = {0, 5, 50, 200}; + constexpr int ITERATIONS = 10; + + SPDLOG_INFO("Running MutationProfile benchmarks ({} iterations per case)", ITERATIONS); + SPDLOG_INFO(""); + + runMutationProfileBenchmarks( + "Short-read database", short_read_db, query_sequence, distances, ITERATIONS + ); + runMutationProfileBenchmarks( + "Large short-read database", short_read_db_large, query_sequence, distances, ITERATIONS + ); + runMutationProfileBenchmarks( + "Full-sequence database", full_seq_db, query_sequence, distances, ITERATIONS + ); + + SPDLOG_INFO("=== Benchmark complete ==="); +} + +} // namespace + +int main() { + try { + run(); + } catch (const std::exception& e) { + SPDLOG_ERROR(e.what()); + return EXIT_FAILURE; + } +} diff --git a/performance/sequence_generator.h b/performance/sequence_generator.h new file mode 100644 index 000000000..d671484dd --- /dev/null +++ b/performance/sequence_generator.h @@ -0,0 +1,310 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "silo/append/ndjson_line_reader.h" +#include "silo/common/phylo_tree.h" +#include "silo/config/database_config.h" +#include "silo/database.h" +#include "silo/initialize/initializer.h" +#include "silo/storage/reference_genomes.h" + +// Header-only utilities shared between performance benchmarks. +// All definitions live inside an anonymous namespace so that each benchmark +// translation unit gets its own copy without ODR conflicts. + +namespace { + +// --- Filesystem helpers --- + +void changeCwdToTestFolder() { + size_t search_depth = 4; + std::filesystem::path candidate = std::filesystem::current_path(); + for (size_t i = 0; i < search_depth; ++i, candidate = candidate / "..") { + if (std::filesystem::exists(candidate / "testBaseData/exampleDataset")) { + std::filesystem::current_path(candidate); + return; + } + } + throw std::runtime_error(fmt::format( + "Should be run from the repository root; could not find it from {}", + std::filesystem::current_path().string() + )); +} + +std::string readReferenceFromFile() { + auto reference_genomes = + silo::ReferenceGenomes::readFromFile("testBaseData/exampleDataset/reference_genomes.json"); + if (reference_genomes.raw_nucleotide_sequences.empty()) { + throw std::runtime_error("No nucleotide sequences found in reference_genomes.json"); + } + return reference_genomes.raw_nucleotide_sequences.at(0); +} + +// --- Sequence evolution model --- + +constexpr double DEFAULT_MUTATION_RATE = 0.001; +constexpr double DEFAULT_DEATH_RATE = 0.1; +constexpr size_t DEFAULT_GENERATIONS = 5; +constexpr size_t DEFAULT_CHILDREN_PER_NODE = 3; + +class SequenceTreeGenerator { + std::mt19937 rng; + const std::string& reference; + double mutation_rate; + double death_rate; + size_t generations; + size_t children_per_node; + + char mutateBase(char base) { + std::uniform_int_distribution dist(0, 3); + char new_base; + do { + silo::Nucleotide::Symbol new_symbol = silo::Nucleotide::SYMBOLS.at(dist(rng)); + new_base = silo::Nucleotide::symbolToChar(new_symbol); + } while (new_base == base); + return new_base; + } + + std::string mutateSequence(std::string_view sequence) { + std::string mutated{sequence}; + std::binomial_distribution num_mutations_dist(sequence.size(), mutation_rate); + const size_t num_mutations = num_mutations_dist(rng); + std::uniform_int_distribution pos_dist(0, sequence.size() - 1); + for (size_t i = 0; i < num_mutations; ++i) { + const size_t pos = pos_dist(rng); + mutated[pos] = mutateBase(mutated[pos]); + } + return mutated; + } + + public: + SequenceTreeGenerator( + const std::string& ref, + uint64_t seed = 42, + double mut_rate = DEFAULT_MUTATION_RATE, + double death = DEFAULT_DEATH_RATE, + size_t gens = DEFAULT_GENERATIONS, + size_t children = DEFAULT_CHILDREN_PER_NODE + ) + : rng(seed), + reference(ref), + mutation_rate(mut_rate), + death_rate(death), + generations(gens), + children_per_node(children) {} + + std::vector generateEvolvedSequences() { + std::vector all_sequences = {reference}; + std::vector current_gen = {0}; + std::bernoulli_distribution survives(1.0 - death_rate); + for (size_t gen = 0; gen < generations; ++gen) { + std::vector next_gen; + for (size_t seq_index : current_gen) { + for (size_t child = 0; child < children_per_node; ++child) { + if (survives(rng)) { + all_sequences.push_back(mutateSequence(all_sequences.at(seq_index))); + next_gen.push_back(all_sequences.size() - 1); + } + } + } + if (next_gen.empty()) { + next_gen.push_back(all_sequences.size() - 1); + } + current_gen = std::move(next_gen); + } + return all_sequences; + } +}; + +// --- Short-read generation --- + +constexpr size_t DEFAULT_READ_COUNT = 5'000'000; +constexpr size_t DEFAULT_READ_LENGTH = 200; + +struct ShortRead { + size_t id; + size_t offset; + std::string sequence; +}; + +class ShortReadGenerator { + std::vector evolved_sequences; + std::mt19937 rng; + std::uniform_int_distribution seq_dist; + size_t count; + size_t read_length; + size_t num_positions; + + public: + class iterator { + ShortReadGenerator* generator; + size_t current_id; + + public: + using iterator_category = std::input_iterator_tag; + using value_type = ShortRead; + using difference_type = std::ptrdiff_t; + + iterator(ShortReadGenerator* gen, size_t id) + : generator(gen), + current_id(id) {} + + ShortRead operator*() { return generator->generateAt(current_id); } + iterator& operator++() { + ++current_id; + return *this; + } + iterator operator++(int) { + iterator tmp = *this; + ++current_id; + return tmp; + } + bool operator==(const iterator& other) const { return current_id == other.current_id; } + bool operator!=(const iterator& other) const { return current_id != other.current_id; } + }; + + ShortReadGenerator( + const std::string& reference, + size_t count, + size_t read_length, + uint64_t seed = 42 + ) + : count(count), + read_length(read_length) { + if (read_length > reference.size()) { + throw std::invalid_argument(fmt::format( + "read_length ({}) exceeds reference length ({})", read_length, reference.size() + )); + } + SequenceTreeGenerator tree_gen(reference, seed); + evolved_sequences = tree_gen.generateEvolvedSequences(); + SPDLOG_INFO("Generated {} evolved sequences from tree model", evolved_sequences.size()); + num_positions = reference.size() - read_length + 1; + rng.seed(seed + 1000); + seq_dist = std::uniform_int_distribution(0, evolved_sequences.size() - 1); + } + + ShortRead generateAt(size_t read_id) { + const size_t offset = (read_id * num_positions) / count; + const auto& source_seq = evolved_sequences[seq_dist(rng)]; + return {read_id, offset, source_seq.substr(offset, read_length)}; + } + + iterator begin() { return iterator(this, 0); } + iterator end() { return iterator(this, count); } + [[nodiscard]] size_t size() const { return count; } +}; + +// --- NDJSON generators --- + +std::stringstream generateShortReadNdjson( + const std::string& reference, + size_t count = DEFAULT_READ_COUNT, + size_t read_length = DEFAULT_READ_LENGTH +) { + ShortReadGenerator generator(reference, count, read_length); + std::stringstream buffer; + for (const auto& read : generator) { + buffer << fmt::format( + R"({{"readId":"read_{}","samplingDate":"2024-01-01","locationName":"generated","main":{{"insertions":[],"offset":{},"sequence":"{}"}}}})", + read.id, + read.offset, + read.sequence + ) << "\n"; + } + return buffer; +} + +constexpr size_t DEFAULT_FULL_SEQ_COUNT = 100'000; + +std::stringstream generateFullSequenceNdjson( + const std::string& reference, + size_t count = DEFAULT_FULL_SEQ_COUNT +) { + SequenceTreeGenerator tree_gen(reference); + const auto evolved = tree_gen.generateEvolvedSequences(); + SPDLOG_INFO( + "Repeating {} evolved sequences to fill {} full-sequence entries", + evolved.size(), + count + ); + std::stringstream buffer; + for (size_t i = 0; i < count; ++i) { + const auto& seq = evolved[i % evolved.size()]; + buffer << fmt::format( + R"({{"key":"{}","main":{{"sequence":"{}","insertions":[]}}}})", i, seq + ) << "\n"; + } + return buffer; +} + +// --- Database initializers --- + +std::shared_ptr initializeDatabaseWithShortReadSchema( + const std::string& reference +) { + auto database_config = silo::config::DatabaseConfig::getValidatedConfig(R"( +schema: + instanceName: test + metadata: + - name: readId + type: string + - name: samplingDate + type: date + - name: locationName + type: string + primaryKey: readId +)"); + silo::ReferenceGenomes reference_genomes{{{"main", reference}}, {}}; + auto database = std::make_shared(); + database->createTable( + silo::schema::TableName::getDefault(), + silo::initialize::Initializer::createSchemaFromConfigFiles( + std::move(database_config), + std::move(reference_genomes), + {}, + silo::common::PhyloTree{}, + /*without_unaligned_sequences=*/true + ) + ); + return database; +} + +std::shared_ptr initializeDatabaseWithFullSequenceSchema( + const std::string& reference +) { + auto database_config = silo::config::DatabaseConfig::getValidatedConfig(R"( +schema: + instanceName: test + metadata: + - name: key + type: string + primaryKey: key +)"); + silo::ReferenceGenomes reference_genomes{{{"main", reference}}, {}}; + auto database = std::make_shared(); + database->createTable( + silo::schema::TableName::getDefault(), + silo::initialize::Initializer::createSchemaFromConfigFiles( + std::move(database_config), + std::move(reference_genomes), + {}, + silo::common::PhyloTree{}, + /*without_unaligned_sequences=*/true + ) + ); + return database; +} + +} // namespace diff --git a/src/silo/common/string_utils.h b/src/silo/common/string_utils.h index 56776b7b6..27e81c35e 100644 --- a/src/silo/common/string_utils.h +++ b/src/silo/common/string_utils.h @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include #include @@ -30,4 +32,28 @@ std::string tieAsString( std::string_view suffix ); +template +std::string joinWithLimit( + const std::vector& items, + std::string_view delimiter = ", ", + size_t limit = 10 +) { + std::string res; + const size_t items_to_print = std::min(items.size(), limit); + + for (size_t i = 0; i < items_to_print; ++i) { + if (i > 0) { + res += delimiter; + } + // Assumes items[i] has a toString() method or works with fmt + res += items[i]->toString(); + } + + if (items.size() > items_to_print) { + res += fmt::format("{}... ({} more)", delimiter, items.size() - items_to_print); + } + + return res; +} + } // namespace silo diff --git a/src/silo/query_engine/filter/expressions/and.cpp b/src/silo/query_engine/filter/expressions/and.cpp index e13abfa92..609f9ee6f 100644 --- a/src/silo/query_engine/filter/expressions/and.cpp +++ b/src/silo/query_engine/filter/expressions/and.cpp @@ -8,9 +8,9 @@ #include #include #include -#include #include +#include "silo/common/string_utils.h" #include "silo/query_engine/filter/expressions/expression.h" #include "silo/query_engine/filter/operators/complement.h" #include "silo/query_engine/filter/operators/empty.h" @@ -30,13 +30,10 @@ And::And(ExpressionVector&& children) : children(std::move(children)) {} std::string And::toString() const { - std::vector child_strings; - std::ranges::transform( - children, - std::back_inserter(child_strings), - [&](const std::unique_ptr& child) { return child->toString(); } - ); - return "And(" + boost::algorithm::join(child_strings, " & ") + ")"; + std::string res = "And("; + res += joinWithLimit(children, " & "); + res += ")"; + return res; } namespace { diff --git a/src/silo/query_engine/filter/expressions/expression.cpp b/src/silo/query_engine/filter/expressions/expression.cpp index 266f1aba5..c4d7f45c9 100644 --- a/src/silo/query_engine/filter/expressions/expression.cpp +++ b/src/silo/query_engine/filter/expressions/expression.cpp @@ -21,6 +21,7 @@ #include "silo/query_engine/filter/expressions/is_null.h" #include "silo/query_engine/filter/expressions/lineage_filter.h" #include "silo/query_engine/filter/expressions/maybe.h" +#include "silo/query_engine/filter/expressions/mutation_profile.h" #include "silo/query_engine/filter/expressions/negation.h" #include "silo/query_engine/filter/expressions/nof.h" #include "silo/query_engine/filter/expressions/or.h" @@ -111,6 +112,10 @@ void from_json(const nlohmann::json& json, std::unique_ptr& filter) filter = json.get>(); } else if (expression_type == "IsNotNull") { filter = std::make_unique(json.get>()); + } else if (expression_type == "NucleotideMutationProfile") { + filter = json.get>>(); + } else if (expression_type == "AminoAcidMutationProfile") { + filter = json.get>>(); } else { throw query_engine::IllegalQueryException( "Unknown object filter type '" + expression_type + "'" diff --git a/src/silo/query_engine/filter/expressions/mutation_profile.cpp b/src/silo/query_engine/filter/expressions/mutation_profile.cpp new file mode 100644 index 000000000..76ab67a90 --- /dev/null +++ b/src/silo/query_engine/filter/expressions/mutation_profile.cpp @@ -0,0 +1,411 @@ +#include "silo/query_engine/filter/expressions/mutation_profile.h" + +#include +#include +#include +#include + +#include +#include + +#include "silo/common/aa_symbols.h" +#include "silo/common/nucleotide_symbols.h" +#include "silo/query_engine/filter/expressions/expression.h" +#include "silo/query_engine/filter/expressions/negation.h" +#include "silo/query_engine/filter/expressions/nof.h" +#include "silo/query_engine/filter/expressions/symbol_in_set.h" +#include "silo/query_engine/filter/operators/operator.h" +#include "silo/query_engine/illegal_query_exception.h" +#include "silo/query_engine/query_compilation_exception.h" +#include "silo/query_engine/query_parse_sequence_name.h" +#include "silo/schema/database_schema.h" +#include "silo/storage/column/horizontal_coverage_index.h" +#include "silo/storage/column/sequence_column.h" +#include "silo/storage/table.h" + +namespace silo::query_engine::filter::expressions { + +template +MutationProfile::MutationProfile( + std::optional sequence_name, + uint32_t distance, + ProfileInput input +) + : sequence_name(std::move(sequence_name)), + distance(distance), + input(std::move(input)) {} + +template +std::string MutationProfile::toString() const { + const std::string seq_prefix = sequence_name ? sequence_name.value() + ":" : ""; + const std::string input_str = std::visit( + [](const auto& inp) -> std::string { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return "querySequence=" + inp.sequence.substr(0, 20) + "..."; + } else if constexpr (std::is_same_v) { + return "sequenceId=" + inp.id; + } else { + return "mutations(count=" + std::to_string(inp.mutations.size()) + ")"; + } + }, + input + ); + return fmt::format("MutationProfile({}distance={},{})", seq_prefix, distance, input_str); +} + +template +std::vector MutationProfile::buildProfileFromQuerySequence( + const storage::column::SequenceColumn& sequence_column +) const { + const auto& query_sequence = std::get(input).sequence; + const size_t ref_len = sequence_column.metadata->reference_sequence.size(); + CHECK_SILO_QUERY( + query_sequence.size() == ref_len, + "querySequence length {} does not match the reference sequence length {} for {} " + "MutationProfile", + query_sequence.size(), + ref_len, + SymbolType::SYMBOL_NAME + ); + + std::vector profile; + profile.reserve(ref_len); + for (char character : query_sequence) { + const auto symbol = SymbolType::charToSymbol(character); + CHECK_SILO_QUERY( + symbol.has_value(), + "Invalid {} symbol '{}' in querySequence for MutationProfile", + SymbolType::SYMBOL_NAME, + character + ); + profile.push_back(symbol.value()); + } + return profile; +} + +namespace { + +template +std::vector reconstructSequenceAtRow( + const storage::column::SequenceColumn& sequence_column, + uint32_t row_id +) { + roaring::Roaring single_row; + single_row.add(row_id); + + std::vector sequences = {sequence_column.local_reference_sequence_string}; + sequence_column.vertical_sequence_index.overwriteSymbolsInSequences(sequences, single_row); + sequence_column.horizontal_coverage_index.template overwriteCoverageInSequence( + sequences, single_row + ); + + std::vector profile; + profile.reserve(sequences[0].size()); + for (const char character : sequences[0]) { + const auto sym = SymbolType::charToSymbol(character); + SILO_ASSERT(sym.has_value()); + profile.push_back(sym.value()); + } + return profile; +} + +} // namespace + +template +std::vector MutationProfile::buildProfileFromSequenceId( + const storage::Table& table, + const std::string& valid_sequence_name +) const { + const auto& seq_id = std::get(input).id; + const auto& primary_key_name = table.schema->primary_key.name; + const auto primary_key_type = table.schema->primary_key.type; + + const auto& seq_col = + table.columns.getColumns().at(valid_sequence_name); + + std::optional found_row_id; + + if (primary_key_type == schema::ColumnType::STRING) { + const auto& primary_key_column = table.columns.string_columns.at(primary_key_name); + for (uint32_t row_id = 0; row_id < static_cast(primary_key_column.numValues()); + ++row_id) { + if (primary_key_column.getValueString(row_id) == seq_id) { + found_row_id = row_id; + break; + } + } + } else if (primary_key_type == schema::ColumnType::INDEXED_STRING) { + const auto& primary_key_column = table.columns.indexed_string_columns.at(primary_key_name); + const auto bitmap_opt = primary_key_column.filter(std::optional(seq_id)); + if (bitmap_opt.has_value() && !bitmap_opt.value()->isEmpty()) { + found_row_id = bitmap_opt.value()->minimum(); + } + } else { + throw IllegalQueryException(fmt::format( + "Unsupported primary key column type for {} MutationProfile sequenceId lookup", + SymbolType::SYMBOL_NAME + )); + } + + if (found_row_id.has_value()) { + return reconstructSequenceAtRow(seq_col, found_row_id.value()); + } + + CHECK_SILO_QUERY( + false, + "No sequence found with primary key '{}' in {} MutationProfile", + seq_id, + SymbolType::SYMBOL_NAME + ); + SILO_UNREACHABLE(); +} + +template +std::vector MutationProfile::buildProfileFromMutations( + const storage::column::SequenceColumn& sequence_column +) const { + const auto& mutation_list = std::get(input).mutations; + const size_t ref_len = sequence_column.metadata->reference_sequence.size(); + + // Start with a copy of the reference sequence + std::vector profile( + sequence_column.metadata->reference_sequence.begin(), + sequence_column.metadata->reference_sequence.end() + ); + + for (const auto& mutation : mutation_list) { + CHECK_SILO_QUERY( + mutation.position_idx < ref_len, + "{} MutationProfile mutation position {} is out of bounds (reference length {})", + SymbolType::SYMBOL_NAME, + mutation.position_idx + 1, + ref_len + ); + profile[mutation.position_idx] = mutation.symbol; + } + return profile; +} + +template +std::unique_ptr MutationProfile::rewrite( + const storage::Table& table, + AmbiguityMode /*mode*/ +) const { + CHECK_SILO_QUERY( + sequence_name.has_value() || table.schema->getDefaultSequenceName(), + "Database does not have a default sequence name for {} sequences. " + "You need to provide the sequence name with the {} MutationProfile filter.", + SymbolType::SYMBOL_NAME, + SymbolType::SYMBOL_NAME + ); + + const auto valid_sequence_name = + validateSequenceNameOrGetDefault(sequence_name, *table.schema); + + const auto& sequence_column = + table.columns.getColumns().at(valid_sequence_name); + + // Build the profile sequence + std::vector profile; + if (std::holds_alternative(input)) { + profile = buildProfileFromQuerySequence(sequence_column); + } else if (std::holds_alternative(input)) { + profile = buildProfileFromSequenceId(table, valid_sequence_name); + } else { + profile = buildProfileFromMutations(sequence_column); + } + + // For each position, build a "difference" child expression: + // difference at pos = SymbolInSet(seq_name, pos, symbols NOT compatible with profile[pos]) + // where "compatible" means: symbols in AMBIGUITY_SYMBOLS[profile[pos]] + ExpressionVector difference_children; + for (size_t pos = 0; pos < profile.size(); ++pos) { + const auto profile_symbol = profile[pos]; + + // Skip positions where the profile has the missing/unknown symbol + if (profile_symbol == SymbolType::SYMBOL_MISSING) { + continue; + } + + // Compute symbols that are NOT compatible with profile_symbol + // (i.e., symbols that count as "definitely different" from profile_symbol) + const auto& compatible_symbols = SymbolType::AMBIGUITY_SYMBOLS.at(profile_symbol); + std::vector difference_symbols; + for (const auto sym : SymbolType::SYMBOLS) { + if (std::find(compatible_symbols.begin(), compatible_symbols.end(), sym) == + compatible_symbols.end()) { + difference_symbols.push_back(sym); + } + } + + if (difference_symbols.empty()) { + continue; + } + + difference_children.push_back(std::make_unique>( + valid_sequence_name, static_cast(pos), std::move(difference_symbols) + )); + } + + // Return Not(NOf(difference_children, distance+1, false)) + // = "at most 'distance' differences" (conservative) + auto at_least_distance_plus_one = std::make_unique( + std::move(difference_children), + static_cast(distance) + 1, + /*match_exactly=*/false + ); + return std::make_unique(std::move(at_least_distance_plus_one)); +} + +template +std::unique_ptr MutationProfile::compile( + const storage::Table& /*table*/ +) const { + throw QueryCompilationException{ + "{} MutationProfile expression must be eliminated in the query rewrite phase", + SymbolType::SYMBOL_NAME + }; +} + +template +// NOLINTNEXTLINE(readability-identifier-naming,readability-function-cognitive-complexity) +void from_json(const nlohmann::json& json, std::unique_ptr>& filter) { + CHECK_SILO_QUERY( + json.contains("distance"), + "The field 'distance' is required in a {} MutationProfile expression", + SymbolType::SYMBOL_NAME + ); + CHECK_SILO_QUERY( + json["distance"].is_number_unsigned(), + "The field 'distance' in a {} MutationProfile expression must be an unsigned integer", + SymbolType::SYMBOL_NAME + ); + const uint32_t distance = json["distance"].get(); + + std::optional seq_name; + if (json.contains("sequenceName")) { + seq_name = json["sequenceName"].get(); + } + + const bool has_query_sequence = json.contains("querySequence"); + const bool has_sequence_id = json.contains("sequenceId"); + const bool has_mutations = json.contains("mutations"); + + const int input_count = static_cast(has_query_sequence) + + static_cast(has_sequence_id) + static_cast(has_mutations); + + CHECK_SILO_QUERY( + input_count == 1, + "Exactly one of 'querySequence', 'sequenceId', or 'mutations' must be provided in a {} " + "MutationProfile expression, but {} were provided", + SymbolType::SYMBOL_NAME, + input_count + ); + + if (has_query_sequence) { + CHECK_SILO_QUERY( + json["querySequence"].is_string(), + "The field 'querySequence' in a {} MutationProfile expression must be a string", + SymbolType::SYMBOL_NAME + ); + filter = std::make_unique>( + seq_name, + distance, + typename MutationProfile::QuerySequenceInput{ + json["querySequence"].get() + } + ); + return; + } + + if (has_sequence_id) { + CHECK_SILO_QUERY( + json["sequenceId"].is_string(), + "The field 'sequenceId' in a {} MutationProfile expression must be a string", + SymbolType::SYMBOL_NAME + ); + filter = std::make_unique>( + seq_name, + distance, + typename MutationProfile::SequenceIdInput{json["sequenceId"].get() + } + ); + return; + } + + // has_mutations == true + CHECK_SILO_QUERY( + json["mutations"].is_array(), + "The field 'mutations' in a {} MutationProfile expression must be an array", + SymbolType::SYMBOL_NAME + ); + + std::vector::Mutation> mutations; + for (const auto& mut_json : json["mutations"]) { + CHECK_SILO_QUERY( + mut_json.contains("position"), + "Each mutation in a {} MutationProfile expression must have a 'position' field", + SymbolType::SYMBOL_NAME + ); + CHECK_SILO_QUERY( + mut_json["position"].is_number_unsigned(), + "The field 'position' in a {} MutationProfile mutation must be an unsigned integer", + SymbolType::SYMBOL_NAME + ); + CHECK_SILO_QUERY( + mut_json.contains("symbol"), + "Each mutation in a {} MutationProfile expression must have a 'symbol' field", + SymbolType::SYMBOL_NAME + ); + CHECK_SILO_QUERY( + mut_json["symbol"].is_string(), + "The field 'symbol' in a {} MutationProfile mutation must be a string", + SymbolType::SYMBOL_NAME + ); + + const uint32_t position_1indexed = mut_json["position"].get(); + CHECK_SILO_QUERY( + position_1indexed > 0, + "The field 'position' in a {} MutationProfile mutation is 1-indexed. Value 0 is not " + "allowed.", + SymbolType::SYMBOL_NAME + ); + const uint32_t position_idx = position_1indexed - 1; + + const std::string& symbol_str = mut_json["symbol"]; + CHECK_SILO_QUERY( + symbol_str.size() == 1, + "The field 'symbol' in a {} MutationProfile mutation must be exactly one character", + SymbolType::SYMBOL_NAME + ); + const auto symbol = SymbolType::charToSymbol(symbol_str[0]); + CHECK_SILO_QUERY( + symbol.has_value(), + "Invalid {} symbol '{}' in MutationProfile mutations", + SymbolType::SYMBOL_NAME, + symbol_str + ); + + mutations.push_back({position_idx, symbol.value()}); + } + + filter = std::make_unique>( + seq_name, distance, typename MutationProfile::MutationsInput{std::move(mutations)} + ); +} + +template void from_json( + const nlohmann::json& json, + std::unique_ptr>& filter +); + +template void from_json( + const nlohmann::json& json, + std::unique_ptr>& filter +); + +template class MutationProfile; +template class MutationProfile; + +} // namespace silo::query_engine::filter::expressions diff --git a/src/silo/query_engine/filter/expressions/mutation_profile.h b/src/silo/query_engine/filter/expressions/mutation_profile.h new file mode 100644 index 000000000..12b0f5e40 --- /dev/null +++ b/src/silo/query_engine/filter/expressions/mutation_profile.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include "silo/query_engine/filter/expressions/expression.h" +#include "silo/query_engine/filter/operators/operator.h" +#include "silo/storage/column/sequence_column.h" +#include "silo/storage/table.h" + +namespace silo::query_engine::filter::expressions { + +template +class MutationProfile : public Expression { + public: + struct Mutation { + uint32_t position_idx; // 0-indexed + typename SymbolType::Symbol symbol; + }; + + struct QuerySequenceInput { + std::string sequence; + }; + + struct SequenceIdInput { + std::string id; + }; + + struct MutationsInput { + std::vector mutations; + }; + + using ProfileInput = std::variant; + + private: + std::optional sequence_name; + uint32_t distance; + ProfileInput input; + + [[nodiscard]] std::vector buildProfileFromQuerySequence( + const storage::column::SequenceColumn& sequence_column + ) const; + + [[nodiscard]] std::vector buildProfileFromSequenceId( + const storage::Table& table, + const std::string& valid_sequence_name + ) const; + + [[nodiscard]] std::vector buildProfileFromMutations( + const storage::column::SequenceColumn& sequence_column + ) const; + + public: + explicit MutationProfile( + std::optional sequence_name, + uint32_t distance, + ProfileInput input + ); + + [[nodiscard]] std::string toString() const override; + + [[nodiscard]] std::unique_ptr rewrite( + const storage::Table& table, + AmbiguityMode mode + ) const override; + + [[nodiscard]] std::unique_ptr compile(const storage::Table& table + ) const override; +}; + +template +// NOLINTNEXTLINE(readability-identifier-naming) +void from_json(const nlohmann::json& json, std::unique_ptr>& filter); + +} // namespace silo::query_engine::filter::expressions diff --git a/src/silo/query_engine/filter/expressions/nof.cpp b/src/silo/query_engine/filter/expressions/nof.cpp index 0a0abe1c1..a435eefd2 100644 --- a/src/silo/query_engine/filter/expressions/nof.cpp +++ b/src/silo/query_engine/filter/expressions/nof.cpp @@ -6,6 +6,7 @@ #include +#include "silo/common/string_utils.h" #include "silo/query_engine/filter/expressions/and.h" #include "silo/query_engine/filter/expressions/expression.h" #include "silo/query_engine/filter/expressions/negation.h" @@ -166,10 +167,7 @@ std::string NOf::toString() const { } else { res = "[" + std::to_string(number_of_matchers) + "-of:"; } - for (const auto& child : children) { - res += child->toString(); - res += ", "; - } + res += joinWithLimit(children); res += "]"; return res; } @@ -233,7 +231,6 @@ std::unique_ptr NOf::rewriteToNonExact( this->number_of_matchers + 1, /*match_exactly=*/false ); - ; ExpressionVector and_children; and_children.push_back(std::move(at_least_k)); and_children.push_back(std::make_unique(std::move(at_least_k_plus_one))); @@ -253,6 +250,13 @@ std::unique_ptr NOf::compile(const storage::Table& table) c auto [non_negated_child_operators, negated_child_operators, updated_number_of_matchers] = mapChildExpressions(table); + if (updated_number_of_matchers < 0) { + if (match_exactly) { + return std::make_unique(table.sequence_count); + } + return std::make_unique(table.sequence_count); + } + return toOperator( updated_number_of_matchers, std::move(non_negated_child_operators), diff --git a/src/silo/query_engine/filter/expressions/or.cpp b/src/silo/query_engine/filter/expressions/or.cpp index e79cbfb0e..9b677f46f 100644 --- a/src/silo/query_engine/filter/expressions/or.cpp +++ b/src/silo/query_engine/filter/expressions/or.cpp @@ -4,9 +4,9 @@ #include #include -#include #include +#include "silo/common/string_utils.h" #include "silo/query_engine/filter/expressions/expression.h" #include "silo/query_engine/filter/expressions/false.h" #include "silo/query_engine/filter/expressions/string_in_set.h" @@ -27,13 +27,10 @@ Or::Or(ExpressionVector&& children) : children(std::move(children)) {} std::string Or::toString() const { - std::vector child_strings; - std::ranges::transform( - children, - std::back_inserter(child_strings), - [&](const std::unique_ptr& child) { return child->toString(); } - ); - return "Or(" + boost::algorithm::join(child_strings, " | ") + ")"; + std::string res = "Or("; + res += joinWithLimit(children, " | "); + res += ")"; + return res; } std::vector Or::collectChildren(const ExpressionVector& children) { diff --git a/src/silo/query_engine/filter/operators/intersection.cpp b/src/silo/query_engine/filter/operators/intersection.cpp index 300e24138..2f6d1eb0e 100644 --- a/src/silo/query_engine/filter/operators/intersection.cpp +++ b/src/silo/query_engine/filter/operators/intersection.cpp @@ -7,6 +7,7 @@ #include #include "evobench/evobench.hpp" +#include "silo/common/string_utils.h" #include "silo/query_engine/copy_on_write_bitmap.h" #include "silo/query_engine/filter/operators/complement.h" #include "silo/query_engine/filter/operators/operator.h" @@ -42,14 +43,11 @@ Intersection::Intersection( Intersection::~Intersection() noexcept = default; std::string Intersection::toString() const { - std::string res = "(" + children[0]->toString(); + std::string res = "Intersection("; + + res += "non_negated: (" + joinWithLimit(children) + ") "; + res += "negated: (" + joinWithLimit(negated_children) + ") "; - for (uint32_t i = 1; i < children.size(); i++) { - res += " & " + children[i]->toString(); - } - for (const auto& child : negated_children) { - res += " &! " + child->toString(); - } res += ")"; return res; } diff --git a/src/silo/query_engine/filter/operators/threshold.cpp b/src/silo/query_engine/filter/operators/threshold.cpp index 4ea694ef7..0be5f96f1 100644 --- a/src/silo/query_engine/filter/operators/threshold.cpp +++ b/src/silo/query_engine/filter/operators/threshold.cpp @@ -7,6 +7,7 @@ #include #include "evobench/evobench.hpp" +#include "silo/common/string_utils.h" #include "silo/query_engine/copy_on_write_bitmap.h" #include "silo/query_engine/filter/operators/complement.h" #include "silo/query_engine/filter/operators/operator.h" @@ -42,18 +43,16 @@ Threshold::Threshold( Threshold::~Threshold() noexcept = default; std::string Threshold::toString() const { - std::string res; + std::string res = "Threshold("; if (match_exactly) { res += "="; } else { res += ">="; } - for (const auto& child : this->non_negated_children) { - res += ", " + child->toString(); - } - for (const auto& child : this->non_negated_children) { - res += ", ! " + child->toString(); - } + res += fmt::format("{}-of ", number_of_matchers); + + res += "non_negated: (" + joinWithLimit(non_negated_children) + ") "; + res += "negated: (" + joinWithLimit(negated_children) + ") "; res += ")"; return res; } @@ -93,7 +92,7 @@ CopyOnWriteBitmap Threshold::evaluate() const { ); // Number of loop iterations for (int i = 1; i < non_negated_child_count; ++i) { - auto bitmap = non_negated_children[i]->evaluate(); + const auto bitmap = non_negated_children[i]->evaluate(); // positions higher than (i-1) cannot have been reached yet, are therefore all 0s and the // conjunction would return 0 // positions lower than n - k + i - 1 are unable to affect the result, because only (k - i) @@ -118,7 +117,7 @@ CopyOnWriteBitmap Threshold::evaluate() const { const int i = local_i + non_negated_child_count; // positions higher than (i-1) cannot have been reached yet, are therefore all 0s and the // conjunction would return 0 - // positions lower than n - k + i - 1 are unable to affect the result, because only (k - i) + // positions lower than (n-1) - (k-i) are unable to affect the result, because only (k-i) // iterations are left for (int j = std::min(max_table_index, i); j > std::max(0, n - k + i - 1); --j) { bitmaps[j] |= bitmaps[j - 1] - bitmap.getConstReference(); diff --git a/src/silo/query_engine/filter/operators/union.cpp b/src/silo/query_engine/filter/operators/union.cpp index 7cf9b5bcc..3fb6194ec 100644 --- a/src/silo/query_engine/filter/operators/union.cpp +++ b/src/silo/query_engine/filter/operators/union.cpp @@ -7,6 +7,7 @@ #include #include "evobench/evobench.hpp" +#include "silo/common/string_utils.h" #include "silo/query_engine/copy_on_write_bitmap.h" #include "silo/query_engine/filter/operators/complement.h" #include "silo/query_engine/filter/operators/operator.h" @@ -20,11 +21,8 @@ Union::Union(OperatorVector&& children, uint32_t row_count) Union::~Union() noexcept = default; std::string Union::toString() const { - std::string res = "(" + children[0]->toString(); - for (size_t i = 1; i < children.size(); ++i) { - const auto& child = children[i]; - res += " | " + child->toString(); - } + std::string res = "("; + res += joinWithLimit(children, " | "); res += ")"; return res; } diff --git a/src/silo/test/mutation_profile.test.cpp b/src/silo/test/mutation_profile.test.cpp new file mode 100644 index 000000000..f2240fb18 --- /dev/null +++ b/src/silo/test/mutation_profile.test.cpp @@ -0,0 +1,402 @@ +#include + +#include "silo/test/query_fixture.test.h" + +namespace { +using silo::ReferenceGenomes; +using silo::test::QueryTestData; +using silo::test::QueryTestScenario; + +nlohmann::json createData( + const std::string& primary_key, + const std::string& nucleotide_sequence, + const std::string& amino_acid_sequence +) { + return { + {"primaryKey", primary_key}, + {"segment1", {{"sequence", nucleotide_sequence}, {"insertions", nlohmann::json::array()}}}, + {"gene1", {{"sequence", amino_acid_sequence}, {"insertions", nlohmann::json::array()}}}, + {"gene2", nullptr} + }; +} + +// segment1 reference: ATGCN (length 5) +// gene1 reference: M* (length 2; * = STOP codon, a definitive AA symbol) +// +// Sequences (segment1 | gene1): +// seq_ref: ATGCN | M* (0 nuc diffs; 0 AA diffs from their respective references) +// seq_1mut: CTGCN | C* (1 nuc diff: pos1 A→C; 1 AA diff: pos1 M→C) +// seq_2mut: CTCCN | M* (2 nuc diffs: pos1, pos3; 0 AA diffs) +// seq_3mut: CTCTN | M* (3 nuc diffs: pos1, pos3, pos4; 0 AA diffs) +// seq_all_n: NNNNN | M* (0 conservative nuc diffs; 0 AA diffs) +// seq_mixed_amb: RTGCN | M* (0 conservative nuc diffs; 0 AA diffs) + +const auto DATABASE_CONFIG = R"( +defaultNucleotideSequence: "segment1" +schema: + instanceName: "test" + metadata: + - name: "primaryKey" + type: "string" + primaryKey: "primaryKey" +)"; + +const auto REFERENCE_GENOMES = + ReferenceGenomes{{{"segment1", "ATGCN"}}, {{"gene1", "M*"}, {"gene2", "M*"}}}; + +// Note: reference has N at position 5, so that position is always skipped in profile comparisons. +// Effective profile length for distance counting = 4 positions (ATGC). + +const QueryTestData TEST_DATA{ + .ndjson_input_data = + { + createData("seq_ref", "ATGCN", "M*"), // 0 diffs from reference profile + createData("seq_1mut", "CTGCN", "C*"), // 1 diff from reference profile (pos1) + createData("seq_2mut", "CTCCN", "M*"), // 2 diffs from reference profile (pos1, pos3) + createData("seq_3mut", "CTCTN", "M*"), // 3 diffs from reference profile (pos1,3,4) + createData("seq_all_n", "NNNNN", "M*"), // 0 conservative diffs (N is compatible) + createData("seq_mixed_amb", "RTGCN", "M*") // 0 diffs: R∈AMBIGUITY[A], conservative + }, + .database_config = DATABASE_CONFIG, + .reference_genomes = REFERENCE_GENOMES, + .without_unaligned_sequences = true +}; + +// ------ Tests using mutations input method ------ + +// Profile = reference (empty mutations list), distance=0 +// Should match: seq_ref (0 diffs), seq_all_n (0 diffs), seq_mixed_amb (0 diffs, R compatible w/ A) +const QueryTestScenario MUTATIONS_DISTANCE_0 = { + .name = "MUTATIONS_DISTANCE_0", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "NucleotideMutationProfile", + "distance": 0, + "mutations": [] + } + })"), + .expected_query_result = nlohmann::json::parse( + R"([{"primaryKey":"seq_ref"},{"primaryKey":"seq_all_n"},{"primaryKey":"seq_mixed_amb"}])" + ) +}; + +// Profile = reference, distance=1 → matches 0 or 1 difference +const QueryTestScenario MUTATIONS_DISTANCE_1 = { + .name = "MUTATIONS_DISTANCE_1", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "NucleotideMutationProfile", + "distance": 1, + "mutations": [] + } + })"), + .expected_query_result = nlohmann::json::parse( + R"([{"primaryKey":"seq_ref"},{"primaryKey":"seq_1mut"},{"primaryKey":"seq_all_n"},{"primaryKey":"seq_mixed_amb"}])" + ) +}; + +// Profile = reference, distance=2 → matches 0, 1, or 2 differences +const QueryTestScenario MUTATIONS_DISTANCE_2 = { + .name = "MUTATIONS_DISTANCE_2", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "NucleotideMutationProfile", + "distance": 2, + "mutations": [] + } + })"), + .expected_query_result = nlohmann::json::parse( + R"([{"primaryKey":"seq_ref"},{"primaryKey":"seq_1mut"},{"primaryKey":"seq_2mut"},{"primaryKey":"seq_all_n"},{"primaryKey":"seq_mixed_amb"}])" + ) +}; + +// Profile with explicit mutation C at 1-based position 1, distance=0 +// Profile = CTGCN → should match seq_1mut (which has CTGCN) and seq_all_n +// seq_mixed_amb does not match, because C is not in {A,G}=R +const QueryTestScenario MUTATIONS_WITH_PROFILE_DISTANCE_0 = { + .name = "MUTATIONS_WITH_PROFILE_DISTANCE_0", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "NucleotideMutationProfile", + "distance": 0, + "mutations": [{"position": 1, "symbol": "C"}] + } + })"), + .expected_query_result = + nlohmann::json::parse(R"([{"primaryKey":"seq_1mut"},{"primaryKey":"seq_all_n"}])") +}; + +// ------ Tests using querySequence input method ------ + +// querySequence = "ATGCN" (same as reference), distance=0 +// Same as MUTATIONS_DISTANCE_0 but N at position 5 is skipped (profile has N = SYMBOL_MISSING) +const QueryTestScenario QUERY_SEQUENCE_DISTANCE_0 = { + .name = "QUERY_SEQUENCE_DISTANCE_0", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "NucleotideMutationProfile", + "distance": 0, + "querySequence": "ATGCN" + } + })"), + .expected_query_result = nlohmann::json::parse( + R"([{"primaryKey":"seq_ref"},{"primaryKey":"seq_all_n"},{"primaryKey":"seq_mixed_amb"}])" + ) +}; + +// querySequence wrong length +const QueryTestScenario QUERY_SEQUENCE_WRONG_LENGTH = { + .name = "QUERY_SEQUENCE_WRONG_LENGTH", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "NucleotideMutationProfile", + "distance": 0, + "querySequence": "ATG" + } + })"), + .expected_error_message = + "querySequence length 3 does not match the reference sequence length 5 for Nucleotide " + "MutationProfile" +}; + +// ------ Tests using sequenceId input method ------ + +// Use seq_1mut as profile (CTGCN), distance=0 +// Should match: seq_1mut (exact match), seq_all_n (N compatible with everything) +const QueryTestScenario SEQUENCE_ID_DISTANCE_0 = { + .name = "SEQUENCE_ID_DISTANCE_0", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "NucleotideMutationProfile", + "distance": 0, + "sequenceId": "seq_1mut" + } + })"), + .expected_query_result = + nlohmann::json::parse(R"([{"primaryKey":"seq_1mut"},{"primaryKey":"seq_all_n"}])") +}; + +// sequenceId not found +const QueryTestScenario SEQUENCE_ID_NOT_FOUND = { + .name = "SEQUENCE_ID_NOT_FOUND", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "NucleotideMutationProfile", + "distance": 0, + "sequenceId": "nonexistent_id" + } + })"), + .expected_error_message = + "No sequence found with primary key 'nonexistent_id' in Nucleotide MutationProfile" +}; + +// No input method provided +const QueryTestScenario NO_INPUT_METHOD = { + .name = "NO_INPUT_METHOD", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "NucleotideMutationProfile", + "distance": 0 + } + })"), + .expected_error_message = + "Exactly one of 'querySequence', 'sequenceId', or 'mutations' must be provided in a " + "Nucleotide MutationProfile expression, but 0 were provided" +}; + +// Two input methods provided +const QueryTestScenario TWO_INPUT_METHODS = { + .name = "TWO_INPUT_METHODS", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "NucleotideMutationProfile", + "distance": 0, + "mutations": [], + "querySequence": "ATGCN" + } + })"), + .expected_error_message = + "Exactly one of 'querySequence', 'sequenceId', or 'mutations' must be provided in a " + "Nucleotide MutationProfile expression, but 2 were provided" +}; + +// ------ AminoAcid tests ------ +// Only seq_1mut differs from the gene1 reference (M→C at pos1). All other rows have "M*". + +// Profile = gene1 reference (empty mutations list), distance=0 +// Matches every row whose AA sequence == "M*"; excludes seq_1mut ("C*", 1 diff). +const QueryTestScenario AA_MUTATIONS_REFERENCE_DISTANCE_0 = { + .name = "AA_MUTATIONS_REFERENCE_DISTANCE_0", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "AminoAcidMutationProfile", + "sequenceName": "gene1", + "distance": 0, + "mutations": [] + } + })"), + .expected_query_result = nlohmann::json::parse( + R"([{"primaryKey":"seq_ref"},{"primaryKey":"seq_2mut"},{"primaryKey":"seq_3mut"},{"primaryKey":"seq_all_n"},{"primaryKey":"seq_mixed_amb"}])" + ) +}; + +// Profile = gene1 reference, distance=1 +// seq_1mut has exactly 1 AA diff (M→C) which is ≤ 1 → all 6 rows match. +const QueryTestScenario AA_MUTATIONS_REFERENCE_DISTANCE_1 = { + .name = "AA_MUTATIONS_REFERENCE_DISTANCE_1", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "AminoAcidMutationProfile", + "sequenceName": "gene1", + "distance": 1, + "mutations": [] + } + })"), + .expected_query_result = nlohmann::json::parse( + R"([{"primaryKey":"seq_ref"},{"primaryKey":"seq_1mut"},{"primaryKey":"seq_2mut"},{"primaryKey":"seq_3mut"},{"primaryKey":"seq_all_n"},{"primaryKey":"seq_mixed_amb"}])" + ) +}; + +// Profile = "C*" via mutations=[{pos:1, sym:"C"}], distance=0 +// Only seq_1mut has C at pos1; all others have M (not compatible with C) → only seq_1mut matches. +const QueryTestScenario AA_MUTATIONS_WITH_PROFILE_DISTANCE_0 = { + .name = "AA_MUTATIONS_WITH_PROFILE_DISTANCE_0", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "AminoAcidMutationProfile", + "sequenceName": "gene1", + "distance": 0, + "mutations": [{"position": 1, "symbol": "C"}] + } + })"), + .expected_query_result = nlohmann::json::parse(R"([{"primaryKey":"seq_1mut"}])") +}; + +// querySequence = "M*" (same as gene1 reference), distance=0 — exercises querySequence parsing for +// AA +const QueryTestScenario AA_QUERY_SEQUENCE_DISTANCE_0 = { + .name = "AA_QUERY_SEQUENCE_DISTANCE_0", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "AminoAcidMutationProfile", + "sequenceName": "gene1", + "distance": 0, + "querySequence": "M*" + } + })"), + .expected_query_result = nlohmann::json::parse( + R"([{"primaryKey":"seq_ref"},{"primaryKey":"seq_2mut"},{"primaryKey":"seq_3mut"},{"primaryKey":"seq_all_n"},{"primaryKey":"seq_mixed_amb"}])" + ) +}; + +// sequenceId = "seq_1mut" → reconstructed profile is "C*", distance=0 +// Only seq_1mut itself has "C*" → only seq_1mut matches. +const QueryTestScenario AA_SEQUENCE_ID_DISTANCE_0 = { + .name = "AA_SEQUENCE_ID_DISTANCE_0", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "AminoAcidMutationProfile", + "sequenceName": "gene1", + "distance": 0, + "sequenceId": "seq_1mut" + } + })"), + .expected_query_result = nlohmann::json::parse(R"([{"primaryKey":"seq_1mut"}])") +}; + +// sequenceName refers to a gene that does not exist in the schema → error +const QueryTestScenario AA_INVALID_SEQUENCE_NAME = { + .name = "AA_INVALID_SEQUENCE_NAME", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "AminoAcidMutationProfile", + "sequenceName": "nonexistent_gene", + "distance": 0, + "mutations": [] + } + })"), + .expected_error_message = + "Database does not contain the AminoAcid Sequence with name: 'nonexistent_gene'" +}; + +// No sequenceName provided and no default AA sequence in the config → error +const QueryTestScenario AA_NO_SEQUENCE_NAME = { + .name = "AA_NO_SEQUENCE_NAME", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "AminoAcidMutationProfile", + "distance": 0, + "mutations": [] + } + })"), + .expected_error_message = + "Database does not have a default sequence name for AminoAcid sequences. " + "You need to provide the sequence name with the AminoAcid MutationProfile filter." +}; + +// Mutation position is outside the bounds of the specified AA sequence → error +const QueryTestScenario AA_MUTATION_OUT_OF_BOUNDS = { + .name = "AA_MUTATION_OUT_OF_BOUNDS", + .query = nlohmann::json::parse(R"({ + "action": {"type": "Details", "fields": ["primaryKey"]}, + "filterExpression": { + "type": "AminoAcidMutationProfile", + "sequenceName": "gene1", + "distance": 0, + "mutations": [{"position": 123456, "symbol": "C"}] + } + })"), + .expected_error_message = + "AminoAcid MutationProfile mutation position 123456 is out of bounds (reference length 2)" +}; + +} // namespace + +QUERY_TEST( + MutationProfile, + TEST_DATA, + ::testing::Values( + MUTATIONS_DISTANCE_0, + MUTATIONS_DISTANCE_1, + MUTATIONS_DISTANCE_2, + MUTATIONS_WITH_PROFILE_DISTANCE_0, + QUERY_SEQUENCE_DISTANCE_0, + QUERY_SEQUENCE_WRONG_LENGTH, + SEQUENCE_ID_DISTANCE_0, + SEQUENCE_ID_NOT_FOUND, + NO_INPUT_METHOD, + TWO_INPUT_METHODS + ) +); + +QUERY_TEST( + AminoAcidMutationProfile, + TEST_DATA, + ::testing::Values( + AA_MUTATIONS_REFERENCE_DISTANCE_0, + AA_MUTATIONS_REFERENCE_DISTANCE_1, + AA_MUTATIONS_WITH_PROFILE_DISTANCE_0, + AA_QUERY_SEQUENCE_DISTANCE_0, + AA_SEQUENCE_ID_DISTANCE_0, + AA_INVALID_SEQUENCE_NAME, + AA_NO_SEQUENCE_NAME, + AA_MUTATION_OUT_OF_BOUNDS + ) +);