Skip to content

Commit 03df383

Browse files
committed
Refactor the SLURM_NTASKS and SLURM_JOB_ID functionality out into separate utility functions, and add some more unit tests to increase code coverage
1 parent 410734b commit 03df383

File tree

3 files changed

+59
-25
lines changed

3 files changed

+59
-25
lines changed

src/slurmmanager.jl

+28-24
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,38 @@ mutable struct SlurmManager <: ClusterManager
1212
srun_proc
1313

1414
function SlurmManager(; launch_timeout=60.0, srun_post_exit_sleep=0.01)
15+
ntasks = get_slurm_ntasks_int()
16+
jobid = get_slurm_jobid_int()
1517

16-
jobid =
17-
if "SLURM_JOB_ID" in keys(ENV)
18-
ENV["SLURM_JOB_ID"]
19-
elseif "SLURM_JOBID" in keys(ENV)
20-
ENV["SLURM_JOBID"]
21-
else
22-
error("""
23-
SlurmManager must be constructed inside a slurm allocation environemnt.
24-
SLURM_JOB_ID or SLURM_JOBID must be defined.
25-
""")
26-
end
27-
28-
ntasks =
29-
if "SLURM_NTASKS" in keys(ENV)
30-
ENV["SLURM_NTASKS"]
31-
else
32-
error("""
33-
SlurmManager must be constructed inside a slurm environment with a specified number of tasks.
34-
SLURM_NTASKS must be defined.
35-
""")
36-
end
18+
new(jobid, ntasks, launch_timeout, srun_post_exit_sleep, nothing)
19+
end
20+
end
3721

38-
jobid = parse(Int, jobid)
39-
ntasks = parse(Int, ntasks)
22+
function get_slurm_ntasks_int()
23+
if haskey(ENV, "SLURM_NTASKS")
24+
ntasks_str = ENV["SLURM_NTASKS"]
25+
else
26+
msg = "SlurmManager must be constructed inside a Slurm allocation environment." *
27+
"SLURM_NTASKS must be defined."
28+
error(msg)
29+
end
30+
ntasks_int = parse(Int, ntasks_str)::Int
31+
return ntasks_int
32+
end
4033

41-
new(jobid, ntasks, launch_timeout, srun_post_exit_sleep, nothing)
34+
function get_slurm_jobid_int()
35+
if haskey(ENV, "SLURM_JOB_ID")
36+
jobid_str = ENV["SLURM_JOB_ID"]
37+
elseif haskey(ENV, "SLURM_JOBID")
38+
jobid_str = ENV["SLURM_JOBID"]
39+
else
40+
msg = "SlurmManager must be constructed inside a Slurm allocation environment." *
41+
"SLURM_JOB_ID or SLURM_JOBID must be defined."
42+
error(msg)
4243
end
44+
45+
jobid_int = parse(Int, jobid_str)::Int
46+
return jobid_int
4347
end
4448

4549
@static if Base.VERSION >= v"1.9.0"

test/runtests.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import Distributed
66
import Test
77

88
# Bring some names into scope, just for convenience:
9-
using Test: @testset, @test
9+
using Test: @testset, @test, @test_throws, @test_logs
1010

1111
const original_JULIA_DEBUG = strip(get(ENV, "JULIA_DEBUG", ""))
1212
if isempty(original_JULIA_DEBUG)
@@ -16,6 +16,10 @@ else
1616
end
1717

1818
@testset "SlurmClusterManager.jl" begin
19+
@testset "Unit tests" begin
20+
include("unit.jl")
21+
end
22+
1923
# test that slurm is available
2024
@test !(Sys.which("sinfo") === nothing)
2125

test/unit.jl

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
@testset "get_slurm_ntasks_int()" begin
2+
x = withenv("SLURM_NTASKS" => "12") do
3+
SlurmClusterManager.get_slurm_ntasks_int()
4+
end
5+
@test x == 12
6+
7+
withenv("SLURM_NTASKS" => nothing) do
8+
@test_throws ErrorException SlurmClusterManager.get_slurm_ntasks_int()
9+
end
10+
end
11+
12+
@testset "get_slurm_jobid_int()" begin
13+
x = withenv("SLURM_JOB_ID" => "34", "SLURM_JOBID" => nothing) do
14+
SlurmClusterManager.get_slurm_jobid_int()
15+
end
16+
@test x == 34
17+
18+
x = withenv("SLURM_JOB_ID" => nothing, "SLURM_JOBID" => "56") do
19+
SlurmClusterManager.get_slurm_jobid_int()
20+
end
21+
@test x == 56
22+
23+
withenv("SLURM_JOB_ID" => nothing, "SLURM_JOBID" => nothing) do
24+
@test_throws ErrorException SlurmClusterManager.get_slurm_jobid_int()
25+
end
26+
end

0 commit comments

Comments
 (0)