Skip to content

Commit a5dc90d

Browse files
authored
Merge pull request #47 from gadget-framework/g3_iterative_default_grouping
g3_iterative: Add g3_iterative_default_grouping
2 parents 52721f2 + 8c20652 commit a5dc90d

File tree

4 files changed

+118
-2
lines changed

4 files changed

+118
-2
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export(g3_data)
1414
export(g3_fit)
1515
export(g3_init_guess)
1616
export(g3_iterative)
17+
export(g3_iterative_default_grouping)
1718
export(g3_iterative_setup)
1819
export(g3_jitter)
1920
export(g3_leaveout)

R/g3_iterative.R

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
#' @export
7676
g3_iterative <- function(gd, wgts = 'WGTS',
7777
model, params.in,
78-
grouping = list(),
78+
grouping = g3_iterative_default_grouping(params.in),
7979
use_parscale = TRUE,
8080
method = 'BFGS',
8181
control = list(),
@@ -374,6 +374,54 @@ g3_iterative <- function(gd, wgts = 'WGTS',
374374
return(params_final)
375375
}
376376

377+
# Generate default grouping, combine all fleet likelihoods into one group
378+
# NB: For this to work, nll_names need to be in the form (nll_source)_(nll_dist), where (nll_dist) matches one of the (nll_dist_names)
379+
#' @param params.in Initial parameters to use with the model
380+
#' @param nll_dist_names Character vector of postfixes to consider when looking for groupings
381+
#' @return
382+
#' \subsection{g3_iterative_default_grouping}{
383+
#' A list of component groups to component names, as required by the \var{grouping} parameter
384+
#' }
385+
#' @details
386+
#' \subsection{g3_iterative_default_grouping}{
387+
#' This assumes that your likelihood component names are of the form ``(nll_group)_(nll_dist)``,
388+
#' where ``(nll_dist)`` matches one of the regexes in \var{nll_dist_names}.
389+
#' For example, ``afleet_ldist``, ``afleet_aldist``, ``bfleet_ldist``. ``afleet`` & ``bfleet`` will be the groups used.
390+
#' }
391+
#' @rdname g3_iterative
392+
#' @export
393+
g3_iterative_default_grouping <- function (params.in, nll_dist_names = c("ldist", "aldist", "matp", "sexdist", "SI", "len\\d+SI")) {
394+
# Extract all likelihood component weight names from params.in
395+
weight_re <- paste0(
396+
"^",
397+
"(?<dist>.dist|.sparse)_",
398+
"(?<function>surveyindices_log|[a-z]+)_",
399+
"(?<nll_source>.+)_",
400+
"(?<nll_dist>", paste0(nll_dist_names, collapse = "|"), ")_",
401+
"weight$"
402+
)
403+
404+
# Break up names into a data.frame of param_name -> regex groups
405+
weight_names <- grep(weight_re, rownames(params.in), value = TRUE, perl = TRUE)
406+
weight_parts <- as.data.frame(do.call(rbind, regmatches(weight_names, regexec(weight_re, weight_names, perl = TRUE))))
407+
names(weight_parts)[[1]] <- "param_name"
408+
weight_parts$value <- params.in[weight_parts$param_name, "value"]
409+
410+
# Remove any zero-weighted parameters
411+
zero_value <- weight_parts[weight_parts$value == 0, "param_name"]
412+
if (length(zero_value) > 0) {
413+
warning("Parameters ", paste(zero_value, collapse = ", ") , " have a value of 0, removing from grouping")
414+
weight_parts <- weight_parts[weight_parts$value > 0,]
415+
}
416+
417+
# Group rows together into a list of nll_source -> vector of (nll_source)_(nll_dist)
418+
sapply(
419+
unique(weight_parts$nll_source),
420+
function (nll_source) paste0(nll_source, "_", weight_parts[weight_parts$nll_source == nll_source, "nll_dist"]),
421+
simplify = FALSE
422+
)
423+
}
424+
377425
#' @title Initial parameters for iterative re-weighting
378426
#' @param lik_out A likelihood summary dataframe. The output of g3_lik_out(model, param)
379427
#' @param grouping A list describing how to group likelihood components for iterative re-weighting

man/g3_iterative.Rd

Lines changed: 19 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
library(unittest)
2+
library(gadgetutils)
3+
4+
library(unittest)
5+
6+
# Convert a string into a data.frame
7+
table_string <- function (text, ...) {
8+
out <- read.table(
9+
text = text,
10+
blank.lines.skip = TRUE,
11+
header = TRUE,
12+
stringsAsFactors = FALSE,
13+
...)
14+
rownames(out) <- out$switch
15+
return(out)
16+
}
17+
18+
ok(ut_cmp_identical(g3_iterative_default_grouping(table_string('
19+
switch value
20+
cdist_sumofsquares_comm_ldist_weight 1
21+
cdist_sumofsquares_comm_aldist_weight 1
22+
cdist_sumofsquares_comm_argle_weight 1
23+
cdist_sumofsquares_comm_matp_weight 1
24+
cdist_sumofsquares_fgn_ldist_weight 1
25+
cdist_sumofsquares_fgn_aldist_weight 1
26+
cdist_surveyindices_log_surv_si_weight 1
27+
'), nll_dist_names = c("ldist", "aldist", "matp", "si")), list(
28+
# NB: argle is missing
29+
comm = c("comm_ldist", "comm_aldist", "comm_matp"),
30+
fgn = c("fgn_ldist", "fgn_aldist"),
31+
# NB: parsed the awkward surveyindices_log
32+
surv = c("surv_si")
33+
)))
34+
35+
ok(ut_cmp_identical(suppressWarnings(g3_iterative_default_grouping(table_string('
36+
switch value
37+
cdist_sumofsquares_comm_ldist_weight 1
38+
cdist_sumofsquares_comm_aldist_weight 1
39+
cdist_sumofsquares_comm_argle_weight 1
40+
cdist_sumofsquares_comm_matp_weight 1
41+
cdist_sumofsquares_fgn_ldist_weight 1
42+
cdist_sumofsquares_fgn_aldist_weight 0
43+
cdist_surveyindices_log_surv_si_weight 1
44+
'), nll_dist_names = c("ldist", "aldist", "matp", "si"))), list(
45+
comm = c("comm_ldist", "comm_aldist", "comm_matp"),
46+
# NB: zero-weighted doesn't count
47+
fgn = c("fgn_ldist"),
48+
surv = c("surv_si")
49+
)))

0 commit comments

Comments
 (0)