diff --git a/R/OHDSIAssistant/R/strategus_cohort_methods_shell.R b/R/OHDSIAssistant/R/strategus_cohort_methods_shell.R index 90404cc..77e0022 100644 --- a/R/OHDSIAssistant/R/strategus_cohort_methods_shell.R +++ b/R/OHDSIAssistant/R/strategus_cohort_methods_shell.R @@ -1561,8 +1561,8 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me outcomeCohortIds = NULL, comparisonLabel = NULL, topK = 20, - maxResults = 20, - candidateLimit = 20, + maxResults = 3, + candidateLimit = 10, indexDir = Sys.getenv("PHENOTYPE_INDEX_DIR", "data/phenotype_index"), negativeControlConceptSetId = NULL, includeCovariateConceptSetId = NULL, @@ -1586,10 +1586,121 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me if (!dir.exists(path)) dir.create(path, recursive = TRUE, showWarnings = FALSE) } + compact_dialogue_context <- function(value) { + if (!is.list(value) || length(value) == 0) return(list()) + keep <- lapply(value, function(item) { + if (is.null(item)) return(FALSE) + if (is.character(item) && length(item) == 1 && !nzchar(trimws(item))) return(FALSE) + if (is.atomic(item) && length(item) == 0) return(FALSE) + if (is.list(item) && length(item) == 0) return(FALSE) + TRUE + }) + keep_idx <- which(vapply(keep, isTRUE, logical(1))) + if (length(keep_idx) == 0) return(list()) + value[keep_idx] + } + + dialogue_state <- new.env(parent = emptyenv()) + dialogue_state$current_step <- "" + dialogue_state$current_role <- "" + dialogue_state$current_context <- list() + + set_dialogue_context <- function(step = "", role = "", context = list()) { + dialogue_state$current_step <- as.character(step %||% "") + dialogue_state$current_role <- as.character(role %||% "") + dialogue_state$current_context <- compact_dialogue_context(context %||% list()) + invisible(NULL) + } + + render_workflow_dialogue <- function(response) { + core <- response$dialogue %||% response + cat(" +== OHDSI Guidance == +") + answer <- as.character(core$answer %||% "") + if (nzchar(trimws(answer))) { + cat(answer, " +") + } else { + cat("No contextual guidance was returned. +") + } + guidance <- core$current_step_guidance %||% list() + if (length(guidance) > 0) { + cat("Current step guidance: +") + for (item in guidance) cat(sprintf(" - %s +", item)) + } + cautions <- core$cautions %||% list() + if (length(cautions) > 0) { + cat("Cautions: +") + for (item in cautions) cat(sprintf(" - %s +", item)) + } + next_actions <- core$suggested_next_actions %||% list() + if (length(next_actions) > 0) { + cat("Suggested next actions: +") + for (item in next_actions) cat(sprintf(" - %s +", item)) + } + cat(" +") + } + + handle_workflow_dialogue_command <- function(entered) { + trimmed <- trimws(as.character(entered %||% "")) + if (!isTRUE(interactive) || !startsWith(trimmed, "/ohdsi")) { + return(list(handled = FALSE, value = entered)) + } + question <- trimws(sub("^/ohdsi", "", trimmed)) + if (!nzchar(question)) { + cat("Enter a question after /ohdsi. Example: /ohdsi why is washout important here? +") + return(list(handled = TRUE, value = "")) + } + if (!ensure_acp_ready(acpUrl)) { + cat("ACP bridge unavailable. Connect ACP before using /ohdsi. +") + return(list(handled = TRUE, value = "")) + } + body <- list( + user_prompt = question, + study_intent = as.character(studyIntent %||% ""), + workflow_type = "cohort_methods", + current_step = as.character(dialogue_state$current_step %||% ""), + current_role = as.character(dialogue_state$current_role %||% ""), + current_context = compact_dialogue_context(dialogue_state$current_context %||% list()) + ) + message("Calling ACP flow: workflow_context_dialogue") + response <- tryCatch( + .acp_post("/flows/workflow_context_dialogue", body), + error = function(e) list(status = "error", error = conditionMessage(e)) + ) + if (!identical(response$status %||% "", "ok")) { + cat(sprintf("OHDSI guidance failed: %s +", as.character(response$error %||% "unknown error"))) + return(list(handled = TRUE, value = "")) + } + render_workflow_dialogue(response) + list(handled = TRUE, value = "") + } + + readline_with_dialogue <- function(prompt) { + repeat { + entered <- readline(prompt) + handled <- handle_workflow_dialogue_command(entered) + if (isTRUE(handled$handled)) next + return(handled$value) + } + } + prompt_yesno <- function(prompt, default = TRUE) { if (!isTRUE(interactive)) return(default) suffix <- if (default) "[Y/n]" else "[y/N]" - resp <- tolower(trimws(readline(sprintf("%s %s ", prompt, suffix)))) + resp <- tolower(trimws(readline_with_dialogue(sprintf("%s %s ", prompt, suffix)))) if (resp == "") return(default) if (resp %in% c("y", "yes")) return(TRUE) if (resp %in% c("n", "no")) return(FALSE) @@ -1636,7 +1747,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me label ) } - entered <- trimws(readline(prompt)) + entered <- trimws(readline_with_dialogue(prompt)) candidate <- if (nzchar(entered)) entered else if (nchar(current, type = "chars") <= max_chars) current else "" if (!nzchar(candidate)) { cat(sprintf("Analysis label must be %s characters or fewer. Please enter a shorter label.\n", max_chars)) @@ -1771,7 +1882,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me if (length(ids) > 1) stop(sprintf("%s must contain exactly one cohort ID.", label)) if (length(ids) == 1) return(as.integer(ids[[1]])) if (!isTRUE(interactive)) stop(sprintf("Missing %s.", label)) - entered <- trimws(readline(sprintf("%s cohort ID: ", label))) + entered <- trimws(readline_with_dialogue(sprintf("%s cohort ID: ", label))) ids <- parse_ids(entered) ids <- ids[!is.na(ids)] if (length(ids) != 1) stop(sprintf("%s must contain exactly one cohort ID.", label)) @@ -1783,7 +1894,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me ids <- unique(ids[!is.na(ids)]) if (length(ids) > 0) return(as.integer(ids)) if (!isTRUE(interactive)) stop(sprintf("Missing %s.", label)) - entered <- trimws(readline(sprintf("%s cohort IDs (comma-separated): ", label))) + entered <- trimws(readline_with_dialogue(sprintf("%s cohort IDs (comma-separated): ", label))) ids <- parse_ids(entered) ids <- unique(ids[!is.na(ids)]) if (length(ids) == 0) stop(sprintf("%s must include at least one cohort ID.", label)) @@ -1796,7 +1907,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me if (length(ids) > 1) stop(sprintf("%s must contain at most one ID.", label)) if (length(ids) == 1) return(validate_positive_integer(ids[[1]], label)) if (!isTRUE(interactive)) return(NULL) - entered <- trimws(readline(prompt %||% sprintf("%s ID [optional]: ", label))) + entered <- trimws(readline_with_dialogue(prompt %||% sprintf("%s ID [optional]: ", label))) if (!nzchar(entered)) return(NULL) ids <- parse_ids(entered) ids <- unique(ids[!is.na(ids)]) @@ -1815,7 +1926,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me repeat { prompt_text <- trimws(as.character(prompt %||% "")) rendered_prompt <- if (nzchar(prompt_text)) sprintf("%s %s ", prompt_text, suffix) else sprintf("%s ", suffix) - entered <- tolower(trimws(readline(rendered_prompt))) + entered <- tolower(trimws(readline_with_dialogue(rendered_prompt))) if (entered == "") return(default) if (entered %in% options$yes) return(TRUE) if (entered %in% options$no) return(FALSE) @@ -1827,7 +1938,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me if (!isTRUE(interactive)) return(default) repeat { default_value <- if (is.null(default)) "" else as.character(default) - entered <- trimws(readline(sprintf("%s [%s]: ", prompt, default_value))) + entered <- trimws(readline_with_dialogue(sprintf("%s [%s]: ", prompt, default_value))) if (entered == "" && !is.null(default)) return(default) if (entered == "") { cat("A value is required.\n") @@ -1851,7 +1962,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me repeat { prompt_text <- trimws(as.character(prompt %||% "")) rendered_prompt <- if (nzchar(prompt_text)) sprintf("%s%s: ", prompt_text, prompt_suffix) else sprintf("%s: ", prompt_suffix) - entered <- trimws(readline(rendered_prompt)) + entered <- trimws(readline_with_dialogue(rendered_prompt)) if (entered == "") { if (allow_null) return(NULL) if (is.null(default)) { @@ -1886,7 +1997,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me repeat { prompt_text <- trimws(as.character(prompt %||% "")) rendered_prompt <- if (nzchar(prompt_text)) sprintf("%s%s: ", prompt_text, prompt_suffix) else sprintf("%s: ", prompt_suffix) - entered <- trimws(readline(rendered_prompt)) + entered <- trimws(readline_with_dialogue(rendered_prompt)) if (entered == "") { if (is.null(default)) { cat("A value is required.\n") @@ -1928,7 +2039,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me repeat { prompt_text <- trimws(as.character(prompt %||% "")) rendered_prompt <- if (nzchar(prompt_text)) sprintf("%s [%s]: ", prompt_text, default) else sprintf("[%s]: ", default) - entered <- trimws(readline(rendered_prompt)) + entered <- trimws(readline_with_dialogue(rendered_prompt)) if (entered == "") { return(default) } @@ -1950,7 +2061,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me collected <- integer(0) repeat { - entered <- trimws(readline("Outcome cohort ID: ")) + entered <- trimws(readline_with_dialogue("Outcome cohort ID: ")) parsed <- parse_ids(entered) parsed <- parsed[!is.na(parsed)] if (length(parsed) != 1) { @@ -2015,6 +2126,15 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me first_nonempty(rec$cohortName, rec$phenotype_name, rec$name, "") } + recommendation_identifier <- function(rec) { + first_nonempty( + as.character(rec$cohortId %||% ""), + as.character(rec$phenotype_id %||% ""), + as.character(rec$id %||% ""), + "" + ) + } + recommendation_cohort_id <- function(rec) { direct <- suppressWarnings(as.integer(rec$cohortId %||% NA_integer_)) if (!is.na(direct)) return(direct) @@ -2025,6 +2145,36 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me suppressWarnings(as.integer(phenotype_id)) } + recommendation_is_ohdsi_computable <- function(rec) { + identifier <- recommendation_identifier(rec) + if (!nzchar(identifier)) return(FALSE) + if (grepl("^[0-9]+$", identifier)) return(TRUE) + grepl("^ohdsi:[0-9]+$", identifier) + } + + recommendation_id_label <- function(rec) { + cohort_id <- recommendation_cohort_id(rec) + if (!is.na(cohort_id)) return(as.character(cohort_id)) + identifier <- recommendation_identifier(rec) + if (nzchar(identifier)) return(identifier) + "?" + } + + unsupported_recommendation_message <- function(rec, role_label) { + identifier <- recommendation_identifier(rec) + if (!nzchar(identifier)) identifier <- "unknown" + sprintf( + paste( + "Selected %s phenotype %s (%s), but this workflow can only continue with a computable OHDSI cohort definition.", + "Descriptive phenotypes such as CIPHER recommendations are not yet convertible to executable cohort JSON in the shell.", + "Choose an OHDSI-backed phenotype for now." + ), + tolower(role_label), + recommendation_name(rec), + identifier + ) + } + lookup_catalog_value <- function(cohort_id, catalog_df, field = "name", fallback = NULL) { idx <- which(catalog_df$cohortId == as.integer(cohort_id))[1] if (!is.na(idx)) { @@ -2086,7 +2236,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me prompt_statement <- function(label, default = NULL) { if (!isTRUE(interactive)) return(default) default_value <- trimws(as.character(default %||% "")) - entered <- readline(sprintf("%s statement [%s]: ", label, default_value)) + entered <- readline_with_dialogue(sprintf("%s statement [%s]: ", label, default_value)) if (nzchar(trimws(entered))) trimws(entered) else default_value } @@ -2109,6 +2259,10 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me collect_recommendation_selection <- function(recommendations, role_label, allow_multiple = FALSE) { if (length(recommendations) == 0) return(integer(0)) if (!isTRUE(interactive)) { + unsupported <- vapply(recommendations, function(rec) !isTRUE(recommendation_is_ohdsi_computable(rec)), logical(1)) + if (any(unsupported)) { + stop(unsupported_recommendation_message(recommendations[[which(unsupported)[1]]], role_label)) + } if (isTRUE(allow_multiple)) { return(as.integer(vapply(recommendations, recommendation_cohort_id, integer(1)))) } @@ -2117,9 +2271,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me labels <- vapply(seq_along(recommendations), function(i) { rec <- recommendations[[i]] - cohort_id <- recommendation_cohort_id(rec) - cohort_id_label <- if (is.na(cohort_id)) "?" else as.character(cohort_id) - sprintf("%s (ID %s)", recommendation_name(rec), cohort_id_label) + sprintf("%s (ID %s)", recommendation_name(rec), recommendation_id_label(rec)) }, character(1)) picks <- utils::select.list( labels, @@ -2127,6 +2279,14 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me title = sprintf("Select %s phenotype%s", tolower(role_label), if (isTRUE(allow_multiple)) "s" else "") ) if (!length(picks) || !any(nzchar(picks))) return(integer(0)) + selected_recs <- lapply(picks, function(label) { + idx <- which(labels == label)[1] + recommendations[[idx]] + }) + unsupported <- vapply(selected_recs, function(rec) !isTRUE(recommendation_is_ohdsi_computable(rec)), logical(1)) + if (any(unsupported)) { + stop(unsupported_recommendation_message(selected_recs[[which(unsupported)[1]]], role_label)) + } selected_ids <- vapply(picks, function(label) { idx <- which(labels == label)[1] recommendation_cohort_id(recommendations[[idx]]) @@ -2147,8 +2307,12 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me selected_cache_label = NULL, selected_cache_dir = NULL, cohort_method_cache = NULL, - incidence_cache = NULL) { + incidence_cache = NULL, + recommendation_role = NULL, + workflow_type = "cohort_methods", + exclude_metadata = NULL) { role_key <- tolower(role_label) + recommendation_role <- tolower(trimws(as.character(recommendation_role %||% role_key))) preferred_selected_ids <- normalize_selected_ids( preferred_selected_ids, sprintf("%s cohort ID%s", role_label, if (isTRUE(allow_multiple)) "s" else ""), @@ -2222,6 +2386,19 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me } } + set_dialogue_context( + paste0(role_key, "_recommendation"), + recommendation_role, + context = list( + statement = statement, + top_k = top_k, + max_results = max_results, + candidate_limit = candidate_limit, + workflow_type = workflow_type, + exclude_metadata = exclude_metadata + ) + ) + recommendation_response <- NULL recommendation_path <- output_path used_cached_recommendation <- FALSE @@ -2236,7 +2413,10 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me study_intent = statement, top_k = top_k, max_results = max_results, - candidate_limit = candidate_limit + candidate_limit = candidate_limit, + recommendation_role = recommendation_role, + workflow_type = workflow_type, + exclude_metadata = exclude_metadata ) message(sprintf("Calling ACP flow: phenotype_recommendation (%s)", role_key)) recommendation_response <- tryCatch( @@ -2250,15 +2430,39 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me recommendations_core <- recommendation_response$recommendations %||% recommendation_response recommendations <- recommendations_core$phenotype_recommendations %||% list() + no_candidate_reason <- as.character(recommendation_response$fallback_reason %||% recommendations_core$fallback_reason %||% "") + + if (isTRUE(interactive) && length(recommendations) == 0 && !is.null(recommendation_response)) { + cat(sprintf(" +== %s Phenotype Recommendations == +", role_label)) + if (identical(no_candidate_reason, "no_direct_role_match")) { + cat("No sufficiently direct computable OHDSI phenotype match was found for this cohort statement. +") + cat("Enter a cohort ID manually if you want to continue with a known cohort definition. +") + } else if (identical(no_candidate_reason, "no_viable_candidates_after_rerank")) { + cat("No viable phenotype candidates were identified from the current search results. +") + cat("Enter a cohort ID manually if you want to continue with a known cohort definition. +") + } else { + cat("No phenotype recommendations were returned. +") + cat("Enter a cohort ID manually if you want to continue with a known cohort definition. +") + } + } if (isTRUE(interactive) && length(recommendations) > 0) { cat(sprintf("\n== %s Phenotype Recommendations ==\n", role_label)) for (i in seq_along(recommendations)) { rec <- recommendations[[i]] - cohort_id <- recommendation_cohort_id(rec) - cohort_id_label <- if (is.na(cohort_id)) "?" else as.character(cohort_id) - cat(sprintf("%d. %s (ID %s)\n", i, recommendation_name(rec), cohort_id_label)) + cat(sprintf("%d. %s (ID %s)\n", i, recommendation_name(rec), recommendation_id_label(rec))) if (!is.null(rec$justification)) cat(sprintf(" %s\n", rec$justification)) + if (!isTRUE(recommendation_is_ohdsi_computable(rec))) { + cat(" Not directly computable in this workflow; descriptive phenotype conversion is not yet implemented.\n") + } } ok_any <- prompt_yesno(sprintf("Are any of these acceptable for the %s?", role_key), default = TRUE) if (!ok_any && ensure_acp_ready(acpUrl)) { @@ -2271,7 +2475,10 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me top_k = top_k, max_results = max_results, candidate_limit = candidate_limit, - candidate_offset = candidate_limit + candidate_offset = candidate_limit, + recommendation_role = recommendation_role, + workflow_type = workflow_type, + exclude_metadata = exclude_metadata ) message(sprintf("Calling ACP flow: phenotype_recommendation (%s window 2)", role_key)) recommendation_response <- tryCatch( @@ -2286,10 +2493,11 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me cat(sprintf("\n== %s Phenotype Recommendations (window 2) ==\n", role_label)) for (i in seq_along(recommendations)) { rec <- recommendations[[i]] - cohort_id <- recommendation_cohort_id(rec) - cohort_id_label <- if (is.na(cohort_id)) "?" else as.character(cohort_id) - cat(sprintf("%d. %s (ID %s)\n", i, recommendation_name(rec), cohort_id_label)) + cat(sprintf("%d. %s (ID %s)\n", i, recommendation_name(rec), recommendation_id_label(rec))) if (!is.null(rec$justification)) cat(sprintf(" %s\n", rec$justification)) + if (!isTRUE(recommendation_is_ohdsi_computable(rec))) { + cat(" Not directly computable in this workflow; descriptive phenotype conversion is not yet implemented.\n") + } } ok_any <- prompt_yesno(sprintf("Are any of these acceptable for the %s?", role_key), default = TRUE) } @@ -2345,9 +2553,32 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me ) } + resolve_index_definition_path <- function(source_id, index_def_dir) { + source_text <- trimws(as.character(source_id %||% "")) + candidates <- character(0) + if (nzchar(source_text)) { + candidates <- c(candidates, file.path(index_def_dir, sprintf("%s.json", source_text))) + if (grepl("^[0-9]+$", source_text)) { + candidates <- c(candidates, file.path(index_def_dir, sprintf("ohdsi__%s.json", source_text))) + } + if (grepl("^[A-Za-z0-9_]+:[A-Za-z0-9_.-]+$", source_text)) { + candidates <- c( + candidates, + file.path(index_def_dir, sprintf("%s.json", gsub(":", "__", source_text, fixed = TRUE))) + ) + } + } + candidates <- unique(candidates[nzchar(candidates)]) + hit <- candidates[file.exists(candidates)][1] + if (length(hit) == 0 || is.na(hit) || !nzchar(hit)) return(NA_character_) + hit + } + copy_cohort_json_multi <- function(source_id, dest_id, dest_dirs, index_def_dir) { - src <- file.path(index_def_dir, sprintf("%s.json", source_id)) - if (!file.exists(src)) stop(sprintf("Cohort JSON not found: %s", src)) + src <- resolve_index_definition_path(source_id, index_def_dir) + if (is.na(src) || !file.exists(src)) { + stop(sprintf("Cohort JSON not found for source %s in %s", source_id, index_def_dir)) + } dests <- character(0) for (dest_dir in dest_dirs) { ensure_dir(dest_dir) @@ -2457,6 +2688,15 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me for (cid in names(response_by_id)) { if (identical(cid, "_meta")) next + set_dialogue_context( + paste0(role_key, "_improvements"), + role_key, + context = list( + role_statement = role_statement, + cohort_id = as.integer(cid), + study_intent = studyIntent + ) + ) resp <- response_by_id[[cid]] core <- resp$full_result %||% resp items <- core$phenotype_improvements %||% list() @@ -2631,16 +2871,16 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me } assert_cohort_json_exists <- function(source_id, index_def_dir, label) { - src <- file.path(index_def_dir, sprintf("%s.json", source_id)) - if (!file.exists(src)) { - stop(sprintf("%s cohort JSON not found: %s", label, src)) + src <- resolve_index_definition_path(source_id, index_def_dir) + if (is.na(src) || !file.exists(src)) { + stop(sprintf("%s cohort JSON not found for source %s in %s", label, source_id, index_def_dir)) } invisible(src) } cohort_json_exists <- function(source_id, index_def_dir) { - src <- file.path(index_def_dir, sprintf("%s.json", source_id)) - file.exists(src) + src <- resolve_index_definition_path(source_id, index_def_dir) + !is.na(src) && file.exists(src) } validate_positive_integer <- function(value, label) { @@ -2953,7 +3193,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me collect_text_value <- function(value, prompt, default = "") { current <- value %||% default if (!isTRUE(interactive)) return(current) - entered <- readline(sprintf("%s [%s]: ", prompt, current)) + entered <- readline_with_dialogue(sprintf("%s [%s]: ", prompt, current)) if (nzchar(trimws(entered))) entered else current } @@ -2969,7 +3209,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me } repeat { - entered <- trimws(readline(sprintf("Select option [%s]: ", match(current, choices)))) + entered <- trimws(readline_with_dialogue(sprintf("Select option [%s]: ", match(current, choices)))) if (!nzchar(entered)) return(current) option_idx <- suppressWarnings(as.integer(entered)) if (!is.na(option_idx) && option_idx >= 1 && option_idx <= length(choices)) { @@ -2986,7 +3226,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me if (!isTRUE(interactive)) return(current) repeat { - entered <- trimws(readline(sprintf("%s [%s]: ", prompt, current))) + entered <- trimws(readline_with_dialogue(sprintf("%s [%s]: ", prompt, current))) if (!nzchar(entered)) return(current) parsed <- suppressWarnings(as.integer(entered)) if (!is.na(parsed) && (is.null(min_value) || parsed >= min_value)) { @@ -3006,7 +3246,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me if (!isTRUE(interactive)) return(current) repeat { - entered <- trimws(readline(sprintf("%s [%s]: ", prompt, format(current, trim = TRUE, scientific = FALSE)))) + entered <- trimws(readline_with_dialogue(sprintf("%s [%s]: ", prompt, format(current, trim = TRUE, scientific = FALSE)))) if (!nzchar(entered)) return(current) parsed <- suppressWarnings(as.numeric(entered)) if (!is.na(parsed) && (is.null(min_value) || parsed >= min_value)) { @@ -3042,7 +3282,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me ) repeat { - entered <- tolower(trimws(readline("Press Enter after saving, or type 'r' to reopen the file: "))) + entered <- tolower(trimws(readline_with_dialogue("Press Enter after saving, or type 'r' to reopen the file: "))) if (identical(entered, "r")) { tryCatch( utils::file.edit(review_path), @@ -3466,7 +3706,8 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me default_intent <- studyIntent %||% cached_inputs$study_intent %||% "Compare a target exposure versus a comparator exposure on one or more outcomes using a cohort method design." if (isTRUE(interactive)) { - entered <- readline(sprintf("Study intent [%s]: ", default_intent)) + set_dialogue_context("study_intent", context = list(default_intent = default_intent)) + entered <- readline_with_dialogue(sprintf("Study intent [%s]: ", default_intent)) if (nzchar(trimws(entered))) { studyIntent <- entered } else { @@ -3524,7 +3765,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me default_selection <- paste(seq_along(defaults), collapse = ",") use_manual_outcome <- FALSE repeat { - entered <- trimws(readline(sprintf( + entered <- trimws(readline_with_dialogue(sprintf( "Keep outcome statements [%s] (comma-separated numbers, 0/none to enter manually, Enter keeps all): ", default_selection ))) @@ -3768,8 +4009,23 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me outcomeStatements <- outcome_statements_default outcomeStatement <- first_nonempty(outcomeStatements) } else { + set_dialogue_context("intent_split", "target", context = list( + target_statement = target_statement_default, + comparator_statement = comparator_statement_default, + outcome_statements = outcome_statements_default + )) targetStatement <- prompt_statement("Target", default = target_statement_default) + set_dialogue_context("intent_split", "comparator", context = list( + target_statement = targetStatement, + comparator_statement = comparator_statement_default, + outcome_statements = outcome_statements_default + )) comparatorStatement <- prompt_statement("Comparator", default = comparator_statement_default) + set_dialogue_context("intent_split", "outcome", context = list( + target_statement = targetStatement, + comparator_statement = comparatorStatement, + outcome_statements = outcome_statements_default + )) outcomeStatements <- prompt_outcome_statements(outcome_statements_default) outcomeStatement <- first_nonempty(outcomeStatements) } @@ -3859,7 +4115,10 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me cache_dir = incidence_selected_target_dir, label = "incidence target cohort selection" ) - ) + ), + recommendation_role = "target", + workflow_type = "cohort_methods", + exclude_metadata = list(executable_definition_status = list("codes_only")) ) targetCohortId <- if (length(target_rec$selected_ids) > 0) { @@ -3904,7 +4163,7 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me cohortIdBase <- cohortIdBase %||% cached_inputs$cohort_id_base %||% default_cohort_id_base cohortIdBase <- suppressWarnings(as.integer(cohortIdBase)) if (isTRUE(interactive)) { - entered <- trimws(readline(sprintf("Cohort ID base [%s]: ", cohortIdBase))) + entered <- trimws(readline_with_dialogue(sprintf("Cohort ID base [%s]: ", cohortIdBase))) if (nzchar(entered)) cohortIdBase <- suppressWarnings(as.integer(entered)) } cohortIdBase <- validate_positive_integer(cohortIdBase, "cohortIdBase") @@ -3968,7 +4227,10 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me cache_dir = NULL, label = NULL ) - ) + ), + recommendation_role = "comparator", + workflow_type = "cohort_methods", + exclude_metadata = list(executable_definition_status = list("codes_only")) ) comparatorCohortId <- if (length(comparator_rec$selected_ids) > 0) { @@ -4051,7 +4313,10 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me selected_cache_label = NULL, selected_cache_dir = NULL, cohort_method_cache = NULL, - incidence_cache = NULL + incidence_cache = NULL, + recommendation_role = "outcome", + workflow_type = "cohort_methods", + exclude_metadata = list(executable_definition_status = list("codes_only")) ) }) } else { @@ -4080,7 +4345,10 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me cache_dir = incidence_selected_outcome_dir, label = "incidence outcome cohort selection" ) - ) + ), + recommendation_role = "outcome", + workflow_type = "cohort_methods", + exclude_metadata = list(executable_definition_status = list("codes_only")) )) } outcome_recommendations <- lapply(seq_along(outcome_recs), function(i) { @@ -4500,12 +4768,36 @@ runStrategusCohortMethodsShell <- function(outputDir = "demo-strategus-cohort-me cat("The shell will collect each required section in order and ask for the analytic settings profile name last.\n") } + set_dialogue_context( + "analytic_settings_step_by_step", + "analytic_settings", + context = list( + study_intent = studyIntent, + target_statement = targetStatement, + comparator_statement = comparatorStatement, + outcome_statements = outcomeStatements, + comparison_label = comparisonLabel + ) + ) + step_by_step_io <- list( section_header = function(label) { + set_dialogue_context( + "analytic_settings_step_by_step", + "analytic_settings", + context = list( + section = label, + study_intent = studyIntent, + target_statement = targetStatement, + comparator_statement = comparatorStatement, + outcome_statements = outcomeStatements, + comparison_label = comparisonLabel + ) + ) cat(sprintf("\n[%s]\n", label)) }, text = function(prompt, default = "", allow_blank = FALSE) { - entered <- trimws(readline(sprintf("%s [%s]: ", prompt, default))) + entered <- trimws(readline_with_dialogue(sprintf("%s [%s]: ", prompt, default))) if (!nzchar(entered)) { if (isTRUE(allow_blank)) return(default) return(default) diff --git a/acp_agent/study_agent_acp/agent.py b/acp_agent/study_agent_acp/agent.py index 0f21611..e8cc7e3 100644 --- a/acp_agent/study_agent_acp/agent.py +++ b/acp_agent/study_agent_acp/agent.py @@ -19,6 +19,7 @@ PhenotypeRecommendationAdviceInput, PhenotypeRecommendationPlanInput, PhenotypeRecommendationsInput, + WorkflowContextDialogueInput, ) from study_agent_core.tools import ( cohort_methods_intent_split, @@ -29,6 +30,7 @@ phenotype_recommendation_plan, phenotype_recommendations, propose_concept_set_diff, + workflow_context_dialogue, ) from .llm_client import ( LLMCallResult, @@ -36,6 +38,7 @@ build_intent_split_prompt, build_recommendation_intent_facets_prompt, build_advice_prompt, + build_workflow_context_dialogue_prompt, build_keeper_concept_set_prompt, build_improvements_prompt, build_keeper_prompt, @@ -79,6 +82,7 @@ def __init__( "phenotype_improvements": phenotype_improvements, "phenotype_intent_split": phenotype_intent_split, "cohort_methods_intent_split": cohort_methods_intent_split, + "workflow_context_dialogue": workflow_context_dialogue, } self._schemas = { @@ -90,6 +94,7 @@ def __init__( "phenotype_improvements": PhenotypeImprovementsInput.model_json_schema(), "phenotype_intent_split": PhenotypeIntentSplitInput.model_json_schema(), "cohort_methods_intent_split": CohortMethodsIntentSplitInput.model_json_schema(), + "workflow_context_dialogue": WorkflowContextDialogueInput.model_json_schema(), "keeper_concept_sets_generate": KeeperConceptSetsGenerateInput.model_json_schema(), "keeper_profiles_generate": KeeperProfilesGenerateInput.model_json_schema(), } @@ -977,6 +982,8 @@ def _candidate_metadata_priority( intent_facets: Dict[str, Any], search_rank: int, study_intent: str = "", + recommendation_role: Optional[str] = None, + workflow_type: Optional[str] = None, ) -> Dict[str, Any]: topic_tokens = self._topic_tokens(intent_facets.get("condition_or_topic")) alias_tokens_list = [ @@ -1079,6 +1086,26 @@ def _candidate_metadata_priority( if intent_role == "medication_based": medication_text = any(token in combined_text for token in ("medication", "drug", "med codes", "insulin", "metformin", "antidiabetic", "meglitinide", "prescription", "therapy")) medication_signal = "has_code_system:medication" in signals_text or medication_text + recommendation_role_text = self._flatten_text(recommendation_role) + focus_stop_tokens = { + "new", "users", "user", "prior", "exposure", "index", "date", "days", "day", "before", + "after", "first", "prescription", "dispensing", "with", "without", "therapy", "treated", + "initiators", "initiator", "cohort", "patients", "patient", "use", "using", "the", "and", + "for", "from", "in", "medication", "drug", "newuser", "prioruse", "no", "of" + } + intent_focus_tokens = { + token for token in self._topic_tokens(study_intent) + if token not in focus_stop_tokens and not token.isdigit() + } + intent_focus_preview = sorted(intent_focus_tokens)[:6] + candidate_focus_text = " ".join( + part for part in ( + name_text, + self._flatten_text(row.get("primary_clinical_topic")), + retrieval_keywords, + ) if part + ) + candidate_focus_tokens = self._topic_tokens(candidate_focus_text) if "medication" in role or "drug" in role: score += 8.0 reasons.append({"kind": "role_match_medication", "delta": 8.0, "detail": row.get("phenotype_role") or ""}) @@ -1097,6 +1124,30 @@ def _candidate_metadata_priority( if any(token in role for token in ("procedure", "screen", "severity", "outcome")): score -= 3.5 reasons.append({"kind": "role_penalty_non_medication", "delta": -3.5, "detail": row.get("phenotype_role") or ""}) + if intent_focus_tokens and recommendation_role_text in {"target", "comparator"}: + focus_overlap = self._topic_overlap_score(intent_focus_tokens, candidate_focus_tokens) + if focus_overlap > 0.0: + delta = focus_overlap * 12.0 + score += delta + reasons.append({ + "kind": f"{recommendation_role_text}_focus_match", + "delta": round(delta, 4), + "detail": {"intent_tokens": intent_focus_preview}, + }) + else: + score -= 7.5 + reasons.append({ + "kind": f"{recommendation_role_text}_focus_mismatch", + "delta": -7.5, + "detail": {"intent_tokens": intent_focus_preview}, + }) + if workflow_type == "cohort_methods": + if recommendation_role_text == "comparator": + score += 1.5 + reasons.append({"kind": "workflow_comparator_bias", "delta": 1.5, "detail": workflow_type}) + elif recommendation_role_text == "target": + score += 1.0 + reasons.append({"kind": "workflow_target_bias", "delta": 1.0, "detail": workflow_type}) if care_setting and care_setting != "any": if candidate_care_setting and care_setting in candidate_care_setting: @@ -1153,11 +1204,74 @@ def _candidate_metadata_priority( "reasons": reasons, } + def _normalize_metadata_exclusions(self, exclude_metadata: Optional[Dict[str, Any]]) -> Dict[str, List[str]]: + normalized: Dict[str, List[str]] = {} + if not isinstance(exclude_metadata, dict): + return normalized + for key, raw_values in exclude_metadata.items(): + if key in (None, ""): + continue + values = raw_values if isinstance(raw_values, list) else [raw_values] + cleaned = [] + for value in values: + value_text = self._flatten_text(value) + if value_text: + cleaned.append(value_text) + if cleaned: + normalized[str(key)] = sorted(set(cleaned)) + return normalized + + def _candidate_exclusion_reason(self, row: Dict[str, Any], exclude_metadata: Dict[str, List[str]]) -> Optional[str]: + if not isinstance(row, dict) or not exclude_metadata: + return None + for key, disallowed_values in exclude_metadata.items(): + row_value = self._flatten_text(row.get(key)) + if row_value and row_value in set(disallowed_values or []): + return f"{key}={row_value}" + return None + + def _apply_metadata_exclusions( + self, + candidates: List[Dict[str, Any]], + exclude_metadata: Optional[Dict[str, Any]], + ) -> tuple[List[Dict[str, Any]], Dict[str, Any]]: + normalized = self._normalize_metadata_exclusions(exclude_metadata) + if not normalized: + return list(candidates or []), { + "requested": {}, + "excluded_ids": [], + "excluded_reasons": {}, + "remaining_count": len(candidates or []), + } + kept: List[Dict[str, Any]] = [] + excluded_ids: List[str] = [] + excluded_reasons: Dict[str, str] = {} + for row in candidates or []: + if not isinstance(row, dict): + continue + reason = self._candidate_exclusion_reason(row, normalized) + phenotype_id = str(row.get("phenotype_id") or "") + if reason: + if phenotype_id: + excluded_ids.append(phenotype_id) + excluded_reasons[phenotype_id] = reason + continue + kept.append(row) + diagnostics = { + "requested": normalized, + "excluded_ids": excluded_ids, + "excluded_reasons": excluded_reasons, + "remaining_count": len(kept), + } + return kept, diagnostics + def _rerank_planning_candidates( self, candidates: List[Dict[str, Any]], intent_facets: Dict[str, Any], study_intent: str = "", + recommendation_role: Optional[str] = None, + workflow_type: Optional[str] = None, ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: ranked_rows: List[tuple[float, float, int, Dict[str, Any], Dict[str, Any]]] = [] for index, row in enumerate(candidates): @@ -1168,6 +1282,8 @@ def _rerank_planning_candidates( intent_facets=intent_facets, search_rank=index, study_intent=study_intent, + recommendation_role=recommendation_role, + workflow_type=workflow_type, ) metadata_score = float(priority.get("metadata_score") or 0.0) retrieval_score = float(priority.get("retrieval_score") or 0.0) @@ -1435,6 +1551,9 @@ def run_phenotype_recommendation_flow( max_results: Optional[int] = None, candidate_limit: Optional[int] = None, candidate_offset: Optional[int] = None, + recommendation_role: Optional[str] = None, + workflow_type: Optional[str] = None, + exclude_metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: if not study_intent: return {"status": "error", "error": "missing study_intent"} @@ -1444,6 +1563,9 @@ def run_phenotype_recommendation_flow( top_k = int(os.getenv("LLM_RECOMMENDATION_TOP_K", "20")) if max_results is None: max_results = int(os.getenv("LLM_RECOMMENDATION_MAX_RESULTS", "3")) + recommendation_role = (str(recommendation_role or "").strip().lower() or None) + workflow_type = (str(workflow_type or "").strip().lower() or None) + exclude_metadata = exclude_metadata if isinstance(exclude_metadata, dict) else {} search_args = {"query": study_intent, "top_k": top_k} if candidate_offset is not None: @@ -1483,6 +1605,7 @@ def run_phenotype_recommendation_flow( } all_candidates = full.get("results") or [] + all_candidates, exclusion_diagnostics = self._apply_metadata_exclusions(all_candidates, exclude_metadata) if candidate_limit is None: candidate_limit = int(os.getenv("LLM_CANDIDATE_LIMIT", "5")) candidate_limit = max(0, int(candidate_limit)) @@ -1523,10 +1646,14 @@ def run_phenotype_recommendation_flow( "phenotype_recommendation: intent llm end " f"status={intent_llm_result.status} seconds={intent_llm_result.duration_seconds:.2f} parse_stage={intent_llm_result.parse_stage}" ) - intent_payload = llm_result_payload(intent_llm_result) or {} + intent_payload = llm_result_payload(intent_llm_result) or getattr(intent_llm_result, "parsed_content", None) or {} raw_intent_facets = intent_payload.get("intent_facets") intent_facets = raw_intent_facets if isinstance(raw_intent_facets, dict) else {} effective_intent_facets = self._effective_intent_facets(study_intent=study_intent, intent_facets=intent_facets) + if recommendation_role: + effective_intent_facets["recommendation_role"] = recommendation_role + if workflow_type: + effective_intent_facets["workflow_type"] = workflow_type raw_intent_notes = intent_payload.get("reasoning_notes") if isinstance(raw_intent_notes, list): intent_reasoning_notes = [str(note) for note in raw_intent_notes if note not in (None, "")] @@ -1567,6 +1694,8 @@ def run_phenotype_recommendation_flow( planning_hydrated, effective_intent_facets, study_intent=study_intent, + recommendation_role=recommendation_role, + workflow_type=workflow_type, ) planning_top_band = int(os.getenv("LLM_PLANNING_TOP_BAND", str(max((max_results or 0) + 2, 5)))) planning_top_band = max(1, min(planning_top_band, len(planning_ranked))) if planning_ranked else 0 @@ -1632,40 +1761,80 @@ def run_phenotype_recommendation_flow( ) selected_candidates = [row for row in hydrated_candidates[: max(0, max_results)] if isinstance(row, dict)] + strict_role_match_kind = None + role_match_candidate_ids: List[str] = [] + selected_role_match_ids: List[str] = [] + if ( + workflow_type == "cohort_methods" + and recommendation_role in {"target", "comparator"} + and self._flatten_text(effective_intent_facets.get("phenotype_role")) == "medication_based" + ): + strict_role_match_kind = f"{recommendation_role}_focus_match" + role_match_candidate_ids = [ + str(item.get("phenotype_id")) + for item in planning_rerank_diagnostics + if any( + isinstance(reason, dict) and reason.get("kind") == strict_role_match_kind + for reason in (item.get("reasons") or []) + ) + and item.get("phenotype_id") not in (None, "") + ] + if role_match_candidate_ids: + selected_candidates = [ + row for row in selected_candidates + if str(row.get("phenotype_id") or "") in set(role_match_candidate_ids) + ] + selected_role_match_ids = [str(row.get("phenotype_id") or "") for row in selected_candidates if row.get("phenotype_id") not in (None, "")] + else: + selected_candidates = [] compact_final_candidates = self._build_compact_final_candidates(selected_candidates) - self._log_debug("phenotype_recommendation: final prompt bundle fetch start") - prompt_bundle = self.call_tool( - name="phenotype_prompt_bundle", - arguments={"task": "phenotype_recommendations"}, - ) - self._log_debug(f"phenotype_recommendation: final prompt bundle fetch end status={prompt_bundle.get('status')}") - prompt_full = prompt_bundle.get("full_result") or {} - if prompt_bundle.get("status") != "ok" or prompt_full.get("error"): - return { - "status": "error", - "error": "phenotype_prompt_bundle_failed", - "details": prompt_bundle, - } + skip_final_reason = None + final_prompt = "" + if not compact_final_candidates: + skip_final_reason = "no_direct_role_match" if strict_role_match_kind else "no_viable_candidates_after_rerank" + self._log_debug(f"phenotype_recommendation: final llm skipped reason={skip_final_reason}") + llm_result = LLMCallResult( + status=f"skipped_{skip_final_reason}", + duration_seconds=0.0, + error=skip_final_reason, + parse_stage="skipped", + request_mode="chat_completions", + schema_valid=False, + ) + else: + self._log_debug("phenotype_recommendation: final prompt bundle fetch start") + prompt_bundle = self.call_tool( + name="phenotype_prompt_bundle", + arguments={"task": "phenotype_recommendations"}, + ) + self._log_debug(f"phenotype_recommendation: final prompt bundle fetch end status={prompt_bundle.get('status')}") + prompt_full = prompt_bundle.get("full_result") or {} + if prompt_bundle.get("status") != "ok" or prompt_full.get("error"): + return { + "status": "error", + "error": "phenotype_prompt_bundle_failed", + "details": prompt_bundle, + } - final_prompt = build_prompt( - overview=prompt_full.get("overview", ""), - spec=prompt_full.get("spec", ""), - output_schema=prompt_full.get("output_schema", {}), - study_intent=study_intent, - candidates=compact_final_candidates, - max_results=max_results, - task="phenotype_recommendations", - extra_dynamic={"intent_facets": effective_intent_facets}, - ) - self._log_debug( - f"phenotype_recommendation: final llm start prompt_chars={len(final_prompt)} candidate_count={len(compact_final_candidates)}" - ) - llm_result = self._call_llm(final_prompt, required_keys=["plan", "phenotype_recommendations"]) - self._log_debug( - "phenotype_recommendation: final llm end " - f"status={llm_result.status} seconds={llm_result.duration_seconds:.2f} parse_stage={llm_result.parse_stage}" - ) + final_prompt = build_prompt( + overview=prompt_full.get("overview", ""), + spec=prompt_full.get("spec", ""), + output_schema=prompt_full.get("output_schema", {}), + study_intent=study_intent, + candidates=compact_final_candidates, + max_results=max_results, + task="phenotype_recommendations", + extra_dynamic={"intent_facets": effective_intent_facets}, + ) + self._log_debug( + f"phenotype_recommendation: final llm start prompt_chars={len(final_prompt)} candidate_count={len(compact_final_candidates)}" + ) + llm_result = self._call_llm(final_prompt, required_keys=["plan", "phenotype_recommendations"]) + self._log_debug( + "phenotype_recommendation: final llm end " + f"status={llm_result.status} seconds={llm_result.duration_seconds:.2f} parse_stage={llm_result.parse_stage}" + ) catalog_rows = [] for row in selected_candidates: @@ -1707,8 +1876,12 @@ def run_phenotype_recommendation_flow( fallback_reason = None fallback_mode = None else: - fallback_reason = self._fallback_reason_for_llm(llm_result) if llm_payload is None else "llm_explanations_unusable" - fallback_mode = "stub" if llm_payload is None else core_result.get("mode") + if skip_final_reason: + fallback_reason = skip_final_reason + fallback_mode = core_result.get("mode") + else: + fallback_reason = self._fallback_reason_for_llm(llm_result) if llm_payload is None else "llm_explanations_unusable" + fallback_mode = "stub" if llm_payload is None else core_result.get("mode") if fallback_reason: self._log_debug(f"phenotype_recommendation: fallback chosen reason={fallback_reason} mode={fallback_mode}") @@ -1721,12 +1894,21 @@ def run_phenotype_recommendation_flow( diagnostics["planning_rerank"] = { "intent_facets_raw": intent_facets, "intent_facets_effective": effective_intent_facets, + "recommendation_role": recommendation_role, + "workflow_type": workflow_type, "candidate_count": len(planning_rerank_diagnostics), "planner_allowed_count": len(planning_candidates), "planner_allowed_ids": [row.get("phenotype_id") for row in planner_allowed_candidates if row.get("phenotype_id")], "shortlist_enforcement": shortlist_enforcement, "candidates": planning_rerank_diagnostics, } + diagnostics["candidate_exclusions"] = exclusion_diagnostics + diagnostics["role_match_gate"] = { + "required_kind": strict_role_match_kind, + "matched_candidate_ids": role_match_candidate_ids, + "selected_candidate_ids": selected_role_match_ids if selected_role_match_ids else [str(row.get("phenotype_id") or "") for row in selected_candidates if row.get("phenotype_id") not in (None, "")], + "skip_reason": skip_final_reason, + } diagnostics["final_validation"] = final_validation diagnostics["final_deterministic"] = final_deterministic diagnostics["final"] = final_diagnostics @@ -1742,6 +1924,8 @@ def run_phenotype_recommendation_flow( "fallback_mode": fallback_mode, "candidate_limit": candidate_limit, "candidate_offset": candidate_offset or 0, + "recommendation_role": recommendation_role, + "workflow_type": workflow_type, "candidate_count": len(hydrated_candidates), "candidate_count_before_truncation": pre_truncation_count, "plan_prompt_length_chars": len(plan_prompt), @@ -2156,6 +2340,72 @@ def run_cohort_methods_intent_split_flow( "diagnostics": self._llm_diagnostics(llm_result), } + def run_workflow_context_dialogue_flow( + self, + user_prompt: str, + study_intent: str = "", + workflow_type: str = "", + current_step: str = "", + current_role: str = "", + current_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + if not user_prompt: + return {"status": "error", "error": "missing user_prompt"} + if self._mcp_client is None: + return {"status": "error", "error": "MCP client unavailable"} + prompt_bundle = self.call_tool( + name="workflow_context_dialogue", + arguments={}, + ) + prompt_full = prompt_bundle.get("full_result") or {} + if prompt_bundle.get("status") != "ok" or prompt_full.get("error"): + return { + "status": "error", + "error": "workflow_context_dialogue_prompt_failed", + "details": prompt_bundle, + } + + prompt = build_workflow_context_dialogue_prompt( + overview=prompt_full.get("overview", ""), + spec=prompt_full.get("spec", ""), + output_schema=prompt_full.get("output_schema", {}), + user_prompt=user_prompt, + study_intent=study_intent, + workflow_type=workflow_type, + current_step=current_step, + current_role=current_role, + current_context=current_context or {}, + ) + self._log_debug("workflow_context_dialogue: calling LLM") + llm_result = self._call_llm( + prompt, + required_keys=["answer", "current_step_guidance", "cautions", "suggested_next_actions"], + ) + self._log_debug( + "workflow_context_dialogue: LLM returned " + f"status={llm_result.status} parse_stage={llm_result.parse_stage}" + ) + llm_payload = llm_result_payload(llm_result) + core_result = workflow_context_dialogue( + user_prompt=user_prompt, + study_intent=study_intent, + workflow_type=workflow_type, + current_step=current_step, + current_role=current_role, + current_context=current_context or {}, + llm_result=llm_payload, + ) + + return { + "status": "ok", + "llm_used": llm_payload is not None, + "llm_status": llm_result.status, + "fallback_reason": None if llm_payload is not None else self._fallback_reason_for_llm(llm_result), + "fallback_mode": None if llm_payload is not None else core_result.get("mode"), + "dialogue": core_result, + "diagnostics": self._llm_diagnostics(llm_result), + } + def run_phenotype_improvements_flow( self, protocol_text: str, diff --git a/acp_agent/study_agent_acp/demo_shell.py b/acp_agent/study_agent_acp/demo_shell.py index 01a2e67..b4b1b36 100644 --- a/acp_agent/study_agent_acp/demo_shell.py +++ b/acp_agent/study_agent_acp/demo_shell.py @@ -136,6 +136,11 @@ class DemoSession: last_keeper_concepts: Optional[Path] = None last_keeper_review: Optional[Path] = None last_phenotype_name: str = "" + current_study_intent: str = "" + current_workflow_type: str = "" + current_step: str = "" + current_role: str = "" + current_context: Dict[str, Any] = field(default_factory=dict) def _extract_nested(payload: Dict[str, Any], *keys: str) -> Any: @@ -247,6 +252,7 @@ def handle_line(self, line: str) -> bool: handler = { "/phenotype-intent-split": self._handle_intent_split, "/phenotype-recommend": self._handle_recommend, + "/ohdsi": self._handle_workflow_dialogue, "/vocab-search-standard": self._handle_vocab_search, "/vocab-phoebe-related": self._handle_phoebe_related, "/keeper-generate-concepts": self._handle_keeper_generate_concepts, @@ -271,6 +277,11 @@ def _build_parsers(self) -> Dict[str, ShellArgumentParser]: "Recommend phenotype candidates for a study intent.", self._configure_recommend_parser, ), + "/ohdsi": _build_parser( + "/ohdsi", + "Ask a contextual workflow question using the current session state.", + lambda parser: parser.add_argument("text", nargs=argparse.REMAINDER), + ), "/vocab-search-standard": _build_parser( "/vocab-search-standard", "Search standard OMOP concepts for one or more semicolon-separated terms.", @@ -363,6 +374,7 @@ def _print_help(self) -> None: print("Commands:") print("/phenotype-intent-split ") print("/phenotype-recommend [--top-k N] [--max-results N] [--candidate-limit N] ") + print("/ohdsi ") print("/vocab-search-standard [--domains CSV] [--classes CSV] [--limit N] [--provider NAME] ") print("/vocab-phoebe-related [--relationships CSV] [--provider NAME] ") print("/keeper-generate-concepts [--domains CSV] [--candidate-limit N] [--min-record-count N] [--vocab-provider NAME] [--phoebe-provider NAME] [--output PATH] ") @@ -430,6 +442,14 @@ def _handle_intent_split(self, argv: Sequence[str]) -> None: print("questions:") for question in questions: print(f"- {question}") + self.session.current_study_intent = study_intent + self.session.current_workflow_type = "phenotype" + self.session.current_step = "intent_split" + self.session.current_role = "" + self.session.current_context = { + "target_statement": split.get("target_statement", ""), + "outcome_statement": split.get("outcome_statement", ""), + } self._print_llm_summary(result) print(f"saved: {artifact}") @@ -459,9 +479,46 @@ def _handle_recommend(self, argv: Sequence[str]) -> None: print(f"{idx}. phenotype_id={phenotype_id} name={phenotype_name}") if reasoning: print(f" {reasoning}") + self.session.current_study_intent = study_intent + self.session.current_workflow_type = "phenotype" + self.session.current_step = "phenotype_recommendation" + self.session.current_role = "" + self.session.current_context = { + "top_k": args.top_k, + "max_results": args.max_results, + "candidate_limit": args.candidate_limit, + "recommendation_count": len(recommendations), + } self._print_llm_summary(result) print(f"saved: {artifact}") + def _handle_workflow_dialogue(self, argv: Sequence[str]) -> None: + args = self._parse("/ohdsi", argv) + user_prompt = " ".join(args.text).strip() + if not user_prompt: + raise ValueError("missing dialogue question") + payload = { + "user_prompt": user_prompt, + "study_intent": self.session.current_study_intent, + "workflow_type": self.session.current_workflow_type, + "current_step": self.session.current_step, + "current_role": self.session.current_role, + "current_context": self.session.current_context, + } + result = self._post_flow("/flows/workflow_context_dialogue", payload) + self._require_ok(result) + dialogue = result.get("dialogue") or {} + print(f"status: {result.get('status')}") + print(dialogue.get("answer", "")) + for label, key in (("current step guidance", "current_step_guidance"), ("cautions", "cautions"), ("suggested next actions", "suggested_next_actions")): + items = dialogue.get(key) or [] + if not items: + continue + print(f"{label}:") + for item in items: + print(f"- {item}") + self._print_llm_summary(result) + def _handle_vocab_search(self, argv: Sequence[str]) -> None: args = self._parse("/vocab-search-standard", argv) raw_query_text = " ".join(args.queries).strip() diff --git a/acp_agent/study_agent_acp/llm_client.py b/acp_agent/study_agent_acp/llm_client.py index b96fec9..d9567bc 100644 --- a/acp_agent/study_agent_acp/llm_client.py +++ b/acp_agent/study_agent_acp/llm_client.py @@ -179,6 +179,48 @@ def build_recommendation_intent_facets_prompt( return "\n\n".join([s for s in sections if s]) +def build_workflow_context_dialogue_prompt( + overview: str, + spec: str, + output_schema: Dict[str, Any], + user_prompt: str, + study_intent: str = "", + workflow_type: str = "", + current_step: str = "", + current_role: str = "", + current_context: Optional[Dict[str, Any]] = None, +) -> str: + dynamic = { + "task": "workflow_context_dialogue", + "user_prompt": user_prompt, + "study_intent": study_intent, + "workflow_type": workflow_type, + "current_step": current_step, + "current_role": current_role, + "current_context": current_context or {}, + } + strict_rules = "\n\n".join( + [ + "STRICT OUTPUT RULES:", + spec, + "Return exactly ONE JSON object that matches the output schema.", + "Do NOT wrap output in markdown, code fences, or prose.", + "If uncertain, return required keys with empty arrays/strings.", + "Keep output under 10 KB.", + ] + ) + sections = [ + overview, + "OUTPUT SCHEMA (JSON):", + json.dumps(output_schema, ensure_ascii=True), + "Below is dynamic content to analyze. Do not act until after STRICT OUTPUT RULES.", + "DYNAMIC INPUT (JSON):", + json.dumps(dynamic, ensure_ascii=True), + strict_rules, + ] + return "\n\n".join([s for s in sections if s]) + + def build_advice_prompt( overview: str, spec: str, diff --git a/acp_agent/study_agent_acp/server.py b/acp_agent/study_agent_acp/server.py index 2077faf..ab98f1c 100644 --- a/acp_agent/study_agent_acp/server.py +++ b/acp_agent/study_agent_acp/server.py @@ -26,6 +26,7 @@ {"name": "phenotype_intent_split", "endpoint": "/flows/phenotype_intent_split"}, {"name": "cohort_methods_intent_split", "endpoint": "/flows/cohort_methods_intent_split"}, {"name": "cohort_methods_specifications_recommendation", "endpoint": "/flows/cohort_methods_specifications_recommendation"}, + {"name": "workflow_context_dialogue", "endpoint": "/flows/workflow_context_dialogue"}, ] SERVICE_REGISTRY_PATH = os.getenv("STUDY_AGENT_SERVICE_REGISTRY", "docs/SERVICE_REGISTRY.yaml") logger = logging.getLogger("study_agent.acp") @@ -284,6 +285,11 @@ def do_POST(self) -> None: candidate_offset = body.get("candidate_offset") if candidate_offset is not None: candidate_offset = int(candidate_offset) + recommendation_role = str(body.get("recommendation_role") or "").strip() or None + workflow_type = str(body.get("workflow_type") or "").strip() or None + exclude_metadata = body.get("exclude_metadata") + if not isinstance(exclude_metadata, dict): + exclude_metadata = None try: result = self.agent.run_phenotype_recommendation_flow( study_intent=study_intent, @@ -291,6 +297,9 @@ def do_POST(self) -> None: max_results=max_results, candidate_limit=candidate_limit, candidate_offset=candidate_offset, + recommendation_role=recommendation_role, + workflow_type=workflow_type, + exclude_metadata=exclude_metadata, ) except Exception as exc: if self.debug: @@ -301,6 +310,38 @@ def do_POST(self) -> None: _write_json(self, status, result) return + if self.path == "/flows/workflow_context_dialogue": + try: + body = _read_json(self) + except Exception as exc: + _write_json(self, 400, {"error": f"invalid_json: {exc}"}) + return + user_prompt = str(body.get("user_prompt") or body.get("prompt") or "").strip() + study_intent = str(body.get("study_intent") or "").strip() + workflow_type = str(body.get("workflow_type") or "").strip() + current_step = str(body.get("current_step") or "").strip() + current_role = str(body.get("current_role") or "").strip() + current_context = body.get("current_context") + if not isinstance(current_context, dict): + current_context = {} + try: + result = self.agent.run_workflow_context_dialogue_flow( + user_prompt=user_prompt, + study_intent=study_intent, + workflow_type=workflow_type, + current_step=current_step, + current_role=current_role, + current_context=current_context, + ) + except Exception as exc: + if self.debug: + logger.exception("flow_failed name=workflow_context_dialogue") + _write_json(self, 500, {"error": "flow_failed", "detail": str(exc) if self.debug else None}) + return + status = 200 if result.get("status") != "error" else 500 + _write_json(self, status, result) + return + if self.path == "/flows/cohort_methods_specifications_recommendation": try: body = _read_json(self) diff --git a/core/study_agent_core/models.py b/core/study_agent_core/models.py index 461b7e6..0481510 100644 --- a/core/study_agent_core/models.py +++ b/core/study_agent_core/models.py @@ -50,6 +50,16 @@ class CohortMethodsIntentSplitInput(BaseModel): llm_result: Optional[Dict[str, Any]] = None +class WorkflowContextDialogueInput(BaseModel): + user_prompt: str + study_intent: str = "" + workflow_type: str = "" + current_step: str = "" + current_role: str = "" + current_context: Dict[str, Any] = Field(default_factory=dict) + llm_result: Optional[Dict[str, Any]] = None + + class PhenotypeValidationReviewInput(BaseModel): disease_name: str = "" keeper_row: Dict[str, Any] = Field(default_factory=dict) @@ -276,6 +286,15 @@ class CohortMethodsIntentSplitOutput(BaseModel): mode: str +class WorkflowContextDialogueOutput(BaseModel): + plan: str + answer: str + current_step_guidance: List[str] = Field(default_factory=list) + cautions: List[str] = Field(default_factory=list) + suggested_next_actions: List[str] = Field(default_factory=list) + mode: str + + class PhenotypeValidationReviewOutput(BaseModel): label: str rationale: str diff --git a/core/study_agent_core/tools.py b/core/study_agent_core/tools.py index e0179ca..9a179ef 100644 --- a/core/study_agent_core/tools.py +++ b/core/study_agent_core/tools.py @@ -20,6 +20,8 @@ PhenotypeValidationReviewOutput, PhenotypeRecommendationsInput, PhenotypeRecommendationsOutput, + WorkflowContextDialogueInput, + WorkflowContextDialogueOutput, ) @@ -599,6 +601,65 @@ def phenotype_recommendation_advice( return _model_dump(output) +def workflow_context_dialogue( + user_prompt: str, + study_intent: str = "", + workflow_type: str = "", + current_step: str = "", + current_role: str = "", + current_context: Optional[Dict[str, Any]] = None, + llm_result: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + payload = WorkflowContextDialogueInput( + user_prompt=user_prompt, + study_intent=study_intent, + workflow_type=workflow_type, + current_step=current_step, + current_role=current_role, + current_context=current_context or {}, + llm_result=llm_result, + ) + + plan = "Answer the user's workflow question in the context of the current study-design step." + answer = "" + current_step_guidance: List[str] = [] + cautions: List[str] = [] + suggested_next_actions: List[str] = [] + mode = "llm" + + if payload.llm_result: + if payload.llm_result.get("plan"): + plan = str(payload.llm_result["plan"]) + answer = str(payload.llm_result.get("answer") or "") + if isinstance(payload.llm_result.get("current_step_guidance"), list): + current_step_guidance = [str(item) for item in payload.llm_result["current_step_guidance"]] + if isinstance(payload.llm_result.get("cautions"), list): + cautions = [str(item) for item in payload.llm_result["cautions"]] + if isinstance(payload.llm_result.get("suggested_next_actions"), list): + suggested_next_actions = [str(item) for item in payload.llm_result["suggested_next_actions"]] + else: + mode = "stub" + answer = "No LLM response is available for workflow guidance right now." + if payload.current_step: + current_step_guidance = [ + f"Continue the current workflow step ({payload.current_step}) after clarifying the question manually." + ] + suggested_next_actions = [ + "Restate the question with more concrete study-design detail.", + "Continue the workflow and revisit the question after gathering more context.", + ] + + output = WorkflowContextDialogueOutput( + plan=plan, + answer=answer, + current_step_guidance=current_step_guidance, + cautions=cautions, + suggested_next_actions=suggested_next_actions, + mode=mode, + ) + return _model_dump(output) + + def phenotype_intent_split( study_intent: str, llm_result: Optional[Dict[str, Any]] = None, diff --git a/mcp_server/prompts/workflow_dialogue/output_schema_workflow_context_dialogue.json b/mcp_server/prompts/workflow_dialogue/output_schema_workflow_context_dialogue.json new file mode 100644 index 0000000..9d66b3a --- /dev/null +++ b/mcp_server/prompts/workflow_dialogue/output_schema_workflow_context_dialogue.json @@ -0,0 +1,39 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "workflow_context_dialogue_output", + "type": "object", + "properties": { + "plan": { + "type": "string" + }, + "answer": { + "type": "string" + }, + "current_step_guidance": { + "type": "array", + "items": { + "type": "string" + } + }, + "cautions": { + "type": "array", + "items": { + "type": "string" + } + }, + "suggested_next_actions": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": [ + "plan", + "answer", + "current_step_guidance", + "cautions", + "suggested_next_actions" + ], + "additionalProperties": false +} \ No newline at end of file diff --git a/mcp_server/prompts/workflow_dialogue/overview_workflow_context_dialogue.md b/mcp_server/prompts/workflow_dialogue/overview_workflow_context_dialogue.md new file mode 100644 index 0000000..d2f32a5 --- /dev/null +++ b/mcp_server/prompts/workflow_dialogue/overview_workflow_context_dialogue.md @@ -0,0 +1,3 @@ +You are the OHDSI Assistant (ACP Model) for contextual workflow dialogue. +Answer the user's question using the supplied study-design context. +Do not claim that workflow state has changed. Provide advice only. \ No newline at end of file diff --git a/mcp_server/prompts/workflow_dialogue/spec_workflow_context_dialogue.md b/mcp_server/prompts/workflow_dialogue/spec_workflow_context_dialogue.md new file mode 100644 index 0000000..8bc4beb --- /dev/null +++ b/mcp_server/prompts/workflow_dialogue/spec_workflow_context_dialogue.md @@ -0,0 +1,21 @@ +Tool: workflow_context_dialogue +Output contract: +{ + "plan": "string <=300 chars", + "answer": "string <=1200 chars", + "current_step_guidance": ["string <=200 chars"], + "cautions": ["string <=200 chars"], + "suggested_next_actions": ["string <=200 chars"] +} + +### HEURISTICS/RULES +- Answer the user's question in the context of the provided study intent and current workflow step. +- Keep the answer advisory only; do not imply that any workflow choice or artifact has already changed. +- Use the current role and current_context only when they help answer the question. +- Prefer concrete guidance tied to the user's present step over general OHDSI background. +- If context is sparse, answer conservatively and mention what additional detail would sharpen the guidance. +- Use sparse bullets in current_step_guidance, cautions, and suggested_next_actions. + +Constraints: +- JSON only; no markdown/fences. +- Keep output < 10 KB. \ No newline at end of file diff --git a/mcp_server/study_agent_mcp/tools/__init__.py b/mcp_server/study_agent_mcp/tools/__init__.py index 926feba..05a2264 100644 --- a/mcp_server/study_agent_mcp/tools/__init__.py +++ b/mcp_server/study_agent_mcp/tools/__init__.py @@ -26,6 +26,7 @@ "study_agent_mcp.tools.keeper_concept_sets", "study_agent_mcp.tools.keeper_profiles", "study_agent_mcp.tools.cohort_methods_prompt_bundle", + "study_agent_mcp.tools.workflow_context_dialogue", ] diff --git a/mcp_server/study_agent_mcp/tools/workflow_context_dialogue.py b/mcp_server/study_agent_mcp/tools/workflow_context_dialogue.py new file mode 100644 index 0000000..b20741a --- /dev/null +++ b/mcp_server/study_agent_mcp/tools/workflow_context_dialogue.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import json +import os +from typing import Any, Dict + +from ._common import with_meta + + +_CACHE: Dict[str, Dict[str, Any]] = {} + + +def _prompt_dir() -> str: + return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "prompts", "workflow_dialogue")) + + +def _load_text(path: str) -> str: + with open(path, "r", encoding="utf-8") as handle: + return handle.read().strip() + + +def _load_json(path: str) -> Dict[str, Any]: + with open(path, "r", encoding="utf-8") as handle: + return json.load(handle) + + +def _load_bundle() -> Dict[str, Any]: + cached = _CACHE.get("workflow_context_dialogue") + if cached is not None: + return cached + base = _prompt_dir() + payload = { + "task": "workflow_context_dialogue", + "overview": _load_text(os.path.join(base, "overview_workflow_context_dialogue.md")), + "spec": _load_text(os.path.join(base, "spec_workflow_context_dialogue.md")), + "output_schema": _load_json(os.path.join(base, "output_schema_workflow_context_dialogue.json")), + } + _CACHE["workflow_context_dialogue"] = payload + return payload + + +def register(mcp: object) -> None: + @mcp.tool(name="workflow_context_dialogue") + def workflow_context_dialogue_tool() -> Dict[str, Any]: + return with_meta(_load_bundle(), "workflow_context_dialogue") + + return None diff --git a/tests/test_acp_phenotype_flow.py b/tests/test_acp_phenotype_flow.py index d9dc47c..7c31ff5 100644 --- a/tests/test_acp_phenotype_flow.py +++ b/tests/test_acp_phenotype_flow.py @@ -336,3 +336,293 @@ def fake_llm(prompt, required_keys=None): assert rerank["candidates"][0]["metadata_score"] > rerank["candidates"][1]["metadata_score"] assert any(reason["kind"] == "role_match" for reason in rerank["candidates"][0]["reasons"]) assert any(reason["kind"] == "exclude_procedure" for reason in rerank["candidates"][1]["reasons"]) + + + +@pytest.mark.acp +def test_acp_flow_excludes_disallowed_metadata_before_llm(monkeypatch): + llm_calls = [] + + def fake_llm(prompt, required_keys=None): + llm_calls.append((prompt, tuple(required_keys or []))) + if len(llm_calls) == 1: + return { + "plan": "Extract recommendation intent facets.", + "intent_facets": { + "condition_or_topic": "test medication cohort", + "phenotype_role": "medication_based", + "care_setting": "any", + "population_cue": "adults", + }, + "reasoning_notes": ["Prefer executable medication phenotype."], + } + if len(llm_calls) == 2: + assert '"phenotype_id": "cipher:2"' not in prompt + return { + "plan": "Shortlist executable candidate only.", + "intent_facets": {"phenotype_role": "medication_based"}, + "shortlist_ids": ["ohdsi:1"], + "needs_more_search": False, + "reasoning_notes": ["Excluded disallowed metadata before planning."], + } + assert '"phenotype_id": "cipher:2"' not in prompt + return { + "plan": "Recommend executable candidate.", + "phenotype_recommendations": [ + {"phenotype_id": "ohdsi:1", "phenotype_name": "Alpha", "justification": "ok"} + ], + } + + monkeypatch.setattr(agent_module, "call_llm", fake_llm) + + client = StubMCPClient() + agent = StudyAgent(mcp_client=client) + result = agent.run_phenotype_recommendation_flow( + study_intent="test medication cohort", + top_k=5, + max_results=3, + candidate_limit=3, + exclude_metadata={"executable_definition_status": ["codes_only"]}, + ) + + assert result["status"] == "ok" + assert result["diagnostics"]["candidate_exclusions"]["requested"] == { + "executable_definition_status": ["codes_only"] + } + assert result["diagnostics"]["candidate_exclusions"]["excluded_ids"] == ["cipher:2"] + assert result["planning"]["shortlist_ids"] == ["ohdsi:1"] + recs = result["recommendations"]["phenotype_recommendations"] + assert [rec["phenotype_id"] for rec in recs] == ["ohdsi:1"] + fetch_ids = [args["phenotype_id"] for name, args in client.calls if name == "phenotype_fetch_summary"] + assert fetch_ids == ["ohdsi:1", "ohdsi:1"] + + +@pytest.mark.acp +def test_acp_flow_comparator_role_prefers_direct_exposure_match(monkeypatch): + llm_calls = [] + + class ComparatorStubMCPClient(StubMCPClient): + def call_tool(self, name, arguments): + self.calls.append((name, arguments)) + if name == "phenotype_search": + return { + "results": [ + { + "phenotype_id": "ohdsi:sglt2", + "name": "[P] New users of SGLT2 inhibitor", + "short_description": "SGLT2 new users", + "score": 0.99, + "executable_definition_status": "native_ohdsi", + "execution_readiness_score": 1.0, + }, + { + "phenotype_id": "ohdsi:glipizide", + "name": "[P] New users of glipizide", + "short_description": "Glipizide new users", + "score": 0.80, + "executable_definition_status": "native_ohdsi", + "execution_readiness_score": 1.0, + }, + ] + } + if name == "phenotype_prompt_bundle": + task = arguments["task"] + return { + "overview": f"overview {task}", + "spec": f"spec {task}", + "output_schema": {"type": "object", "title": task}, + } + if name == "phenotype_fetch_summary": + phenotype_id = arguments["phenotype_id"] + if phenotype_id == "ohdsi:sglt2": + return { + "content": { + "phenotype_id": "ohdsi:sglt2", + "name": "[P] New users of SGLT2 inhibitor", + "short_description": "SGLT2 new users", + "primary_clinical_topic": "SGLT2 inhibitors", + "phenotype_role": "medication_based", + "care_setting_scope": "mixed", + "population_scope": "adults with diabetes", + "retrieval_keywords": ["sglt2", "empagliflozin", "canagliflozin"], + "recommendation_summary": "Executable SGLT2 comparator cohort.", + } + } + if phenotype_id == "ohdsi:glipizide": + return { + "content": { + "phenotype_id": "ohdsi:glipizide", + "name": "[P] New users of glipizide", + "short_description": "Glipizide new users", + "primary_clinical_topic": "glipizide", + "phenotype_role": "medication_based", + "care_setting_scope": "mixed", + "population_scope": "adults with diabetes", + "retrieval_keywords": ["glipizide", "sulfonylurea"], + "recommendation_summary": "Executable glipizide comparator cohort.", + } + } + raise ValueError(f"unexpected tool {name}") + + def fake_llm(prompt, required_keys=None): + llm_calls.append((prompt, tuple(required_keys or []))) + if len(llm_calls) == 1: + return { + "plan": "Extract recommendation intent facets.", + "intent_facets": { + "condition_or_topic": "glipizide new users", + "phenotype_role": "medication_based", + "care_setting": "any", + "population_cue": "adults with diabetes", + }, + "reasoning_notes": ["Comparator should match the named exposure."], + } + if len(llm_calls) == 2: + assert prompt.index('"phenotype_id": "ohdsi:glipizide"') < prompt.index('"phenotype_id": "ohdsi:sglt2"') + return { + "plan": "Shortlist glipizide candidate first.", + "intent_facets": {"phenotype_role": "medication_based"}, + "shortlist_ids": ["ohdsi:glipizide"], + "needs_more_search": False, + "reasoning_notes": ["Direct comparator exposure match outranks adjacent drug class."], + } + return { + "plan": "Recommend glipizide candidate.", + "phenotype_recommendations": [ + {"phenotype_id": "ohdsi:glipizide", "phenotype_name": "[P] New users of glipizide", "justification": "ok"} + ], + } + + monkeypatch.setattr(agent_module, "call_llm", fake_llm) + + agent = StudyAgent(mcp_client=ComparatorStubMCPClient()) + result = agent.run_phenotype_recommendation_flow( + study_intent="New users of glipizide with no prior glipizide exposure in the 365 days before index date.", + top_k=5, + max_results=3, + candidate_limit=2, + recommendation_role="comparator", + workflow_type="cohort_methods", + ) + + assert result["status"] == "ok" + assert result["recommendation_role"] == "comparator" + assert result["workflow_type"] == "cohort_methods" + assert result["planning"]["shortlist_ids"] == ["ohdsi:glipizide"] + assert result["recommendations"]["phenotype_recommendations"][0]["phenotype_id"] == "ohdsi:glipizide" + rerank = result["diagnostics"]["planning_rerank"] + assert rerank["candidates"][0]["phenotype_id"] == "ohdsi:glipizide" + assert rerank["candidates"][1]["phenotype_id"] == "ohdsi:sglt2" + assert any(reason["kind"] == "comparator_focus_match" for reason in rerank["candidates"][0]["reasons"]) + assert any(reason["kind"] == "comparator_focus_mismatch" for reason in rerank["candidates"][1]["reasons"]) + + + +@pytest.mark.acp +def test_acp_flow_comparator_without_direct_match_returns_no_recommendations(monkeypatch): + llm_calls = [] + + class NoDirectComparatorStubMCPClient(StubMCPClient): + def call_tool(self, name, arguments): + self.calls.append((name, arguments)) + if name == "phenotype_search": + return { + "results": [ + { + "phenotype_id": "ohdsi:sglt2", + "name": "[P] New users of SGLT2 inhibitor", + "short_description": "SGLT2 new users", + "score": 0.91, + "executable_definition_status": "native_ohdsi", + "execution_readiness_score": 1.0, + }, + { + "phenotype_id": "ohdsi:dpp4", + "name": "[P] New users of DPP-4 inhibitors", + "short_description": "DPP4 new users", + "score": 0.88, + "executable_definition_status": "native_ohdsi", + "execution_readiness_score": 1.0, + }, + ] + } + if name == "phenotype_prompt_bundle": + task = arguments["task"] + return { + "overview": f"overview {task}", + "spec": f"spec {task}", + "output_schema": {"type": "object", "title": task}, + } + if name == "phenotype_fetch_summary": + phenotype_id = arguments["phenotype_id"] + if phenotype_id == "ohdsi:sglt2": + return { + "content": { + "phenotype_id": "ohdsi:sglt2", + "name": "[P] New users of SGLT2 inhibitor", + "short_description": "SGLT2 new users", + "primary_clinical_topic": "SGLT2 inhibitors", + "phenotype_role": "medication_based", + "care_setting_scope": "mixed", + "population_scope": "adults with diabetes", + "retrieval_keywords": ["sglt2", "empagliflozin"], + "recommendation_summary": "Executable SGLT2 cohort.", + } + } + if phenotype_id == "ohdsi:dpp4": + return { + "content": { + "phenotype_id": "ohdsi:dpp4", + "name": "[P] New users of DPP-4 inhibitors", + "short_description": "DPP4 new users", + "primary_clinical_topic": "DPP-4 inhibitors", + "phenotype_role": "medication_based", + "care_setting_scope": "mixed", + "population_scope": "adults with diabetes", + "retrieval_keywords": ["dpp4", "sitagliptin"], + "recommendation_summary": "Executable DPP4 cohort.", + } + } + raise ValueError(f"unexpected tool {name}") + + def fake_llm(prompt, required_keys=None): + llm_calls.append((prompt, tuple(required_keys or []))) + if len(llm_calls) == 1: + return { + "plan": "Extract recommendation intent facets.", + "intent_facets": { + "condition_or_topic": "glipizide new users", + "phenotype_role": "medication_based", + "care_setting": "any", + "population_cue": "adults with diabetes", + }, + "reasoning_notes": ["Comparator should match the named exposure."], + } + return { + "plan": "Shortlist adjacent diabetes medication cohorts.", + "intent_facets": {"phenotype_role": "medication_based"}, + "shortlist_ids": ["ohdsi:sglt2", "ohdsi:dpp4"], + "needs_more_search": False, + "reasoning_notes": ["No direct glipizide cohort was found."], + } + + monkeypatch.setattr(agent_module, "call_llm", fake_llm) + + agent = StudyAgent(mcp_client=NoDirectComparatorStubMCPClient()) + result = agent.run_phenotype_recommendation_flow( + study_intent="New users of glipizide with no prior glipizide exposure in the 365 days before index date.", + top_k=20, + max_results=3, + candidate_limit=10, + recommendation_role="comparator", + workflow_type="cohort_methods", + ) + + assert result["status"] == "ok" + assert result["llm_status"] == "skipped_no_direct_role_match" + assert result["fallback_reason"] == "no_direct_role_match" + assert result["recommendations"]["phenotype_recommendations"] == [] + assert result["diagnostics"]["role_match_gate"]["required_kind"] == "comparator_focus_match" + assert result["diagnostics"]["role_match_gate"]["matched_candidate_ids"] == [] + assert result["diagnostics"]["role_match_gate"]["skip_reason"] == "no_direct_role_match" + assert len(llm_calls) == 2 diff --git a/tests/test_acp_server.py b/tests/test_acp_server.py index 339d207..8dc0351 100644 --- a/tests/test_acp_server.py +++ b/tests/test_acp_server.py @@ -126,6 +126,8 @@ def call_tool(self, name, arguments): return {"overview": "overview", "spec": "spec", "output_schema": {"type": "object"}} if name == "cohort_methods_intent_split": return {"overview": "overview", "spec": "spec", "output_schema": {"type": "object"}} + if name == "workflow_context_dialogue": + return {"overview": "overview", "spec": "spec", "output_schema": {"type": "object"}} if name == "lint_prompt_bundle": return {"overview": "overview", "spec": "spec", "output_schema": {"type": "object"}} if name == "keeper_sanitize_row": @@ -1207,3 +1209,46 @@ def fake_llm(prompt, required_keys=None): assert signal_call["ingred_rxcui"] == "456" assert signal_call["report_lookup_key"] == {"primaryid": None, "isr": "6526923"} assert signal_call["adverse_event_meddra_id"] == "789" + + +@pytest.mark.acp +def test_flow_workflow_context_dialogue(monkeypatch): + agent = StudyAgent(mcp_client=StubMCPClient()) + + def fake_call_llm(prompt, required_keys=None): + assert "workflow_context_dialogue" in prompt + return LLMCallResult( + status="ok", + parsed_content={ + "plan": "answer in context", + "answer": "Washout reduces prevalent-user bias.", + "current_step_guidance": ["Keep the existing comparator step open while you decide."], + "cautions": ["Do not change cohort IDs yet."], + "suggested_next_actions": ["Confirm whether the design is new-user or prevalent-user."], + }, + content_text="{}", + parse_stage="chat_completions_content", + schema_valid=True, + ) + + monkeypatch.setattr(agent, "_call_llm", fake_call_llm) + + result = agent.run_workflow_context_dialogue_flow( + user_prompt="Why does the washout matter here?", + study_intent="Compare metformin versus sulfonylurea new users.", + workflow_type="cohort_methods", + current_step="comparator_recommendation", + current_role="comparator", + current_context={"statement": "New users of glipizide"}, + ) + + assert result["status"] == "ok" + assert result["dialogue"]["answer"] == "Washout reduces prevalent-user bias." + assert result["dialogue"]["current_step_guidance"] == ["Keep the existing comparator step open while you decide."] + + +@pytest.mark.acp +def test_flow_workflow_context_dialogue_missing_prompt(): + agent = StudyAgent(mcp_client=StubMCPClient()) + result = agent.run_workflow_context_dialogue_flow(user_prompt="") + assert result["error"] == "missing user_prompt" diff --git a/tests/test_cohort_methods_shell_recommendation_support.py b/tests/test_cohort_methods_shell_recommendation_support.py new file mode 100644 index 0000000..3c480e6 --- /dev/null +++ b/tests/test_cohort_methods_shell_recommendation_support.py @@ -0,0 +1,29 @@ +from pathlib import Path + +SOURCE = Path("R/OHDSIAssistant/R/strategus_cohort_methods_shell.R") + + +def test_shell_supports_namespaced_recommendation_ids_and_blocks_unsupported_selection() -> None: + source = SOURCE.read_text(encoding="utf-8") + + assert 'recommendation_identifier <- function(rec)' in source + assert 'recommendation_is_ohdsi_computable <- function(rec)' in source + assert 'grepl("^ohdsi:[0-9]+$", identifier)' in source + assert 'unsupported_recommendation_message <- function(rec, role_label)' in source + assert 'Descriptive phenotypes such as CIPHER recommendations are not yet convertible' in source + assert 'stop(unsupported_recommendation_message(' in source + + +def test_shell_displays_noncomputable_recommendation_note() -> None: + source = SOURCE.read_text(encoding="utf-8") + + assert 'recommendation_id_label(rec)' in source + assert 'Not directly computable in this workflow; descriptive phenotype conversion is not yet implemented.' in source + +def test_shell_resolves_namespaced_source_definition_filenames() -> None: + source = SOURCE.read_text(encoding="utf-8") + + assert 'resolve_index_definition_path <- function(source_id, index_def_dir)' in source + assert 'sprintf("ohdsi__%s.json", source_text)' in source + assert 'gsub(":", "__", source_text, fixed = TRUE)' in source + assert 'src <- resolve_index_definition_path(source_id, index_def_dir)' in source diff --git a/tests/test_demo_shell.py b/tests/test_demo_shell.py index 26e311c..e00166a 100644 --- a/tests/test_demo_shell.py +++ b/tests/test_demo_shell.py @@ -1,6 +1,11 @@ import pytest +from pathlib import Path + from study_agent_acp.demo_shell import ( + ACPClient, + DemoSession, + StudyAgentDemoShell, _extract_keeper_row, _infer_phenotype_name, _slugify, @@ -65,3 +70,43 @@ def test_infer_phenotype_name_prefers_top_level() -> None: def test_infer_phenotype_name_uses_nested_full_result() -> None: payload = {"full_result": {"phenotype_name": "Intracranial bleeding"}} assert _infer_phenotype_name(payload) == "Intracranial bleeding" + + +class _FakeClient(ACPClient): + def __init__(self): + super().__init__(base_url="http://127.0.0.1:8765") + self.last_post = None + + def post(self, path, payload): + self.last_post = (path, payload) + return { + "status": "ok", + "dialogue": { + "answer": "Use the current step context before changing cohort IDs.", + "current_step_guidance": ["Stay in the comparator recommendation step."], + "cautions": ["Do not overwrite cached selections yet."], + "suggested_next_actions": ["Review the comparator statement wording."], + }, + } + + +def test_demo_shell_ohdsi_command_uses_session_context(tmp_path, capsys) -> None: + client = _FakeClient() + session = DemoSession(output_dir=Path(tmp_path)) + session.current_study_intent = "Compare sitagliptin versus glipizide new users." + session.current_workflow_type = "cohort_methods" + session.current_step = "comparator_recommendation" + session.current_role = "comparator" + session.current_context = {"statement": "New users of glipizide"} + shell = StudyAgentDemoShell(client=client, session=session) + + shell.handle_line("/ohdsi why is this comparator weak?") + + assert client.last_post is not None + path, payload = client.last_post + assert path == "/flows/workflow_context_dialogue" + assert payload["current_step"] == "comparator_recommendation" + assert payload["current_role"] == "comparator" + assert payload["study_intent"] == "Compare sitagliptin versus glipizide new users." + out = capsys.readouterr().out + assert "Use the current step context before changing cohort IDs." in out diff --git a/tests/test_mcp_prompt_bundle.py b/tests/test_mcp_prompt_bundle.py index 90f8f3d..4a81e93 100644 --- a/tests/test_mcp_prompt_bundle.py +++ b/tests/test_mcp_prompt_bundle.py @@ -205,3 +205,17 @@ def test_case_causal_review_build_prompt_contains_allowed_domains() -> None: assert '"adverse_event_name": "Hepatic failure"' in payload["prompt"] assert '"allowed_domains": [' in payload["prompt"] assert '"candidate_items": [' in payload["prompt"] + + +@pytest.mark.mcp +def test_workflow_context_dialogue_bundle_schema() -> None: + from study_agent_mcp.tools import workflow_context_dialogue + + mcp = DummyMCP() + workflow_context_dialogue.register(mcp) + fn = mcp.tools["workflow_context_dialogue"] + payload = fn() + assert "overview" in payload + assert "spec" in payload + assert "output_schema" in payload + assert payload["output_schema"]["title"] == "workflow_context_dialogue_output" diff --git a/tests/test_mcp_tools_registry.py b/tests/test_mcp_tools_registry.py index dec44d0..0d1f739 100644 --- a/tests/test_mcp_tools_registry.py +++ b/tests/test_mcp_tools_registry.py @@ -57,4 +57,5 @@ def test_register_all_tools() -> None: "vocab_add_nonchildren", "vocab_fetch_concepts", "cohort_methods_prompt_bundle", + "workflow_context_dialogue", }