Functional programming: Sectioning function to a smaller domain

Søren Højsgaard

Sectioning a function domain with section_fun()

The section_fun utility in doBy creates a new function by fixing some arguments of an existing function. The result is a section of the original function, defined only on the remaining arguments.

For example, if you have:

$$ f(x,y) = x + y $$

then fixing \(x=10\) yields:

$$ f_x(y) = 10 + y $$

In R terms, section_fun lets you programmatically create such specialized versions.


How section_fun works

section_fun() offers three ways to fix arguments:

  1. Defaults (method = “def”) – Inserts the fixed values as defaults in the argument list.
  2. Substitution (method = “sub”) – Rewrites the function body with the fixed values.
  3. Environment (method = “env”) – Stores fixed values in an auxiliary environment.

Example:

fun  <- function(a, b, c=4, d=9) {
    a + b + c + d
}
fun_def <- section_fun(fun, list(b=7, d=10))
fun_def
#> function (a, c = 4, b = 7, d = 10) 
#> {
#>     a + b + c + d
#> }
fun_body <- section_fun(fun, list(b=7, d=10), method="sub")
fun_body
#> function (a, c = 4) 
#> {
#>     b = 7
#>     d = 10
#>     a + b + c + d
#> }
fun_env <- section_fun(fun, list(b=7, d=10), method = "env")
fun_env
#> function (a, c = 4) 
#> {
#>     . <- "use get_section(function_name) to see section"
#>     . <- "use get_fun(function_name) to see original function"
#>     args <- arg_getter()
#>     do.call(fun, args)
#> }
#> <environment: 0x5d58c7af6838>

You can inspect the environment-based section:

get_section(fun_env) 
#> $b
#> [1] 7
#> 
#> $d
#> [1] 10
## same as: attr(fun_env, "arg_env")$args 
get_fun(fun_env) 
#> <srcref: file "" chars 1:9 to 3:1>
## same as: environment(fun_env)$fun

Example evaluations:

fun(a=10, b=7, c=5, d=10)
#> [1] 32
fun_def(a=10, c=5)
#> [1] 32
fun_body(a=10, c=5)
#> [1] 32
fun_env(a=10, c=5)
#> [1] 32

Benchmarking example

Suppose you want to benchmark a function for different input values without writing repetitive code:

inv_toep <- function(n) {
    solve(toeplitz(1:n))
}

Instead of typing the following

microbenchmark(
    inv_toep(4), inv_toep(8), inv_toep(16),
    times=3
)

you can create specialized versions programmatically:

n.vec  <- c(4, 8, 16)
fun_list <- lapply(n.vec,
                   function(ni) {
                       section_fun(inv_toep, list(n=ni))
                   })
fun_list
#> [[1]]
#> function (n = 4) 
#> {
#>     solve(toeplitz(1:n))
#> }
#> 
#> [[2]]
#> function (n = 8) 
#> {
#>     solve(toeplitz(1:n))
#> }
#> 
#> [[3]]
#> function (n = 16) 
#> {
#>     solve(toeplitz(1:n))
#> }

Inspect and evaluate:

fun_list[[1]]
#> function (n = 4) 
#> {
#>     solve(toeplitz(1:n))
#> }
fun_list[[1]]()
#>      [,1] [,2] [,3] [,4]
#> [1,] -0.4  0.5  0.0  0.1
#> [2,]  0.5 -1.0  0.5  0.0
#> [3,]  0.0  0.5 -1.0  0.5
#> [4,]  0.1  0.0  0.5 -0.4

To use with microbenchmark, we need expressions:

bquote_list <- function(fun_list) {
    lapply(fun_list, function(g){
        bquote(.(g)())
    })
}

We get:

bq_fun_list <- bquote_list(fun_list)
bq_fun_list
#> [[1]]
#> (function (n = 4) 
#> {
#>     solve(toeplitz(1:n))
#> })()
#> 
#> [[2]]
#> (function (n = 8) 
#> {
#>     solve(toeplitz(1:n))
#> })()
#> 
#> [[3]]
#> (function (n = 16) 
#> {
#>     solve(toeplitz(1:n))
#> })()
bq_fun_list[[1]]
#> (function (n = 4) 
#> {
#>     solve(toeplitz(1:n))
#> })()
eval(bq_fun_list[[1]])
#>      [,1] [,2] [,3] [,4]
#> [1,] -0.4  0.5  0.0  0.1
#> [2,]  0.5 -1.0  0.5  0.0
#> [3,]  0.0  0.5 -1.0  0.5
#> [4,]  0.1  0.0  0.5 -0.4

Now run:

microbenchmark(
  list = bq_fun_list,
  times = 5
)
#> Unit: microseconds
#>                                                 expr   min    lq mean median
#>   (function (n = 4)  {     solve(toeplitz(1:n)) })()  8.93  9.04 23.3   10.1
#>   (function (n = 8)  {     solve(toeplitz(1:n)) })() 10.04 10.60 12.0   12.4
#>  (function (n = 16)  {     solve(toeplitz(1:n)) })() 15.96 16.35 18.7   17.1
#>    uq  max neval cld
#>  12.9 75.4     5   a
#>  13.2 13.7     5   a
#>  19.1 25.1     5   a

Running the code below provides a benchmark of the different ways of sectioning in terms of speed.

n.vec  <- seq(20, 80, by=20)
fun_def <- lapply(n.vec,
                  function(n){
                      section_fun(inv_toep, list(n=n), method="def")
                  })
fun_body <- lapply(n.vec,
                  function(n){
                      section_fun(inv_toep, list(n=n), method="sub")
                  })
fun_env <- lapply(n.vec,
                  function(n){
                      section_fun(inv_toep, list(n=n), method="env")
                  })

names(fun_def)  <- paste0("def", n.vec)
names(fun_body) <- paste0("body", n.vec)
names(fun_env)  <- paste0("env", n.vec)

bq_fun_list <- bquote_list(c(fun_def, fun_body, fun_env))
bq_fun_list |> head()
#> $def20
#> (function (n = 20) 
#> {
#>     solve(toeplitz(1:n))
#> })()
#> 
#> $def40
#> (function (n = 40) 
#> {
#>     solve(toeplitz(1:n))
#> })()
#> 
#> $def60
#> (function (n = 60) 
#> {
#>     solve(toeplitz(1:n))
#> })()
#> 
#> $def80
#> (function (n = 80) 
#> {
#>     solve(toeplitz(1:n))
#> })()
#> 
#> $body20
#> (function () 
#> {
#>     n = 20
#>     solve(toeplitz(1:n))
#> })()
#> 
#> $body40
#> (function () 
#> {
#>     n = 40
#>     solve(toeplitz(1:n))
#> })()

mb <- microbenchmark(
  list  = bq_fun_list,
  times = 2
)
mb
#> Unit: microseconds
#>    expr   min    lq  mean median     uq    max neval cld
#>   def20  29.3  29.3  38.7   38.7   48.1   48.1     2   a
#>   def40  76.4  76.4  80.7   80.7   85.1   85.1     2   a
#>   def60 168.9 168.9 181.9  181.9  195.0  195.0     2   a
#>   def80 343.8 343.8 344.6  344.6  345.5  345.5     2   a
#>  body20  42.9  42.9 427.8  427.8  812.6  812.6     2   a
#>  body40 104.6 104.6 535.5  535.5  966.4  966.4     2   a
#>  body60 267.4 267.4 691.1  691.1 1114.9 1114.9     2   a
#>  body80 518.9 518.9 961.4  961.4 1404.0 1404.0     2   a
#>   env20  28.3  28.3  41.0   41.0   53.6   53.6     2   a
#>   env40  84.4  84.4 620.9  620.9 1157.5 1157.5     2   a
#>   env60 184.5 184.5 232.9  232.9  281.4  281.4     2   a
#>   env80 345.0 345.0 353.5  353.5  362.1  362.1     2   a