diff --git a/DESCRIPTION b/DESCRIPTION index 1993d2f..f28e533 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -17,7 +17,8 @@ Suggests: knitr, rmarkdown, testthat (>= 3.0.0), - tidyr + tidyr, + withr Config/testthat/edition: 3 Depends: R (>= 3.5) @@ -58,4 +59,5 @@ Collate: 'plot_posterior.R' 'stan_summaries.R' 'stanmodels.R' + 'utils.R' 'utils_misc.R' diff --git a/NAMESPACE b/NAMESPACE index 37c898c..4cb733b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -9,6 +9,7 @@ export(distribution.negative_binomial) export(distribution.normal) export(distribution.point_mass) export(distribution.poisson) +export(distribution.uniform) export(estimate_mixture_of_two_normals) export(logistic) export(logit) @@ -21,6 +22,9 @@ export(rename_params_cmdstanfile_to_rstan) export(simulate_mixture_of_two_normals) export(stanfit_to_dt) export(stanfit_to_matrix) +export(utils.class) +export(utils.class.interface) +export(utils.class.interface.implements) export(utils.uniroot.vectorized) import(Rcpp) import(methods) diff --git a/R/distribution_R6_class.R b/R/distribution_R6_class.R index ff6ba92..47be0eb 100644 --- a/R/distribution_R6_class.R +++ b/R/distribution_R6_class.R @@ -45,7 +45,7 @@ distribution.interface <- utils.class.interface( # distribution.abstract.class ###################################################################/ #' Class: `distribution.abstract.class` -#' @description Base class for all derived distributions +#' @description Base class for derived distributions #' #' @param x vector of quantiles. #' @param q vector of quantiles. @@ -62,6 +62,7 @@ distribution.interface <- utils.class.interface( #' @field param_names The names of all distribution parameters #' @field params Named list of distribution parameters #' @field interfaces The list of available class interfaces + distribution.abstract.class <- utils.class( "distribution.abstract.class", interfaces = list( distribution.interface ), diff --git a/R/utils.R b/R/utils.R new file mode 100644 index 0000000..a4e2ceb --- /dev/null +++ b/R/utils.R @@ -0,0 +1,208 @@ +###################################################################/ +# Name: utils.uniroot.vectorized +# Description: vectorized version of the Brent Root algorithm +# useful for when solving many similar optimization problems +# where evaluation of the objective can be calculated far more +# efficiently when vectorized +# Args: f - function to optimize over with single input (vector) and outputs a vector of values +# lower - vector of left hand boundary +# upper - vector of right hand boundary +# Return: +# Author: Rob Hinch +###################################################################/ +#' Vectorised uniroot +#' +#' A vectorised version of the Brent root algorithm. Useful for solving many +#' similar optimisation problems where evaluation of the objective function can +#' be calculated more efficiently when vectorised +#' +#' @param f the objective function to optimise over. Must take a single vector +#' as input and output a vector of values +#' @param lower vector of left hand end points of the intervals to optimise over +#' @param upper vector of right hand end points of the intervals to optimise over +#' @param tol the desired accuracy (convergence tolerance) +#' @param itmax the maximum number of iterations +#' @param eps XXXXX TODO +#' +#' @return vector of roots for each interval \code{[lower, upper ]} +#' +#' @examples +#' f <- function( x ) ( x^2 - 1 ) * ( x^2 - 2 ) +#' lower <- c( -2, -1.2, 0, 1.2 ) +#' upper <- c( -1.2, 0, 1.2, 2 ) +#' utils.uniroot.vectorized( f, lower, upper ) +#' +#' @export + +utils.uniroot.vectorized = function( f, lower, upper, tol = 1e-8, itmax = 100, eps = 1e-10 ){ + # check the initial bracket + f_lower <- f( lower ) + f_upper <- f( upper ) + + if ( any( sign( f_lower * f_upper ) == 1 ) ) + stop( "all roots must be bracketed" ) + + # main bracketing loop + d <- rep( NA, length( lower ) ) + e <- rep( NA, length( lower ) ) + c <- upper + fc <- f_upper + for( ii in 1:itmax ) + { + flip <- ( fc > 0 & f_upper > 0 ) | (fc < 0 & f_upper < 0 ) + c <- ifelse( flip, lower, c ) + fc <- ifelse( flip, f_lower, fc ) + d <- ifelse( flip, upper-lower, d ) + e <- ifelse( flip, d, e ) + + closer <- abs( fc ) < abs( f_upper ) + lower <- ifelse( closer, upper, lower ) + upper <- ifelse( closer, c, upper ) + c <- ifelse( closer, lower, c ) + f_lower <- ifelse( closer, f_upper, f_lower ) + f_upper <- ifelse( closer, fc, f_upper ) + fc <- ifelse( closer, f_lower, fc ) + + # convergence check + tol1 <- 2 * eps * abs(upper) + 0.5 * tol + xm <- 0.5 * (c-upper) + + if( min( ( abs( xm ) < tol1 ) | f_upper == 0 ) ) + return( upper ) + + # Attempt inverse quadratic interpolation + s <- f_upper/f_lower + q <- f_lower/fc + r <- f_upper/fc + p <- ifelse( lower == c, 2 * xm * s, s * ( 2 * xm * q * (q-r) - (upper-lower) * (r-1) ) ) + q <- ifelse( lower == c, 1-s, ( q - 1 ) * ( r - 1 ) * (s - 1 ) ) + + # Check whether in bounds. + q <- ifelse( p > 0, -q, q ) + p <- abs( p ) + + # accept interpolation + accept <- ( abs(e) >= tol1 ) & + ( abs(f_lower) > abs(f_upper) ) & + ( 2 * p < pmin( 3 * xm * q - abs( tol1*q ), abs( e*q ) ) ) + e <- ifelse( accept, d, xm ) + d <- ifelse( accept, p/q, xm ) + + # Move last best guess to lower. + lower <- upper + f_lower <- f_upper + upper <- ifelse( abs( d ) > tol1, upper + d, upper + tol1 * sign( xm ) ) + f_upper <- f( upper ) + } + + stop( "exceeding maximum iterations" ) + return( upper ) +} + +###################################################################/ +# Name: utils.optimise.vectorized +# Description: vectorized version of the Brent Minimisation algorithm +# useful for when solving many similar optimization problems +# where evaluation of the objective can be calculated far more +# efficiently when vectorized +# Args: f - function to optimize over with single input (vector) and outputs a vector of values +# a - vector of left hand boundary +# v - vector of best initial guesses +# b - vector of right hand boundary +# Return: TRUE/FALSE +# Author: Rob Hinch +###################################################################/ +utils.optimise.vectorized = function( f, a, v, b, tol = 1e-6, maximum = F, itmax = 100 ) +{ + # constants + cgold = 0.3819660 + zeps = 1e-10 + + # initial data check + if( sum( a > v ) > 0 | sum( v > b ) > 0 ) + stop( "The initial value and boundary must saisfy a ( tol2 - 0.5 * ( b - a ) ) ) == 0 ) + break + + # fit parabola for all + r = ( x - w ) * ( fx - fv ) + q = ( x - v ) * ( fx - fw ) + p = ( x - v ) * q - ( x - w ) * r + q = 2 * ( q - r ) + p = -p * sign( q ) + q = abs( q ) + etemp =e + e = d + dd = p / ( q + zeps ); # avoid divide by 0 + u = x + dd + d = dd + ( u - a < tol2 | b - u < tol2 ) * ( ( xm - x > 0 ) * tol - dd ) + + # next find golden ratio point + eg = -x + a + ( x < xm ) * ( b - a ) + dg = eg * cgold + + # decide whether to pick the paroabola min or golden ratio point + absEtemp = abs( etemp ) + gr = ( absEtemp > tol1 ) & ( 2 * abs( p ) < q * absEtemp ) & ( dd > (a-x) ) & ( dd < ( b - x ) ) + e = eg + gr * ( e -eg) + d = dg + gr * ( d - dg ) + + # move to the next point, if distance is smaller than tolerance then move by tolerance + u = x + d + ( abs( d ) < tol1 ) * ( sign( d ) * tol1 - d ) + + # single function evaluation at next point + fu = f( u ) + if( maximum ) + fu = - fu + + # finally book-keeping to see which point to update + cond1 = fu >= fx + cond2 = u < x + cond3 = fu > fw & w != x + cond4 = cond1 | cond2 + cond5 = !( cond1 & cond2 ) + cond6 = cond1 | !cond2 + cond7 = !cond1 | cond2 + cond8 = cond1 & cond3 + + a = x + cond4 * ( -x + u + cond5 * ( a - u ) ) + b = x + cond6 * ( -x + u + cond7 * ( b - u ) ) + v = w + cond8 * ( u - w ) + fv = fw + cond8 * ( fu - fw ) + w = x + cond1 * ( -x + u + cond3 * ( w -u ) ) + fw = fx + cond1 * ( -fx + fu + cond3 * ( fw -fu ) ) + x = u + cond1 * ( -u + x ) + fx = fu + cond1 * ( -fu + fx ) + } + if( maximum ) + return( list( maximum = x, objective = - fx ) ) + else + return( list( minimum = x, objective = fx ) ) +} \ No newline at end of file diff --git a/R/utils_R6.R b/R/utils_R6.R index 5222edb..d76158c 100644 --- a/R/utils_R6.R +++ b/R/utils_R6.R @@ -1,24 +1,102 @@ -# TODO: Documentation for the base class +# Returns a character vector containing function arguments without a set default +# value set, e.g. +# .get_required_args( function( x, y = 1 ) NULL ) +# returns +# "x", +# but +# .get_required_args( function( x, y ) NULL ) +# returns +# c( "x", "y" ) .get_required_args <- function( func ) { args <- formals( func ) rArgs <- unlist( lapply( args, function( x ) ifelse( length(x)==1, x == "", FALSE ) ) ) - if( !length( rArgs ) ) + if( !length( rArgs ) ) # Equivalent to if ( length( rArgs ) == 0 ) return( c() ) rArgs <- names( args )[ which( rArgs ) ] rArgs <- rArgs[ which( rArgs != "..." ) ] return( rArgs ) } -###################################################################################/ + +# Helper function to validate that all methods of type `method_type` defined on +# the interface are defined correctly on the class +# +# @param iMethod_list list of methods defined on the interface to validate +# against +# @param method_list list of methods defined on the class to be validated +# @param iName name of the interface contributing methods iMethod_list +# @param method_type type of method being validated; used for informative +# error messages +# +# Variable names with prefix i are related to the interface not the defined +# class. + +.validate_interface_methods <- function( method_list, iMethod_list, iName, + error_type = "public method" ){ + # For each method defined on the interface of type `method_type`, check that + # the class defines a method with the same name and the same set of required + # arguments + methNames <- names( method_list ) + for ( iMethod in iMethod_list ){ + if ( !is.null( iMethod ) ){ + for ( iMethName in names( iMethod ) ){ + # clone method must exist on R6 class and does not need to be checked + if ( iMethName == "clone" ) next + + # Check iMethName is defined on class (with any set of arguments) + if ( !( iMethName %in% methNames ) ){ + stop( sprintf( "must implement %s %s on interface %s", + error_type, iMethName, iName)) + } + + # Check required arguments for interface public method + iArgs <- formalArgs( iMethod[[ iMethName ]] ) + r_iArgs <- .get_required_args( iMethod[[ iMethName ]] ) + + # Check required arguments for new class public method + args <- formalArgs( method_list[[ iMethName ]] ) + r_args <- .get_required_args( method_list[[ iMethName ]] ) + + if( length( r_iArgs ) ) { + if( !all( r_iArgs %in% args ) ) + stop( sprintf( "incorrect arguments for %s %s on interface %s", + error_type, iMethName, iName ) ) + } + if( length( r_args ) ) { + if( !all( r_args %in% iArgs ) ) + stop( sprintf( "incorrect arguments for %s %s on interface %s", + error_type, iMethName, iName ) ) + } + } + } + } +} + +################################################################################/ # utils.class # -# add interfaces to R6 class infrastructure # Author: Rob Hinch -###################################################################################/ -##### NOTE: All derived R6 classes using interfaces should include this file via -##### Roxygen using the include tag: #' @include R6_util_class.R to update the -##### collate field in DESCRIPTION +################################################################################/ +##### NOTE: All derived R6 classes using interfaces in mastiff should include +##### this file via Roxygen using the include tag: #' @include R6_util_class.R +##### to update the collate field in DESCRIPTION. +##### +##### Typically this is only for safety, but if a derived class is defined with +##### a name alphabetically before R6_util_class.R and included in another file, +##### the collate order might matter, e.g. R6_a_new_class.R might break the +##### collate order if #' @include R6_a_new_class.R is ever used. + +#' Class: utils.class +#' +#' @description R6 object extending [R6::R6Class()] to include interfaces. +#' +#' @inheritParams R6::R6Class +#' @param interfaces An optional list of interfaces implemented for the derived +#' class. +#' +#' @export + utils.class = function( classname = NULL, public = list(), @@ -32,17 +110,17 @@ utils.class = function( lock_class = FALSE, cloneable = TRUE, parent_env = parent.frame() -) -{ +){ # check to see an inherited class has been created by utils.class if( !is.null( inherit ) ){ if( inherit$inherit != "utils.class.parent" ) - stop( "inherited classes must be created by utils.class (i.e. must inherited utils.class.class)" ) + stop( "inherited classes must be created by utils.class (i.e. must inherit utils.class.class)" ) } else{ inherit = utils.class.class } - # create an environment in the parent_env which just contains the name of the inherited generator + # create an environment in the parent_env which just contains the name of the + # inherited generator envir = new.env( parent = parent_env ) utils.class.parent = inherit assign( "utils.class.parent", utils.class.parent, envir = envir ) @@ -64,83 +142,37 @@ utils.class = function( interfaceNames = c() nInterfaces = length( interfaces ) if( nInterfaces ){ - publicNames <- names( publicMethods ) - privateNames <- names( privateMethods ) - activeNames <- names( activeMethods ) - for( k in 1:nInterfaces ){ if( interfaces[[ k ]]$inherit != "utils.class.interface.class" ) - stop( "interfaces must be inherited from utils.class.interface.class" ) + stop( "interfaces must be created by utils.class.interface (i.e. must inherit utils.class.interface.class" ) iName = interfaces[[ k ]]$classname - # check public methods first - for( iPublic in list( interfaces[[ k ]]$public_methods, interfaces[[ k ]]$public_fields ) ) - if( !is.null( iPublic ) ) - if( length( iPublic ) ) - for( j in 1:length( iPublic ) ){ - iMethName =names( iPublic )[ j ] - if( iMethName == "clone" ) - next(); - if( !( iMethName %in% publicNames ) ) - stop( sprintf( "must implement public method %s on interface %s", iMethName, iName ) ) - args <- formalArgs( publicMethods[[ iMethName ]] ) - iArgs <- formalArgs( iPublic[[ iMethName ]] ) - r_args <- .get_required_args( publicMethods[[ iMethName ]] ) - r_iArgs <- .get_required_args( iPublic[[ iMethName ]] ) - if( length( r_iArgs ) ) { - if( !all( r_iArgs %in% args ) ) - stop( sprintf( "incorrect arguments for private method %s on interface %s", iMethName, iName ) ) - } - if( length( r_args ) ) { - if( !all( r_args %in% iArgs ) ) - stop( sprintf( "incorrect arguments for private method %s on interface %s", iMethName, iName ) ) - } - } + # Validate public methods + .validate_interface_methods( + method_list = publicMethods, + iMethod_list = list( interfaces[[ k ]]$public_methods, + interfaces[[ k ]]$public_fields ), + iName, + error_type = "public method" + ) - # check private methods - for( iPrivate in list( interfaces[[ k ]]$private_methods, interfaces[[ k ]]$private_fields ) ) - if( !is.null( iPrivate ) ) - if( length( iPrivate ) ) - for( j in 1:length( iPrivate ) ){ - iMethName = names( iPrivate )[ j ] - if( !( iMethName %in% privateNames ) ) - stop( sprintf( "must implement private method %s on interface %s", iMethName, iName ) ) - args <- formalArgs( privateMethods[[ iMethName ]] ) - iArgs <- formalArgs( iPrivate[[ iMethName ]] ) - r_args <- .get_required_args( privateMethods[[ iMethName ]] ) - r_iArgs <- .get_required_args( iPrivate[[ iMethName ]] ) - if( length( r_iArgs ) ) { - if( !all( r_iArgs %in% args ) ) - stop( sprintf( "incorrect arguments for private method %s on interface %s", iMethName, iName ) ) - } - if( length( r_args ) ) { - if( !all( r_args %in% iArgs ) ) - stop( sprintf( "incorrect arguments for private method %s on interface %s", iMethName, iName ) ) - } - } + # Validate private methods + .validate_interface_methods( + method_list = privateMethods, + iMethod_list = list( interfaces[[ k ]]$private_methods, + interfaces[[ k ]]$private_fields ), + iName, + error_type = "private method" + ) - # check active methods - iActive = interfaces[[ k ]]$active - if( !is.null( iActive ) ) - if( length( iActive ) ) - for( j in 1:length( iActive ) ) - { - iMethName = names( iActive )[ j ] - if( !( iMethName %in% activeNames ) ) - stop( sprintf( "must implement active field %s on interface %s", iMethName, iName ) ) - args <- formalArgs( activeMethods[[ iMethName ]] ) - iArgs <- formalArgs( iActive[[ iMethName ]] ) - r_iArgs <- .get_required_args( iActive[[ iMethName ]] ) - if( length( r_iArgs ) ) { - if( !all( r_iArgs %in% args ) ) - stop( sprintf( "incorrect arguments for private method %s on interface %s", iMethName, iName ) ) - } - if( length( r_args ) ) { - if( !all( r_args %in% iArgs ) ) - stop( sprintf( "incorrect arguments for private method %s on interface %s", iMethName, iName ) ) - } - } + # Validate active methods + .validate_interface_methods( + method_list = activeMethods, + iMethod_list = list( interfaces[[ k ]]$active ), + iName, + error_type = "active field" + ) interfaceNames[ length( interfaceNames ) + 1 ] = iName } @@ -161,11 +193,11 @@ utils.class = function( parent_env = envir ) ) } -###################################################################################/ -# utils.class.class #### +################################################################################/ +# utils.class.class # # add interfaces to R6 class infrastructure -###################################################################################/ +################################################################################/ utils.class.class = R6::R6Class( "utils.class.class", private = list( @@ -182,22 +214,41 @@ utils.class.class = R6::R6Class( ) ) -###################################################################################/ -# utils.class.interface.class #### +################################################################################/ +# utils.class.interface.class # # add interfaces to R6 class infrastructure -###################################################################################/ +################################################################################/ +#' Class: `utils.class.interface.class` +#' +#' @description R6 class acting as base interface class. + utils.class.interface.class = R6::R6Class( "utils.class.interface.class", public = list( + ############################################################################/ + # is.interface + ############################################################################/ + #' @description Logical function indicating whether an object is an + #' interface. is.interface = function() return( TRUE ) ) ) -###################################################################################/ +################################################################################/ # utils.class.interface # add interfaces to R6 class infrastructure -###################################################################################/ +################################################################################/ +#' utils.class.interface +#' +#' Constructor function for objects of class [utils.class.interface.class] +#' +#' @param interfacename Name of the interface. The interface name is useful +#' primarily for S3 method dispatch. +#' @inheritParams R6::R6Class +#' +#' @export + utils.class.interface = function( interfacename = NULL, public = list(), @@ -212,16 +263,26 @@ utils.class.interface = function( inherit = utils.class.interface.class ) ) } -###################################################################################/ +################################################################################/ # utils.class.interface.implements # checks to see if an interface has been implemented # check private internal variable directly to prevent accidental name mismatches -###################################################################################/ +################################################################################/ +#' utils.class.interface.implements +#' +#' @description Checks to see whether interface `interfaceName` has been +#' implemented on object `object`. +#' +#' +#' @param object R6 object of class `utils.class`. +#' @param interfaceName Name of an interface to check for `object`. +#' +#' @export + utils.class.interface.implements = function( object, interfaceName -) -{ +){ if( !R6::is.R6( object ) | !inherits( object, "utils.class.class") ) stop( "object must be from a class generated by utils.class()" ) @@ -230,211 +291,3 @@ utils.class.interface.implements = function( return( length( intersect( object$.__enclos_env__$private$.INTERNAL_INTERFACES, interfaceName ) ) == 1 ) } - -###################################################################/ -# Name: utils.uniroot.vectorized -# Description: vectorized version of the Brent Root algorithm -# useful for when solving many similar optimization problems -# where evaluation of the objective can be calculated far more -# efficiently when vectorized -# Args: f - function to optimize over with single input (vector) and outputs a vector of values -# lower - vector of left hand boundary -# upper - vector of right hand boundary -# Return: -# Author: Rob Hinch -###################################################################/ -#' Vectorised uniroot -#' -#' A vectorised version of the Brent root algorithm. Useful for solving many -#' similar optimisation problems where evaluation of the objective function can -#' be calculated more efficiently when vectorised -#' -#' @param f the objective function to optimise over. Must take a single vector -#' as input and output a vector of values -#' @param lower vector of left hand end points of the intervals to optimise over -#' @param upper vector of right hand end points of the intervals to optimise over -#' @param tol the desired accuracy (convergence tolerance) -#' @param itmax the maximum number of iterations -#' @param eps XXXXX TODO -#' -#' @return vector of roots for each interval \code{[lower, upper ]} -#' -#' @examples -#' f <- function( x ) ( x^2 - 1 ) * ( x^2 - 2 ) -#' lower <- c( -2, -1.2, 0, 1.2 ) -#' upper <- c( -1.2, 0, 1.2, 2 ) -#' utils.uniroot.vectorized( f, lower, upper ) -#' -#' @export -utils.uniroot.vectorized = function( f, lower, upper, tol = 1e-8, itmax = 100, eps = 1e-10 ){ - # check the initial bracket - f_lower <- f( lower ) - f_upper <- f( upper ) - - if ( any( sign( f_lower * f_upper ) == 1 ) ) - stop( "all roots must be bracketed" ) - - # main bracketing loop - d <- rep( NA, length( lower ) ) - e <- rep( NA, length( lower ) ) - c <- upper - fc <- f_upper - for( ii in 1:itmax ) - { - flip <- ( fc > 0 & f_upper > 0 ) | (fc < 0 & f_upper < 0 ) - c <- ifelse( flip, lower, c ) - fc <- ifelse( flip, f_lower, fc ) - d <- ifelse( flip, upper-lower, d ) - e <- ifelse( flip, d, e ) - - closer <- abs( fc ) < abs( f_upper ) - lower <- ifelse( closer, upper, lower ) - upper <- ifelse( closer, c, upper ) - c <- ifelse( closer, lower, c ) - f_lower <- ifelse( closer, f_upper, f_lower ) - f_upper <- ifelse( closer, fc, f_upper ) - fc <- ifelse( closer, f_lower, fc ) - - # convergence check - tol1 <- 2 * eps * abs(upper) + 0.5 * tol - xm <- 0.5 * (c-upper) - - if( min( ( abs( xm ) < tol1 ) | f_upper == 0 ) ) - return( upper ) - - # Attempt inverse quadratic interpolation - s <- f_upper/f_lower - q <- f_lower/fc - r <- f_upper/fc - p <- ifelse( lower == c, 2 * xm * s, s * ( 2 * xm * q * (q-r) - (upper-lower) * (r-1) ) ) - q <- ifelse( lower == c, 1-s, ( q - 1 ) * ( r - 1 ) * (s - 1 ) ) - - # Check whether in bounds. - q <- ifelse( p > 0, -q, q ) - p <- abs( p ) - - # accept interpolation - accept <- ( abs(e) >= tol1 ) & - ( abs(f_lower) > abs(f_upper) ) & - ( 2 * p < pmin( 3 * xm * q - abs( tol1*q ), abs( e*q ) ) ) - e <- ifelse( accept, d, xm ) - d <- ifelse( accept, p/q, xm ) - - # Move last best guess to lower. - lower <- upper - f_lower <- f_upper - upper <- ifelse( abs( d ) > tol1, upper + d, upper + tol1 * sign( xm ) ) - f_upper <- f( upper ) - } - - stop( "exceeding maximum iterations" ) - return( upper ) -} - -###################################################################/ -# Name: utils.optimise.vectorized -# Description: vectorized version of the Brent Minimisation algorithm -# useful for when solving many similar optimization problems -# where evaluation of the objective can be calculated far more -# efficiently when vectorized -# Args: f - function to optimize over with single input (vector) and outputs a vector of values -# a - vector of left hand boundary -# v - vector of best initial guesses -# b - vector of right hand boundary -# Return: TRUE/FALSE -# Author: Rob Hinch -###################################################################/ -utils.optimise.vectorized = function( f, a, v, b, tol = 1e-6, maximum = F, itmax = 100 ) -{ - # constants - cgold = 0.3819660 - zeps = 1e-10 - - # initial data check - if( sum( a > v ) > 0 | sum( v > b ) > 0 ) - stop( "The initial value and boundary must saisfy a ( tol2 - 0.5 * ( b - a ) ) ) == 0 ) - break - - # fit parabola for all - r = ( x - w ) * ( fx - fv ) - q = ( x - v ) * ( fx - fw ) - p = ( x - v ) * q - ( x - w ) * r - q = 2 * ( q - r ) - p = -p * sign( q ) - q = abs( q ) - etemp =e - e = d - dd = p / ( q + zeps ); # avoid divide by 0 - u = x + dd - d = dd + ( u - a < tol2 | b - u < tol2 ) * ( ( xm - x > 0 ) * tol - dd ) - - # next find golden ratio point - eg = -x + a + ( x < xm ) * ( b - a ) - dg = eg * cgold - - # decide whether to pick the paroabola min or golden ratio point - absEtemp = abs( etemp ) - gr = ( absEtemp > tol1 ) & ( 2 * abs( p ) < q * absEtemp ) & ( dd > (a-x) ) & ( dd < ( b - x ) ) - e = eg + gr * ( e -eg) - d = dg + gr * ( d - dg ) - - # move to the next point, if distance is smaller than tolerance then move by tolerance - u = x + d + ( abs( d ) < tol1 ) * ( sign( d ) * tol1 - d ) - - # single function evaluation at next point - fu = f( u ) - if( maximum ) - fu = - fu - - # finally book-keeping to see which point to update - cond1 = fu >= fx - cond2 = u < x - cond3 = fu > fw & w != x - cond4 = cond1 | cond2 - cond5 = !( cond1 & cond2 ) - cond6 = cond1 | !cond2 - cond7 = !cond1 | cond2 - cond8 = cond1 & cond3 - - a = x + cond4 * ( -x + u + cond5 * ( a - u ) ) - b = x + cond6 * ( -x + u + cond7 * ( b - u ) ) - v = w + cond8 * ( u - w ) - fv = fw + cond8 * ( fu - fw ) - w = x + cond1 * ( -x + u + cond3 * ( w -u ) ) - fw = fx + cond1 * ( -fx + fu + cond3 * ( fw -fu ) ) - x = u + cond1 * ( -u + x ) - fx = fu + cond1 * ( -fu + fx ) - } - if( maximum ) - return( list( maximum = x, objective = - fx ) ) - else - return( list( minimum = x, objective = fx ) ) -} diff --git a/man/distribution.abstract.class.Rd b/man/distribution.abstract.class.Rd index 276ddba..6c9c196 100644 --- a/man/distribution.abstract.class.Rd +++ b/man/distribution.abstract.class.Rd @@ -4,7 +4,7 @@ \alias{distribution.abstract.class} \title{Class: \code{distribution.abstract.class}} \description{ -Base class for all derived distributions +Base class for derived distributions } \section{Super class}{ \code{mastiff::utils.class.class} -> \code{distribution.abstract.class} diff --git a/man/distribution.continuous.uniform.class.Rd b/man/distribution.continuous.uniform.class.Rd index e7f7601..5f97bda 100644 --- a/man/distribution.continuous.uniform.class.Rd +++ b/man/distribution.continuous.uniform.class.Rd @@ -4,7 +4,7 @@ \alias{distribution.continuous.uniform.class} \title{Class: \code{distribution.continuous.uniform.class}} \description{ -Derived class for an uniformly-distributed random variable on \code{[lower, upper]} +Derived class for an uniformly-distributed random variable on \code{[min, max]} } \section{Super classes}{ \code{mastiff::utils.class.class} -> \code{\link[mastiff:distribution.abstract.class]{mastiff::distribution.abstract.class}} -> \code{\link[mastiff:distribution.continuous.class]{mastiff::distribution.continuous.class}} -> \code{distribution.continuous.uniform.class} @@ -14,13 +14,13 @@ Derived class for an uniformly-distributed random variable on \code{[lower, uppe \describe{ \item{\code{interfaces}}{The list of available class interfaces} -\item{\code{mean}}{The mean of an exponential distribution with rate \verb{$params$rate}.} +\item{\code{mean}}{The mean of a uniform random variable on \code{[min, max]}.} -\item{\code{sd}}{The standard deviation of an exponential distribution with rate -\verb{$params$rate}.} +\item{\code{sd}}{The standard deviation of a uniform random variable on +\code{[min, max]}.} -\item{\code{var}}{The variance of an exponential distribution with rate -\verb{$params$rate}.} +\item{\code{var}}{The variance of a uniform random variable on \code{[min, +max]}.} } \if{html}{\out{}} } @@ -41,13 +41,15 @@ Derived class for an uniformly-distributed random variable on \code{[lower, uppe \subsection{Method \code{new()}}{ Create a new object of class \code{distribution.continuous.exponential.class} \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{distribution.continuous.uniform.class$new(rate = 1)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{distribution.continuous.uniform.class$new(min = 0, max = 1)}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ -\item{\code{rate}}{The rate of the exponential distribution} +\item{\code{min}}{The lower bound of the uniform distribution} + +\item{\code{max}}{The max bound of the uniform distribution} } \if{html}{\out{
}} } @@ -56,8 +58,8 @@ Create a new object of class \code{distribution.continuous.exponential.class} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-distribution.continuous.uniform.class-d}{}}} \subsection{Method \code{d()}}{ -Density function for an exponential random variable with -rate \code{params$rate}. +Density function for a uniform random variable on +\code{[min, max]}. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{distribution.continuous.uniform.class$d(x, log = FALSE)}\if{html}{\out{
}} } @@ -76,8 +78,8 @@ rate \code{params$rate}. \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-distribution.continuous.uniform.class-p}{}}} \subsection{Method \code{p()}}{ -Cumulative density function for an exponential random -variable with rate \code{params$rate}. +Cumulative density function for a uniform random variable on +\code{[min, max]}. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{distribution.continuous.uniform.class$p(q, lower.tail = TRUE, log.p = FALSE)}\if{html}{\out{
}} } @@ -99,7 +101,8 @@ otherwise, \eqn{P[X>x]}.} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-distribution.continuous.uniform.class-q}{}}} \subsection{Method \code{q()}}{ -Quantile function for an exponential random variable with +Quantile function for a uniform random variable on +\code{[min, max]}. rate \code{params$rate}. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{distribution.continuous.uniform.class$q(p, lower.tail = TRUE, log.p = FALSE)}\if{html}{\out{
}} @@ -122,8 +125,8 @@ otherwise, \eqn{P[X>x]}.} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-distribution.continuous.uniform.class-r}{}}} \subsection{Method \code{r()}}{ -Generates random deviates for an exponential random variable -with rate \code{params$rate}. +Generates random deviates for a uniform random variable on +\code{[min, max]}. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{distribution.continuous.uniform.class$r(n)}\if{html}{\out{
}} } diff --git a/man/distribution.exponential.Rd b/man/distribution.exponential.Rd index 498d053..3318a5d 100644 --- a/man/distribution.exponential.Rd +++ b/man/distribution.exponential.Rd @@ -4,20 +4,14 @@ \alias{distribution.exponential} \title{distribution.exponential} \usage{ -distribution.exponential(rate = 1) - distribution.exponential(rate = 1) } \arguments{ \item{rate}{vector of rates} } \value{ -An object of class [\link{distribution.continuous.exponential.class}] - An object of class [\link{distribution.continuous.exponential.class}] } \description{ -Constructor function for an object of class \code{distribution.continuous.exponential.class} - Constructor function for an object of class \code{distribution.continuous.exponential.class} } diff --git a/man/distribution.uniform.Rd b/man/distribution.uniform.Rd new file mode 100644 index 0000000..286c8c4 --- /dev/null +++ b/man/distribution.uniform.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/distribution_continuous.R +\name{distribution.uniform} +\alias{distribution.uniform} +\title{distribution.exponential} +\usage{ +distribution.uniform(min = 0, max = 1) +} +\arguments{ +\item{min}{The lower bound of the uniform distribution} + +\item{max}{The max bound of the uniform distribution} +} +\value{ +An object of class [\link{distribution.continuous.uniform.class}] +} +\description{ +Constructor function for an object of class \code{distribution.continuous.uniform.class} +} diff --git a/man/utils.class.Rd b/man/utils.class.Rd new file mode 100644 index 0000000..7d984c3 --- /dev/null +++ b/man/utils.class.Rd @@ -0,0 +1,67 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils_R6.R +\name{utils.class} +\alias{utils.class} +\title{Class: utils.class} +\usage{ +utils.class( + classname = NULL, + public = list(), + private = list(), + active = list(), + inherit = NULL, + interfaces = list(), + lock_objects = TRUE, + class = TRUE, + portable = TRUE, + lock_class = FALSE, + cloneable = TRUE, + parent_env = parent.frame() +) +} +\arguments{ +\item{classname}{Name of the class. The class name is useful primarily for S3 +method dispatch.} + +\item{public}{A list of public members, which can be functions (methods) and +non-functions (fields).} + +\item{private}{An optional list of private members, which can be functions +and non-functions.} + +\item{active}{An optional list of active binding functions.} + +\item{inherit}{A R6ClassGenerator object to inherit from; in other words, a +superclass. This is captured as an unevaluated expression which is +evaluated in \code{parent_env} each time an object is instantiated.} + +\item{interfaces}{An optional list of interfaces implemented for the derived +class.} + +\item{lock_objects}{Should the environments of the generated objects be +locked? If locked, new members can't be added to the objects.} + +\item{class}{Should a class attribute be added to the object? Default is +\code{TRUE}. If \code{FALSE}, the objects will simply look like +environments, which is what they are.} + +\item{portable}{If \code{TRUE} (the default), this class will work with +inheritance across different packages. Note that when this is enabled, +fields and members must be accessed with \code{self$x} or +\code{private$x}; they can't be accessed with just \code{x}.} + +\item{lock_class}{If \code{TRUE}, it won't be possible to add more members to +the generator object with \code{$set}. If \code{FALSE} (the default), then +it will be possible to add more members with \code{$set}. The methods +\code{$is_locked}, \code{$lock}, and \code{$unlock} can be used to query +and change the locked state of the class.} + +\item{cloneable}{If \code{TRUE} (the default), the generated objects will +have method named \code{$clone}, which makes a copy of the object.} + +\item{parent_env}{An environment to use as the parent of newly-created +objects.} +} +\description{ +R6 object extending \code{\link[R6:R6Class]{R6::R6Class()}} to include interfaces. +} diff --git a/man/utils.class.interface.Rd b/man/utils.class.interface.Rd new file mode 100644 index 0000000..0f5f6ff --- /dev/null +++ b/man/utils.class.interface.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils_R6.R +\name{utils.class.interface} +\alias{utils.class.interface} +\title{utils.class.interface} +\usage{ +utils.class.interface( + interfacename = NULL, + public = list(), + private = list(), + active = list() +) +} +\arguments{ +\item{interfacename}{Name of the interface. The interface name is useful +primarily for S3 method dispatch.} + +\item{public}{A list of public members, which can be functions (methods) and +non-functions (fields).} + +\item{private}{An optional list of private members, which can be functions +and non-functions.} + +\item{active}{An optional list of active binding functions.} +} +\description{ +Constructor function for objects of class \link{utils.class.interface.class} +} diff --git a/man/utils.class.interface.class.Rd b/man/utils.class.interface.class.Rd new file mode 100644 index 0000000..f198298 --- /dev/null +++ b/man/utils.class.interface.class.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils_R6.R +\name{utils.class.interface.class} +\alias{utils.class.interface.class} +\title{Class: \code{utils.class.interface.class}} +\description{ +R6 class acting as base interface class. +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-utils.class.interface.class-is.interface}{\code{utils.class.interface.class$is.interface()}} +\item \href{#method-utils.class.interface.class-clone}{\code{utils.class.interface.class$clone()}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-utils.class.interface.class-is.interface}{}}} +\subsection{Method \code{is.interface()}}{ +Logical function indicating whether an object is an +interface. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{utils.class.interface.class$is.interface()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-utils.class.interface.class-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{utils.class.interface.class$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/utils.class.interface.implements.Rd b/man/utils.class.interface.implements.Rd new file mode 100644 index 0000000..86e6ee3 --- /dev/null +++ b/man/utils.class.interface.implements.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils_R6.R +\name{utils.class.interface.implements} +\alias{utils.class.interface.implements} +\title{utils.class.interface.implements} +\usage{ +utils.class.interface.implements(object, interfaceName) +} +\arguments{ +\item{object}{R6 object of class \code{utils.class}.} + +\item{interfaceName}{Name of an interface to check for \code{object}.} +} +\description{ +Checks to see whether interface \code{interfaceName} has been +implemented on object \code{object}. +} diff --git a/man/utils.uniroot.vectorized.Rd b/man/utils.uniroot.vectorized.Rd index 036cc9a..2dab47b 100644 --- a/man/utils.uniroot.vectorized.Rd +++ b/man/utils.uniroot.vectorized.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/utils_R6.R +% Please edit documentation in R/utils.R \name{utils.uniroot.vectorized} \alias{utils.uniroot.vectorized} \title{Vectorised uniroot} diff --git a/tests/testthat/test-distribution_continuous.R b/tests/testthat/test-distribution_continuous.R index 41a4c77..bc5c3d4 100644 --- a/tests/testthat/test-distribution_continuous.R +++ b/tests/testthat/test-distribution_continuous.R @@ -38,7 +38,7 @@ test_that( "Default $p() and $q() return the correct CDF and quantile function o }, d = function( x, log = FALSE ) stats::dunif( x, min = private$.params$min, max = private$.params$max, log = log ), r = function( n ) stats::runif( n, min = private$.params$min, max = private$.params$max ) ) - )$new( min = unif_min, max = unif_max ) + )$new( min = unif_min, max = unif_max ) unif_class <- distribution.uniform( min = unif_min, max = unif_max ) test_p_q( unif_test_class, unif_class ) @@ -61,7 +61,7 @@ test_that( "Default $p() and $q() return the correct CDF and quantile function o }, d = function( x, log = FALSE ) stats::dexp( x, rate = private$.params$rate, log = log ), r = function( n ) stats::rexp( n, rate = private$.params$rate ) ) - )$new( rate = exp_rate ) + )$new( rate = exp_rate ) exp_class <- distribution.exponential( rate = exp_rate ) test_p_q( exp_test_class, exp_class ) @@ -123,120 +123,126 @@ test_that( "Default $p() and $q() return the correct CDF and quantile function o }) test_that( "distribution.exponential constructs a valid class", { - n <- 1e5 - tol <- 3 / sqrt( n ) - - X <- distribution.exponential( rate = 1 ) - - # Test that density is correct for initial rate parameter - expect_equal( X$d( x = 0 ), 1 ) - expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) - - # Test that $params can be updated via named list - expect_no_error( X$params <- list( rate = 10 ) ) - expect_equal({ - X$params <- list( rate = 10 ) - X$d( x = 0 ) - }, 10 ) - expect_equal({ - X$params <- list( rate = 10 ) - mean( X$r( n ) ) - }, X$mean, tolerance = tol ) - - # Test that elements of $params can be updated by name - expect_no_error( X$params$rate <- 1 ) - - # Test that invalid values of $params fail (via private$.check_params()) - expect_error( X$params$rate <- -1 ) - expect_error( X$params <- list( rate = -1 ) ) - expect_error( X$params$rate <- 'a' ) - expect_error( X$params <- list( rate = 'a' ) ) - - # Test that incorrectly named list $params is rejected - expect_error( X$params <- list( foo = 1 ) ) - expect_error( X$params <- list( 1 ) ) - expect_error( X$params <- list( rate = 1, - foo = 1 ) ) + withr::with_seed( 123, { + n <- 1e5 + tol <- 3 / sqrt( n ) + + X <- distribution.exponential( rate = 1 ) + + # Test that density is correct for initial rate parameter + expect_equal( X$d( x = 0 ), 1 ) + expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) + + # Test that $params can be updated via named list + expect_no_error( X$params <- list( rate = 10 ) ) + expect_equal({ + X$params <- list( rate = 10 ) + X$d( x = 0 ) + }, 10 ) + expect_equal({ + X$params <- list( rate = 10 ) + mean( X$r( n ) ) + }, X$mean, tolerance = tol ) + + # Test that elements of $params can be updated by name + expect_no_error( X$params$rate <- 1 ) + + # Test that invalid values of $params fail (via private$.check_params()) + expect_error( X$params$rate <- -1 ) + expect_error( X$params <- list( rate = -1 ) ) + expect_error( X$params$rate <- 'a' ) + expect_error( X$params <- list( rate = 'a' ) ) + + # Test that incorrectly named list $params is rejected + expect_error( X$params <- list( foo = 1 ) ) + expect_error( X$params <- list( 1 ) ) + expect_error( X$params <- list( rate = 1, + foo = 1 ) ) + }) }) test_that( "distribution.gamma constructs a valid class", { - n <- 1e5 - tol <- 3 / sqrt( n ) - - X <- distribution.gamma( shape = 1, rate = 1, scale = 1 ) - - # Test that distribution can be initialised with exactly one of rate or scale - expect_no_error( X <- distribution.gamma( shape = 1, rate = 0.5 ) ) - expect_no_error( X <- distribution.gamma( shape = 1, scale = 0.5 ) ) - - # ...and initialisation fails if rate != 1 / scale - expect_error( X <- distribution.gamma( shape = 1, rate = 1, scale = 100 ) ) - - # Test that density is correct for initial rate parameter - expect_equal( X$d( x = 1 ), - X$params$rate * exp( - X$params$rate ) ) - expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) - - # Test that $params can be updated via named list - expect_no_error( X$params <- list( shape = 0.5, - rate = 10 ) ) - expect_equal({ - X$params <- list( shape = 0.5, - rate = 10 ) - mean( X$r( n ) ) - }, X$mean, tolerance = tol ) - - # Test that elements of $params can be updated by name - expect_no_error( X$params$rate <- 10 ) - expect_no_error( X$params$scale <- 10 ) - - # Test that invalid values of $params fail (via private$.check_params()) - expect_error( X$params$shape <- -1 ) - expect_error( X$params$shape <- 'a' ) - expect_error( X$params$rate <- -1 ) - expect_error( X$params$rate <- 'a' ) - expect_error( X$params$scale <- -1 ) - expect_error( X$params$scale <- 'a' ) - - - # Test that incorrectly named list $params is rejected - expect_error( X$params <- list( foo = 1 ) ) - expect_error( X$params <- list( 1 ) ) - expect_error( X$params <- list( rate = 1, - foo = 1 ) ) + withr::with_seed( 123, { + n <- 1e5 + tol <- 3 / sqrt( n ) + + X <- distribution.gamma( shape = 1, rate = 1, scale = 1 ) + + # Test that distribution can be initialised with exactly one of rate or scale + expect_no_error( X <- distribution.gamma( shape = 1, rate = 0.5 ) ) + expect_no_error( X <- distribution.gamma( shape = 1, scale = 0.5 ) ) + + # ...and initialisation fails if rate != 1 / scale + expect_error( X <- distribution.gamma( shape = 1, rate = 1, scale = 100 ) ) + + # Test that density is correct for initial rate parameter + expect_equal( X$d( x = 1 ), + X$params$rate * exp( - X$params$rate ) ) + expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) + + # Test that $params can be updated via named list + expect_no_error( X$params <- list( shape = 0.5, + rate = 10 ) ) + expect_equal({ + X$params <- list( shape = 0.5, + rate = 10 ) + mean( X$r( n ) ) + }, X$mean, tolerance = tol ) + + # Test that elements of $params can be updated by name + expect_no_error( X$params$rate <- 10 ) + expect_no_error( X$params$scale <- 10 ) + + # Test that invalid values of $params fail (via private$.check_params()) + expect_error( X$params$shape <- -1 ) + expect_error( X$params$shape <- 'a' ) + expect_error( X$params$rate <- -1 ) + expect_error( X$params$rate <- 'a' ) + expect_error( X$params$scale <- -1 ) + expect_error( X$params$scale <- 'a' ) + + + # Test that incorrectly named list $params is rejected + expect_error( X$params <- list( foo = 1 ) ) + expect_error( X$params <- list( 1 ) ) + expect_error( X$params <- list( rate = 1, + foo = 1 ) ) + }) }) test_that( "distribution.normal constructs a valid class", { - n <- 1e5 - tol <- 3 / sqrt( n ) - - X <- distribution.normal( mean = 0, sd = 1 ) - - # Test that density is correct for initial rate parameter - expect_equal( X$d( x = 0 ), 1 / sqrt( 2 * pi ) ) - expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) - - # Test that $params can be updated via named list - expect_no_error( X$params <- list( mean = 1, - sd = 1 ) ) - expect_equal({ - X$params <- list( mean = 10, - sd = 1 ) - X$d( x = 10 ) - }, 1 / sqrt( 2 * pi ) ) - - # Test that elements of $params can be updated by name - expect_no_error( X$params$mean <- 1 ) - expect_no_error( X$params$sd <- 10 ) - - # Test that invalid values of $params fail (via private$.check_params()) - expect_error( X$params$mean <- 'a' ) - expect_error( X$params$sd <- -1 ) - expect_error( X$params$sd <- 'a' ) - - # Test that incorrectly named list $params is rejected - expect_error( X$params <- list( foo = 1 ) ) - expect_error( X$params <- list( 1 ) ) - expect_error( X$params <- list( rate = 1, - foo = 1 ) ) + withr::with_seed( 123, { + n <- 1e5 + tol <- 3 / sqrt( n ) + + X <- distribution.normal( mean = 0, sd = 1 ) + + # Test that density is correct for initial rate parameter + expect_equal( X$d( x = 0 ), 1 / sqrt( 2 * pi ) ) + expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) + + # Test that $params can be updated via named list + expect_no_error( X$params <- list( mean = 1, + sd = 1 ) ) + expect_equal({ + X$params <- list( mean = 10, + sd = 1 ) + X$d( x = 10 ) + }, 1 / sqrt( 2 * pi ) ) + + # Test that elements of $params can be updated by name + expect_no_error( X$params$mean <- 1 ) + expect_no_error( X$params$sd <- 10 ) + + # Test that invalid values of $params fail (via private$.check_params()) + expect_error( X$params$mean <- 'a' ) + expect_error( X$params$sd <- -1 ) + expect_error( X$params$sd <- 'a' ) + + # Test that incorrectly named list $params is rejected + expect_error( X$params <- list( foo = 1 ) ) + expect_error( X$params <- list( 1 ) ) + expect_error( X$params <- list( rate = 1, + foo = 1 ) ) + }) }) diff --git a/tests/testthat/test-distribution_discrete.R b/tests/testthat/test-distribution_discrete.R index 684f8d3..b02b55c 100644 --- a/tests/testthat/test-distribution_discrete.R +++ b/tests/testthat/test-distribution_discrete.R @@ -53,116 +53,122 @@ test_that( "Default $p() and $q() return the correct CDF and quantile function o }) test_that( "distribution.binomial constructs a valid class", { - n <- 1e5 - tol <- 3 / sqrt( n ) - - X <- distribution.binomial( size = 10, - prob = 0.5 ) - - # Test that density is correct for initial rate parameter - expect_equal( X$d( x = 0 ), 0.5^10 ) - expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) - - # Test that $params can be updated via named list - expect_no_error( X$params <- list( size = 5, - prob = 0.8 ) ) - expect_equal({ - X$params <- list( size = 5, - prob = 0.8 ) - X$d( x = 0 ) - }, 0.2^5 ) - expect_equal({ - X$params <- list( size = 5, - prob = 0.8 ) - mean( X$r( n ) ) - }, X$mean, tolerance = tol ) - - # Test that elements of $params can be updated by name - expect_no_error( X$params$size <- 1 ) - expect_no_error( X$params$prob <- 0 ) - - # Test that invalid values of $params fail (via private$.check_params()) - expect_error( X$params$size <- -1 ) - expect_error( X$params$size <- 0.5 ) - expect_error( X$params$size <- 'a' ) - expect_error( X$params$prob <- -1 ) - expect_error( X$params$prob <- 2 ) - expect_error( X$params$prob <- 'a' ) - - # Test that incorrectly named list $params is rejected - expect_error( X$params <- list( foo = 1 ) ) - expect_error( X$params <- list( 1 ) ) + withr::with_seed( 123, { + n <- 1e5 + tol <- 3 / sqrt( n ) + + X <- distribution.binomial( size = 10, + prob = 0.5 ) + + # Test that density is correct for initial rate parameter + expect_equal( X$d( x = 0 ), 0.5^10 ) + expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) + + # Test that $params can be updated via named list + expect_no_error( X$params <- list( size = 5, + prob = 0.8 ) ) + expect_equal({ + X$params <- list( size = 5, + prob = 0.8 ) + X$d( x = 0 ) + }, 0.2^5 ) + expect_equal({ + X$params <- list( size = 5, + prob = 0.8 ) + mean( X$r( n ) ) + }, X$mean, tolerance = tol ) + + # Test that elements of $params can be updated by name + expect_no_error( X$params$size <- 1 ) + expect_no_error( X$params$prob <- 0 ) + + # Test that invalid values of $params fail (via private$.check_params()) + expect_error( X$params$size <- -1 ) + expect_error( X$params$size <- 0.5 ) + expect_error( X$params$size <- 'a' ) + expect_error( X$params$prob <- -1 ) + expect_error( X$params$prob <- 2 ) + expect_error( X$params$prob <- 'a' ) + + # Test that incorrectly named list $params is rejected + expect_error( X$params <- list( foo = 1 ) ) + expect_error( X$params <- list( 1 ) ) + }) }) test_that( "distribution.poisson constructs a valid class", { - n <- 1e5 - tol <- 3 / sqrt( n ) - - X <- distribution.poisson( lambda = 1 ) - - # Test that density is correct for initial rate parameter - expect_equal( X$d( x = 0 ), exp( -1 ) ) - expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) - - # Test that $params can be updated via named list - expect_no_error( X$params <- list( lambda = 5 ) ) - expect_equal({ - X$params <- list( lambda = 5 ) - X$d( x = 0 ) - }, exp( -5 ) ) - expect_equal({ - X$params <- list( lambda = 5 ) - mean( X$r( n ) ) - }, X$mean, tolerance = tol ) - - # Test that elements of $params can be updated by name - expect_no_error( X$params$lambda <- 1 ) - - # Test that invalid values of $params fail (via private$.check_params()) - expect_error( X$params$lambda <- -1 ) - expect_error( X$params$lambda <- 'a' ) - - # Test that incorrectly named list $params is rejected - expect_error( X$params <- list( foo = 1 ) ) - expect_error( X$params <- list( 1 ) ) + withr::with_seed( 123, { + n <- 1e5 + tol <- 3 / sqrt( n ) + + X <- distribution.poisson( lambda = 1 ) + + # Test that density is correct for initial rate parameter + expect_equal( X$d( x = 0 ), exp( -1 ) ) + expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) + + # Test that $params can be updated via named list + expect_no_error( X$params <- list( lambda = 5 ) ) + expect_equal({ + X$params <- list( lambda = 5 ) + X$d( x = 0 ) + }, exp( -5 ) ) + expect_equal({ + X$params <- list( lambda = 5 ) + mean( X$r( n ) ) + }, X$mean, tolerance = tol ) + + # Test that elements of $params can be updated by name + expect_no_error( X$params$lambda <- 1 ) + + # Test that invalid values of $params fail (via private$.check_params()) + expect_error( X$params$lambda <- -1 ) + expect_error( X$params$lambda <- 'a' ) + + # Test that incorrectly named list $params is rejected + expect_error( X$params <- list( foo = 1 ) ) + expect_error( X$params <- list( 1 ) ) + }) }) test_that( "distribution.negative_binomial constructs a valid class", { - n <- 1e5 - tol <- 3 / sqrt( n ) - - X <- distribution.negative_binomial( size = 10, prob = 0.5 ) - - # # Test that density is correct for initial rate parameter - expect_equal( X$d( x = 0 ), 0.5^10 ) - expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) - - # Test that $params can be updated via named list - expect_no_error( X$params <- list( size = 10, - prob = 0.5 ) ) - expect_no_error( X$params <- list( size = 10, - mu = 10 ) ) - expect_no_error( X$params <- list( size = 10, - prob = 0.5, - mu = 10 ) ) - - #Test that $params throws an error is prob and mu are both set and are - #inconsistent - expect_error( X$params <- list( size = 10, - prob = 0.5, - mu = 0 ) ) - - # Test that $params throws an error if size is not set - expect_error( X$params <- list( prob = 0.5 ) ) - - # Test that invalid values of $params fail (via private$.check_params()) - expect_error( X$params$prob <- -1 ) - expect_error( X$params$prob <- 'a' ) - expect_error( X$params$mu <- -1 ) - expect_error( X$params$mu <- 'a' ) - expect_error( X$params$size <- -1 ) - expect_error( X$params$size <- 0.5 ) - expect_error( X$params$size <- 'a' ) + withr::with_seed( 123, { + n <- 1e5 + tol <- 3 / sqrt( n ) + + X <- distribution.negative_binomial( size = 10, prob = 0.5 ) + + # # Test that density is correct for initial rate parameter + expect_equal( X$d( x = 0 ), 0.5^10 ) + expect_equal( mean( X$r( n ) ), X$mean, tolerance = tol ) + + # Test that $params can be updated via named list + expect_no_error( X$params <- list( size = 10, + prob = 0.5 ) ) + expect_no_error( X$params <- list( size = 10, + mu = 10 ) ) + expect_no_error( X$params <- list( size = 10, + prob = 0.5, + mu = 10 ) ) + + #Test that $params throws an error is prob and mu are both set and are + #inconsistent + expect_error( X$params <- list( size = 10, + prob = 0.5, + mu = 0 ) ) + + # Test that $params throws an error if size is not set + expect_error( X$params <- list( prob = 0.5 ) ) + + # Test that invalid values of $params fail (via private$.check_params()) + expect_error( X$params$prob <- -1 ) + expect_error( X$params$prob <- 'a' ) + expect_error( X$params$mu <- -1 ) + expect_error( X$params$mu <- 'a' ) + expect_error( X$params$size <- -1 ) + expect_error( X$params$size <- 0.5 ) + expect_error( X$params$size <- 'a' ) + }) }) test_that( "distribution.point_mass constructs a valid class", { diff --git a/tests/testthat/test-R6_util_class.R b/tests/testthat/test-utils.R similarity index 98% rename from tests/testthat/test-R6_util_class.R rename to tests/testthat/test-utils.R index 71c4592..1842afe 100644 --- a/tests/testthat/test-R6_util_class.R +++ b/tests/testthat/test-utils.R @@ -1,9 +1,9 @@ test_that( "utils.uniroot.vectorised() returns the same root as stats::uniroot()", { tol <- 1e-8 - + # Quartic function with roots at +/- 1 and +/- sqrt( 2 ) test_func <- function( x ) ( x^2 - 1 ) * ( x^2 - 2 ) - + intervals <- list( c( - 2, -1.2 ), c( -1.2, 0 ), c( 0, 1.2 ), @@ -14,16 +14,16 @@ test_that( "utils.uniroot.vectorised() returns the same root as stats::uniroot() interval = intervals[[ idx]], tol = tol )$root } - + interval_lower <- sapply( intervals, min, na.rm = TRUE ) interval_upper <- sapply( intervals, max, na.rm = TRUE ) - + mastiff_uniroot <- utils.uniroot.vectorized( f = test_func, lower = interval_lower, upper = interval_upper, tol = tol ) - + expect_equal( mastiff_uniroot, stats_uniroot, tolerance = tol) diff --git a/tests/testthat/test-utils_R6.R b/tests/testthat/test-utils_R6.R index d545ee1..6122b57 100644 --- a/tests/testthat/test-utils_R6.R +++ b/tests/testthat/test-utils_R6.R @@ -1,24 +1,194 @@ -test_that( "utils.uniroot.vectorised() returns the same root as stats::uniroot()", { - # Quartic function with roots at +/- 1 and +/- sqrt( 2 ) - test_func <- function( x ) ( x - 1 ) * ( x + 1 ) * ( x^2 - 2 ) - - intervals <- list( c( - 2, -1.2 ), - c( -1.2, 0 ), - c( 0, 1.2 ), - c( 1.2, 2 ) ) - stats_uniroot <- numeric( length( intervals ) ) - for ( idx in seq_along( intervals ) ){ - stats_uniroot[ idx ] <- stats::uniroot( test_func, - interval = intervals[[ idx]] )$root - } - - interval_lower <- sapply( intervals, min, na.rm = TRUE ) - interval_upper <- sapply( intervals, max, na.rm = TRUE ) - - mastiff_uniroot <- utils.uniroot.vectorized( - f = test_func, - lower = interval_lower, - upper = interval_upper ) - - expect_true( TRUE ) -}) \ No newline at end of file +test_that( "Test interface and class names", { + interfaceA <- utils.class.interface( + interfacename = "test_interface", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ), + funcE = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ) + ) + expect_equal( interfaceA$classname, "test_interface" ) + + classA <- utils.class( + classname = "test_class", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x,y ) return( T ), funcE = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ), + interfaces = list( interfaceA ) + ) + + expect_equal( classA$classname, "test_class" ) + expect_equal( utils.class.interface.implements( classA$new(), "test_interface"), TRUE ) +} ) + +test_that( "Test class requires the methods on the interface", { + interfaceA <- utils.class.interface( + interfacename = "test_interface", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ), + funcE = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ) + ) + + # Incorrect private method name + expect_error( utils.class( + classname = "test_class", + private = list( funcD = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ), + funcE = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ), + interfaces = list( interfaceA ) + ) ) + + # Incorrect public method name + expect_error( utils.class( + classname = "test_class", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ), + funcD = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ), + interfaces = list( interfaceA ) + ) ) + + # Incorrect public method funcE signature + expect_error( utils.class( + classname = "test_class", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x ) return( T ), + funcE = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ), + interfaces = list( interfaceA ) + ) ) + + # Missing public method funcE + expect_error( utils.class( + classname = "test_class", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ) ), + active = list( funcC = function( x ) return( T ) ), + interfaces = list( interfaceA ) + ) ) + + # Incorrect active binding name + expect_error( utils.class( + classname = "test_class", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ), + funcE = function( x ) return( T ) ), + active = list( funcD = function( x ) return( T ) ), + interfaces = list( interfaceA ) + ) ) +} ) + +test_that( "Test class with 2 interfcaes", { + interfaceA <- utils.class.interface( + interfacename = "test_interface", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ), + funcE = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ) + ) + + interfaceB <- utils.class.interface( + interfacename = "test_interface2", + public = list( funcF = function( x ) return( T ) ), + ) + + classF <- utils.class( + classname = "test_class", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ), + funcE = function( x ) return( T ), + funcF = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ), + interfaces = list( interfaceA, interfaceB ) + ) + + expect_equal( R6::is.R6Class( classF ), TRUE ) + + # Missing public method funcE for interfaceA + expect_error( + utils.class( + classname = "test_class", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ), + funcF = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ), + interfaces = list( interfaceA, interfaceB ) + ) ) + + # Missing public method funcF for interfaceB + expect_error( + utils.class( + classname = "test_class", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ), + funcE = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ), + interfaces = list( interfaceA, interfaceB ) + ) ) +}) + +test_that("Test class method arguments can be defined in any order", { + interfaceC <- utils.class.interface( + interfacename = "test_interface3", + public = list( funcH = function( x, y ) return( T ) ), + ) + + # Define method arguments in the same order as interface + expect_no_error( utils.class( + classname = "test_class", + public = list( funcH = function( x, y ) return( T ) ), + interfaces = list( interfaceC ) + ) + ) + + # Define method arguments in different order to interface + expect_no_error( utils.class( + classname = "test_class", + public = list( funcH = function( y, x ) return( T ) ), + interfaces = list( interfaceC ) + ) ) +}) + +test_that("Test class methods match interface arguments exactly", { + interfaceC <- utils.class.interface( + interfacename = "test_interface3", + public = list( funcH = function( x ) return( T ) ), + ) + + # Additional argument y in public method funcH + expect_error( + utils.class( + classname = "test_class", + public = list( funcH = function( x, y ) return( T ) ), + interfaces = list( interfaceC ) + ) + ) + + expect_error( + utils.class( + classname = "test_class", + public = list( funcH = function( ) return( T ) ), + interfaces = list( interfaceC ) + ) + ) +}) + +test_that("Test interface on derived class checks base methods", { + interfaceA <- utils.class.interface( + interfacename = "test_interface", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ), + funcE = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ) + ) + + expect_no_error( + utils.class( + classname = "test_class", + private = list( funcA = function( x ) return( T ) ), + public = list( funcB = function( x, y ) return( T ), + funcE = function( x ) return( T ) ), + active = list( funcC = function( x ) return( T ) ) + ) ) +} ) \ No newline at end of file