Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SourceCodeMcCormick"
uuid = "a7283dc5-4ecf-47fb-a95b-1412723fc960"
authors = ["Robert Gottlieb <[email protected]>"]
version = "0.5.0"
version = "0.5.1"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
91 changes: 77 additions & 14 deletions src/kernel_writer/kernel_write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ kgen(num::Num, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwr
kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, gradlist, raw_outputs, constants, overwrite, splitting, affine_quadratic)
function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, constants::Vector{Num}, overwrite::Bool, splitting::Symbol, affine_quadratic::Bool)
# Create a hash of the expression and check if the function already exists
expr_hash = string(hash(num+sum(gradlist)), base=62)
expr_hash = string(hash(string(num)*string(gradlist)), base=62)
if (overwrite==false) && (isfile(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl")))
try func_name = eval(Meta.parse("f_"*expr_hash))
return func_name
Expand Down Expand Up @@ -102,9 +102,6 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
elseif splitting==:high # Formerly default
split_point = 1500
max_size = 2000
# elseif splitting==:high # More splitting
# split_point = 1000
# max_size = 1200
elseif splitting==:max # Extremely small
split_point = 500
max_size = 750
Expand All @@ -116,7 +113,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
sparsity = detect_sparsity(factored, gradlist)

# Decide if the kernel needs to be split
if (n_vars[end] < 31) && (n_lines[end] <= max_size)
if (n_vars[end] < 31) && ((n_lines[end] <= max_size) || (findfirst(x -> x > split_point, n_lines)==length(n_lines)))
# Complexity is fairly low; only a single kernel needed
create_kernel!(expr_hash, 1, num, get_name.(gradlist), func_outputs, constants, factored, sparsity)
push!(kernel_nums, 1)
Expand All @@ -130,7 +127,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
while !complete
# Determine which line to break at
line_ID = findfirst(x -> x > split_point, n_lines)
vars_ID = findfirst(x -> x == 31, n_vars)
vars_ID = findfirst(x -> (x == 30) || (x == 31), n_vars)
if isnothing(vars_ID)
new_ID = line_ID
elseif isnothing(line_ID)
Expand Down Expand Up @@ -188,7 +185,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
n_lines = complexity(factored)
n_vars = var_counts(factored)

# If the total number of lines (not including the final line) is below 2000
# If the total number of lines (not including the final line) is below the max size
# and the number of variables is below 32, we can make the final kernel and be done
if (n_vars[end] < 32) && (all(n_lines[1:end-1] .<= max_size))
create_kernel!(expr_hash, kernel_count, extract(factored), get_name.(gradlist), func_outputs, constants, factored, sparsity)
Expand Down Expand Up @@ -328,7 +325,12 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")

# Put in the preamble.
write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist)))
if isempty(vars)
write(file, preamble_string(expr_hash, ["OUT";], 1, 1, length(gradlist)))
else
write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist)))
end


# Depending on the format of the expression, compose the kernel differently
if typeof(expr) <: Real
Expand Down Expand Up @@ -360,9 +362,9 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
end
end
else # There must be two elements in the dictionary
binary_vars = string.(get_name.(keys(key.dict)))
binary_vars = string.(get_name.(keys(expr.dict)))
binary_vars = binary_vars[sort_vars(binary_vars)]
write(file, SCMC_quadaff_binary(vars..., expr.coeff, varlist))
write(file, SCMC_quadaff_binary(binary_vars..., expr.coeff, varlist))
end

elseif exprtype(expr)==ADD
Expand Down Expand Up @@ -394,7 +396,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
# EAGO already does this and bypasses the need to calculate relaxations.
# But, for compatibility with McCormick-style relaxations in ParBB,
# it's easier to simply calculate what ParBB is expecting.)
write(file, postamble_quadaff(string.(vars), varlist))
if isempty(varlist)
write(file, postamble_quadaff(String[], String[]))
elseif isempty(vars)
write(file, postamble_quadaff(String[], varlist))
else
write(file, postamble_quadaff(string.(vars), varlist))
end
close(file)

# Include this kernel so SCMC knows what it is
Expand All @@ -403,7 +411,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
# Add onto the file the "main" CPU function that calls the kernel
blocks = Int32(CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT))
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist)))
if isempty(gradlist)
write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, Symbol[]))
elseif isempty(vars)
write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, get_name.(gradlist)))
else
write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist)))
end
close(file)

# Include the file again to get the final kernel
Expand Down Expand Up @@ -731,6 +745,7 @@ end
# 7) log(inv(x1)) = -log(x1) [EAGO paper]
# 8) CONST1*CONST2*x1 = (CONST1*CONST2)*x1
# 9) 1 / (1 + exp(-x)) = Sigmoid(x)
# 10) sin(x) = cos(x - pi/2)
#
# Forms that aren't relevant yet:
# 1) (a^x1)^b = (a^b)^x1 [EAGO paper] (Can't do powers besides integers)
Expand Down Expand Up @@ -826,7 +841,7 @@ function perform_substitutions(old_factored::Vector{Equation})
end
end
# Create a factorization of this new expr
new_factorization = factor(new_expr)
new_factorization = factor(new_expr, split_div=true)
# Scan through the new factorization to see if we can merge elements
# with the original factored list
done = false
Expand Down Expand Up @@ -1191,7 +1206,7 @@ function perform_substitutions(old_factored::Vector{Equation})
new_expr *= arg
end
# Create a factorization of this new expr
new_factorization = factor(new_expr)
new_factorization = factor(new_expr, split_div=true)


# Scan through the new factorization to see if we can merge elements
Expand Down Expand Up @@ -1315,6 +1330,38 @@ function perform_substitutions(old_factored::Vector{Equation})
end
end
end

# 10) sin(x) = cos(x - pi/2)
if exprtype(factored[index0].rhs)==TERM
if factored[index0].rhs.f==sin
# We found sin(arg). Check if (arg - pi/2) exists,
# and if so, also check if cos(arg - pi/2) exists.
scan_flag = true
index1 = findfirst(x -> isequal(x.rhs, arguments(factored[index0].rhs)[] - pi/2), factored)
if !isnothing(index1)
index2 = findfirst(x -> isequal(x.rhs, cos(factored[index1].lhs)), factored)
if !isnothing(index2)
# cos(arg - pi/2) exists already (index2). Remove all reference to index0 and replace with index2
for i in eachindex(factored)
@eval $factored[$i] = $factored[$i].lhs ~ substitute($factored[$i].rhs, Dict($factored[$index0].lhs => $factored[$index2].lhs))
end
deleteat!(factored, index0)
else
# arg - pi/2 exists already (index1), but not cos(arg - pi/2). Change
# index0 to be cos of index1.lhs instead of sin of arg
@eval $factored[$index0] = $factored[$index0].lhs ~ cos($factored[$index1].lhs)
end
else
# (arg - pi/2) doesn't exist, so we need to create it
newsym = gensym(:aux)
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
newvar = genvar(newsym)
insert!(factored, index0, Equation(Symbolics.value(newvar), arguments(factored[index0].rhs)[] - pi/2))
@eval $factored[$index0+1] = $factored[$index0+1].lhs ~ cos($newvar)
end
break
end
end
end
end

Expand Down Expand Up @@ -1511,6 +1558,8 @@ function write_operation(file::IOStream, RHS::BasicSymbolic{Real}, inputs::Vecto
write(file, SCMC_sigmoid_kernel(inputs..., gradlist, sparsity))
elseif RHS.f==sqrt
write(file, SCMC_float_power_kernel(inputs..., 0.5, gradlist, sparsity))
elseif RHS.f==cos
write(file, SCMC_cos_kernel(inputs..., gradlist, sparsity))
else
close(file)
error("Some function was used that we can't handle yet ($RHS)")
Expand Down Expand Up @@ -1845,6 +1894,10 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
else
total_lines += 190
end
new_ID = findfirst(x -> isequal(x.lhs, RHS.base), factorized)
if !isnothing(new_ID)
total_lines += _complexity(complexity, factorized, new_ID)
end
elseif exprtype(RHS) == TERM
if RHS.f==exp
total_lines += 212 # Ranges from 212--310
Expand All @@ -1866,6 +1919,16 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
end
elseif RHS.f==sqrt
total_lines += 190
new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized)
if !isnothing(new_ID)
total_lines += _complexity(complexity, factorized, new_ID)
end
elseif RHS.f==cos || RHS.f==sin
total_lines += 300
new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized)
if !isnothing(new_ID)
total_lines += _complexity(complexity, factorized, new_ID)
end
else
error("Unknown function")
end
Expand Down
Loading
Loading