diff --git a/R/geyserEnrichment.R b/R/geyserEnrichment.R index 2ab033c..1be8ffc 100644 --- a/R/geyserEnrichment.R +++ b/R/geyserEnrichment.R @@ -21,6 +21,10 @@ #' *`"group"`* – natural sort of group labels; #' *`NULL`* – keep original ordering. #' @param facet.by Optional metadata column used to facet the plot. +#' @param summarise.by Optional metadata column used to summarise data. +#' @param summary.stat Optional method used to summarize expression within each +#' group defined by \code{summarise.by}. One of: \code{"mean"} (default), +#' \code{"median"}, \code{"max"}, \code{"sum"}, or \code{"geometric"}. #' @param scale Logical; if `TRUE` scores are centered/scaled (Z‑score) prior #' to plotting. #' @param palette Character. Any palette from \code{\link[grDevices]{hcl.pals}}. @@ -50,6 +54,8 @@ geyserEnrichment <- function(input.data, order.by = NULL, scale = FALSE, facet.by = NULL, + summarise.by = NULL, + summary.stat = "mean", palette = "inferno") { ## ---- 0) Sanity checks ----------------------------------------------------- if (missing(gene.set) || length(gene.set) != 1L) @@ -61,10 +67,27 @@ geyserEnrichment <- function(input.data, if (identical(color.by, "group")) color.by <- group.by - ## ---- 1) Build tidy data.frame ------------------------------------------- + if (!is.null(summarise.by) && (identical(summarise.by, group.by) || + identical(summarise.by, facet.by))) + stop("'summarise.by' cannot be the same as 'group.by' or 'facet.by'. + Please choose a different metadata column.") + + # ---- 1) helper to match summary function ------------------------- + summary_fun <- .match_summary_fun(summary.stat) + + ## ---- 2) Build tidy data.frame ------------------------------------------- enriched <- .prepData(input.data, assay, gene.set, group.by, - split.by = NULL, facet.by = facet.by) + split.by = summarise.by, facet.by = facet.by, color.by = color.by) + ## Optionally summarise data with **base aggregate()** ---------------------- + if (!is.null(summarise.by)) { + grp_cols <- c(summarise.by, group.by, facet.by, color.by) + enriched <- aggregate(enriched[gene.set], + by = enriched[grp_cols], + FUN = summary_fun, + SIMPLIFY = FALSE) + } + ## Optionally Z‑transform ---------------------------------------------------- if (scale) enriched[[gene.set]] <- as.numeric(scale(enriched[[gene.set]])) @@ -73,12 +96,19 @@ geyserEnrichment <- function(input.data, if (!is.null(order.by)) enriched <- .orderFunction(enriched, order.by, group.by) - ## ---- 2) Plot -------------------------------------------------------------- - plt <- ggplot(enriched, aes(x = .data[[group.by]], - y = .data[[gene.set]], - colour = .data[[color.by]])) + + ## ---- 3) Plot -------------------------------------------------------------- + if (!is.null(color.by)) + plt <- ggplot(enriched, aes(x = .data[[group.by]], + y = .data[[gene.set]], + group = .data[[group.by]], + colour = .data[[color.by]])) + else + plt <- ggplot(enriched, aes(x = .data[[group.by]], + y = .data[[gene.set]]), + group = .data[[group.by]]) + # Raw points -------------------------------------------------------------- - geom_jitter(width = 0.25, size = 1.5, alpha = 0.6, na.rm = TRUE) + + plt <- plt + geom_jitter(width = 0.25, size = 1.5, alpha = 0.6, na.rm = TRUE) + # White base interval + median point ------------------------------------- stat_pointinterval(interval_size_range = c(2, 3), fatten_point = 1.4, @@ -97,10 +127,11 @@ geyserEnrichment <- function(input.data, theme(legend.direction = "horizontal", legend.position = "bottom") - ## ---- 3) Colour scale ------------------------------------------------------ - plt <- .colorby(enriched, plt, color.by, palette, type = "color") + ## ---- 4) Colour scale ------------------------------------------------------ + if (!is.null(color.by)) + plt <- .colorby(enriched, plt, color.by, palette, type = "color") - ## ---- 4) Facetting --------------------------------------------------------- + ## ---- 5) Facetting --------------------------------------------------------- if (!is.null(facet.by)) plt <- plt + facet_grid(as.formula(paste(".~", facet.by))) diff --git a/R/heatmapEnrichment.R b/R/heatmapEnrichment.R index 821f982..cc42e24 100644 --- a/R/heatmapEnrichment.R +++ b/R/heatmapEnrichment.R @@ -17,9 +17,9 @@ #' @param facet.by Optional metadata column used to facet the plot. #' @param scale If \code{TRUE}, Z‑transforms each gene‑set column **after** #' summarization. -#' @param summary.stat Method used to summarize expression within each -#* group: one of `"mean"` (default), `"median"`, `"max"`, -#*`"sum"`, or `"geometric"` +#' @param summary.stat Optional method used to summarize expression within each +#' group. One of: \code{"mean"} (default), \code{"median"}, \code{"max"}, +#' \code{"sum"}, or \code{"geometric"}. #' @param palette Character. Any palette from \code{\link[grDevices]{hcl.pals}}. #' #' @return A \code{ggplot2} object. @@ -47,22 +47,6 @@ heatmapEnrichment <- function(input.data, palette = "inferno") { # ---------- 1. helper to match summary function ------------------------- - .match_summary_fun <- function(fun) { - if (is.function(fun)) return(fun) - if (!is.character(fun) || length(fun) != 1) - stop("'summary.stat' must be a single character keyword or a function") - kw <- tolower(fun) - fn <- switch(kw, - mean = base::mean, - median = stats::median, - sum = base::sum, - sd = stats::sd, - max = base::max, - min = base::min, - geometric = function(x) exp(mean(log(x + 1e-6))), - stop("Unsupported summary keyword: ", fun)) - fn - } summary_fun <- .match_summary_fun(summary.stat) # ---------- 2. pull / tidy data ----------------------------------------- @@ -70,7 +54,8 @@ heatmapEnrichment <- function(input.data, df <- .prepData(input.data, assay, gene.set.use, group.by = group.by, split.by = NULL, - facet.by = facet.by) + facet.by = facet.by, + color.by = NULL) # Which columns contain gene-set scores? if (identical(gene.set.use, "all")) diff --git a/R/ridgeEnrichment.R b/R/ridgeEnrichment.R index 9f38826..85d8584 100644 --- a/R/ridgeEnrichment.R +++ b/R/ridgeEnrichment.R @@ -61,7 +61,7 @@ ridgeEnrichment <- function(input.data, ## ---- 1 build long data.frame --------------------------------------- df <- .prepData(input.data, assay, gene.set.use, group.by, - split.by = NULL, facet.by = facet.by) + split.by = NULL, facet.by = facet.by, color.by = color.by) ## optional scaling (Z-transform per gene-set) ------------------------- if (scale) diff --git a/R/scatterEnrichment.R b/R/scatterEnrichment.R index 1f8338b..2a19f40 100644 --- a/R/scatterEnrichment.R +++ b/R/scatterEnrichment.R @@ -69,7 +69,8 @@ scatterEnrichment <- function(input.data, gene.set <- c(x.axis, y.axis) ## ---- 1 Assemble long data-frame ----------------------------------------- - enriched <- .prepData(input.data, assay, gene.set, group.by, NULL, facet.by) + enriched <- .prepData(input.data, assay, gene.set, group.by, NULL, facet.by, + color.by = NULL) if (scale) { enriched[, gene.set] <- apply(enriched[, gene.set, drop = FALSE], 2, scale) diff --git a/R/splitEnrichment.R b/R/splitEnrichment.R index c081b5d..7226b9b 100644 --- a/R/splitEnrichment.R +++ b/R/splitEnrichment.R @@ -56,7 +56,8 @@ splitEnrichment <- function(input.data, if (is.null(group.by)) group.by <- "ident" # Prepare tidy data with relevant metadata columns - enriched <- .prepData(input.data, assay, gene.set.use, group.by, split.by, facet.by) + enriched <- .prepData(input.data, assay, gene.set.use, group.by, split.by, + facet.by, color.by = NULL) # Determine the number of levels in the splitting variable split.levels <- unique(enriched[[split.by]]) diff --git a/R/utils.R b/R/utils.R index c713793..39a22aa 100644 --- a/R/utils.R +++ b/R/utils.R @@ -41,10 +41,10 @@ # DATA.frame BUILDERS --------------------------------------------------------- # ----------------------------------------------------------------------------- .makeDFfromSCO <- function(input.data, assay = "escape", gene.set = NULL, - group.by = NULL, split.by = NULL, facet.by = NULL) { + group.by = NULL, split.by = NULL, facet.by = NULL, color.by = NULL) { if (is.null(assay)) stop("Please provide assay name") - cols <- unique(c(group.by, split.by, facet.by)) + cols <- unique(c(group.by, split.by, facet.by, color.by)) cnts <- .cntEval(input.data, assay = assay, type = "data") if (length(gene.set) == 1 && gene.set == "all") @@ -62,18 +62,18 @@ df } -.prepData <- function(input.data, assay, gene.set, group.by, split.by, facet.by) { +.prepData <- function(input.data, assay, gene.set, group.by, split.by, facet.by, color.by) { if (.is_seurat_or_sce(input.data)) { - df <- .makeDFfromSCO(input.data, assay, gene.set, group.by, split.by, facet.by) + df <- .makeDFfromSCO(input.data, assay, gene.set, group.by, split.by, facet.by, color.by) if (identical(gene.set, "all")) { - gene.set <- setdiff(colnames(df), c(group.by, split.by, facet.by)) + gene.set <- setdiff(colnames(df), c(group.by, split.by, facet.by, color.by)) } } else { # assume plain data.frame / matrix if (identical(gene.set, "all")) - gene.set <- setdiff(colnames(input.data), c(group.by, split.by, facet.by)) - df <- input.data[, c(gene.set, group.by, split.by, facet.by), drop = FALSE] + gene.set <- setdiff(colnames(input.data), c(group.by, split.by, facet.by, color.by)) + df <- input.data[, c(gene.set, group.by, split.by, facet.by, color.by), drop = FALSE] } - colnames(df) <- c(gene.set, group.by, split.by, facet.by) + colnames(df) <- unique(c(gene.set, group.by, split.by, facet.by, color.by)) df } @@ -443,4 +443,21 @@ utils::globalVariables(c( "gene.set.query", "index" )) +# helper to match summary function +.match_summary_fun <- function(fun) { + if (is.function(fun)) return(fun) + if (!is.character(fun) || length(fun) != 1) + stop("'summary.stat' must be a single character keyword or a function") + kw <- tolower(fun) + fn <- switch(kw, + mean = base::mean, + median = stats::median, + sum = base::sum, + sd = stats::sd, + max = base::max, + min = base::min, + geometric = function(x) exp(mean(log(x + 1e-6))), + stop("Unsupported summary keyword: ", fun)) + fn +} diff --git a/man/geyserEnrichment.Rd b/man/geyserEnrichment.Rd index db5370a..bd2f151 100644 --- a/man/geyserEnrichment.Rd +++ b/man/geyserEnrichment.Rd @@ -13,6 +13,8 @@ geyserEnrichment( order.by = NULL, scale = FALSE, facet.by = NULL, + summarise.by = NULL, + summary.stat = "mean", palette = "inferno" ) } @@ -44,6 +46,12 @@ to plotting.} \item{facet.by}{Optional metadata column used to facet the plot.} +\item{summarise.by}{Optional metadata column used to summarise data.} + +\item{summary.stat}{Optional method used to summarize expression within each +group defined by \code{summarise.by}. One of: \code{"mean"} (default), +\code{"median"}, \code{"max"}, \code{"sum"}, or \code{"geometric"}.} + \item{palette}{Character. Any palette from \code{\link[grDevices]{hcl.pals}}.} } \value{ diff --git a/man/heatmapEnrichment.Rd b/man/heatmapEnrichment.Rd index ea10eda..722c4f2 100644 --- a/man/heatmapEnrichment.Rd +++ b/man/heatmapEnrichment.Rd @@ -38,7 +38,9 @@ are ordered by Ward‑linkage hierarchical clustering (Euclidean distance).} \item{scale}{If \code{TRUE}, Z‑transforms each gene‑set column **after** summarization.} -\item{summary.stat}{Method used to summarize expression within each} +\item{summary.stat}{Optional method used to summarize expression within each +group. One of: \code{"mean"} (default), \code{"median"}, \code{"max"}, +\code{"sum"}, or \code{"geometric"}.} \item{palette}{Character. Any palette from \code{\link[grDevices]{hcl.pals}}.} } diff --git a/tests/testthat/test-splitEnrichment.R b/tests/testthat/test-splitEnrichment.R index 556b58b..7c068be 100644 --- a/tests/testthat/test-splitEnrichment.R +++ b/tests/testthat/test-splitEnrichment.R @@ -70,7 +70,8 @@ test_that("order.by = 'mean' reorders x-axis levels by descending mean", { gene.set = "Tcells", group.by = "ident", split.by = "groups", - facet.by = NULL + facet.by = NULL, + color.by = NULL ) expected <- enr %>%