Skip to content

Commit 70b3321

Browse files
Issue #19, added tests for discos-keygen
1 parent 0701192 commit 70b3321

5 files changed

Lines changed: 136 additions & 28 deletions

File tree

.coveragerc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
[run]
22
concurrency = thread
33
source = discos_client
4-
omit = discos_client/cli.py
54

65
[paths]
76
discos_client =

.gitignore

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,9 @@ cython_debug/
182182
.abstra/
183183

184184
# Visual Studio Code
185-
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
185+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186186
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187-
# and can be added to the global gitignore or merged into this file. However, if you prefer,
187+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
188188
# you could uncomment the following to ignore the entire vscode folder
189189
# .vscode/
190190

@@ -206,4 +206,4 @@ marimo/_static/
206206
marimo/_lsp/
207207
__marimo__/
208208

209-
209+
**.swp
Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,64 @@
11
import os
2-
import sys
32
from pathlib import Path
43
from argparse import ArgumentParser
54
from platformdirs import user_config_dir
65
from zmq.auth import create_certificates
76

8-
base_config = Path(user_config_dir("discos"))
9-
target_dir = base_config / "rpc" / "client"
10-
KEY_FILENAME = "identity"
11-
full_path_public = target_dir / f"{KEY_FILENAME}.key"
12-
full_path_secret = target_dir / f"{KEY_FILENAME}.key_secret"
7+
8+
def get_config_paths():
9+
base_config = Path(user_config_dir("discos"))
10+
config_dir = base_config / "rpc" / "client"
11+
public = config_dir / "identity.key"
12+
secret = config_dir / "identity.key_secret"
13+
return config_dir, public, secret
1314

1415

1516
def create_discos_keys(overwrite):
17+
config_dir, public, secret = get_config_paths()
1618

17-
if full_path_secret.exists() and not overwrite:
19+
if secret.exists() and not overwrite:
1820
print("Kept previously created key pair. "
1921
"Use --overwrite to replace it.\n")
20-
return
22+
return 0
2123

2224
try:
23-
target_dir.mkdir(parents=True, exist_ok=True)
25+
config_dir.mkdir(parents=True, exist_ok=True)
2426
except OSError as e:
2527
print(f"Error creating the configuration directory: {e}")
26-
sys.exit(1)
28+
return 1
2729

28-
create_certificates(str(target_dir), KEY_FILENAME)
30+
create_certificates(str(config_dir), "identity")
2931

3032
if os.name == 'posix':
31-
full_path_secret.chmod(0o600)
32-
(target_dir / f"{KEY_FILENAME}.key").chmod(0o644)
33-
print(f"Key pair created in: '{target_dir}'.")
33+
public.chmod(0o644)
34+
secret.chmod(0o600)
35+
print(f"Key pair created in: '{config_dir}'.")
36+
return 0
3437

3538

3639
def print_discos_keys():
37-
if not full_path_public.exists():
40+
_, public, _ = get_config_paths()
41+
42+
if not public.exists():
3843
print("No key was generated yet.")
39-
return
44+
return 0
4045

41-
with open(full_path_public, "r", encoding="utf-8") as f:
46+
with open(public, "r", encoding="utf-8") as f:
4247
print(f.read())
43-
print(f"\nPath of the public key file: {full_path_public}")
44-
print(f"Remember to never share the '{KEY_FILENAME}.key_secret' file with "
48+
49+
print(f"\nPath of the public key file: {public}")
50+
print("Remember to never share the 'identity.key_secret' file with "
4551
"anyone.")
4652
print(
4753
"In order to be authorized to send command to any of the telescopes, "
48-
f"remember to send a copy of the '{KEY_FILENAME}.key' file to the "
54+
"remember to send a copy of the 'identity.key' file to the "
4955
"DISCOS team, asking for authorization. Your request will be taken "
5056
"into consideration and you will hear back from the team."
5157
)
58+
return 0
5259

5360

54-
def main():
61+
def keygen():
5562
parser = ArgumentParser(
5663
"DISCOS CURVE key pairs generator."
5764
)
@@ -69,5 +76,9 @@ def main():
6976
args = parser.parse_args()
7077

7178
if not args.show_only:
72-
create_discos_keys(args.overwrite)
73-
print_discos_keys()
79+
return_code = create_discos_keys(args.overwrite)
80+
81+
if return_code != 0:
82+
return return_code
83+
84+
return print_discos_keys()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dependencies = [
2323
packages = ["discos_client"]
2424

2525
[project.scripts]
26-
discos-keygen = "discos_client.cli:main"
26+
discos-keygen = "discos_client.scripts:keygen"
2727

2828
[tool.setuptools.package-data]
2929
discos_client = ["schemas/**", "servers/**"]

tests/test_scripts.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import unittest
2+
import sys
3+
import shutil
4+
import tempfile
5+
from io import StringIO
6+
from pathlib import Path
7+
from unittest.mock import patch, MagicMock
8+
from platformdirs import user_config_dir
9+
from discos_client import scripts
10+
11+
12+
class TestKeygen(unittest.TestCase):
13+
14+
def setUp(self):
15+
self.test_dir = tempfile.mkdtemp()
16+
self.test_path = Path(self.test_dir)
17+
self.mock_target_dir = self.test_path / "rpc" / "client"
18+
self.mock_public = self.mock_target_dir / "identity.key"
19+
self.mock_secret = self.mock_target_dir / "identity.key_secret"
20+
21+
def tearDown(self):
22+
shutil.rmtree(self.test_dir)
23+
24+
def test_correct_paths(self):
25+
config_dir, public, secret = scripts.get_config_paths()
26+
expected_config_dir = \
27+
Path(user_config_dir("discos")) / "rpc" / "client"
28+
expected_public = expected_config_dir / "identity.key"
29+
expected_secret = expected_config_dir / "identity.key_secret"
30+
self.assertEqual(config_dir, expected_config_dir)
31+
self.assertEqual(public, expected_public)
32+
self.assertEqual(secret, expected_secret)
33+
34+
@patch("discos_client.scripts.get_config_paths")
35+
@patch("sys.stdout", new_callable=StringIO)
36+
@patch.object(sys, "argv", ["discos-keygen"])
37+
def test_keygen(self, mock_stdout, mock_paths):
38+
mock_paths.return_value = (
39+
self.mock_target_dir,
40+
self.mock_public,
41+
self.mock_secret
42+
)
43+
rc = scripts.keygen()
44+
self.assertEqual(rc, 0)
45+
self.assertTrue(self.mock_public.exists())
46+
self.assertTrue(self.mock_secret.exists())
47+
output = mock_stdout.getvalue()
48+
self.assertIn("Key pair created in", output)
49+
50+
@patch("discos_client.scripts.get_config_paths")
51+
@patch("sys.stdout", new_callable=StringIO)
52+
@patch.object(sys, "argv", ["discos-keygen"])
53+
def test_keygen_no_overwrite(self, mock_stdout, mock_paths):
54+
mock_paths.return_value = (
55+
self.mock_target_dir,
56+
self.mock_public,
57+
self.mock_secret
58+
)
59+
self.assertEqual(scripts.keygen(), 0)
60+
self.assertTrue(self.mock_public.exists())
61+
self.assertTrue(self.mock_secret.exists())
62+
output = mock_stdout.getvalue()
63+
self.assertIn("Key pair created in", output)
64+
self.assertEqual(scripts.keygen(), 0)
65+
output = mock_stdout.getvalue()
66+
self.assertIn("Kept previously created key pair", output)
67+
68+
@patch("discos_client.scripts.get_config_paths")
69+
@patch("sys.stdout", new_callable=StringIO)
70+
def test_print_keys(self, mock_stdout, mock_paths):
71+
mock_paths.return_value = (
72+
self.mock_target_dir,
73+
self.mock_public,
74+
self.mock_secret
75+
)
76+
scripts.print_discos_keys()
77+
output = mock_stdout.getvalue()
78+
self.assertIn("No key was generated yet.", output)
79+
80+
@patch("discos_client.scripts.get_config_paths")
81+
@patch("sys.stdout", new_callable=StringIO)
82+
@patch.object(sys, "argv", ["discos-keygen"])
83+
def test_mkdir_error(self, mock_stdout, mock_paths):
84+
mock_target_dir = MagicMock()
85+
mock_target_dir.mkdir.side_effect = OSError("Test error")
86+
mock_paths.return_value = (
87+
mock_target_dir,
88+
self.mock_public,
89+
self.mock_secret
90+
)
91+
rc = scripts.keygen()
92+
self.assertEqual(rc, 1)
93+
output = mock_stdout.getvalue()
94+
self.assertIn("Error creating the configuration directory", output)
95+
96+
97+
if __name__ == "__main__":
98+
unittest.main()

0 commit comments

Comments
 (0)