-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add some unit tests, and complete handling of scalar and list env var…
…iables
- Loading branch information
Showing
3 changed files
with
259 additions
and
80 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
from enum import Enum | ||
from typing import Optional, List | ||
|
||
class EnvVarOp (Enum): | ||
PREPEND=1 | ||
APPEND=2 | ||
SET=3 | ||
|
||
class EnvVarKind (Enum): | ||
SCALAR=2 | ||
LIST=2 | ||
|
||
list_variables = { | ||
"ACLOCAL_PATH", | ||
"CMAKE_PREFIX_PATH", | ||
"CPATH", | ||
"LD_LIBRARY_PATH", | ||
"LIBRARY_PATH", | ||
"MANPATH", | ||
"PATH", | ||
"PKG_CONFIG_PATH", | ||
} | ||
|
||
class EnvVarError(Exception): | ||
"""Exception raised when there is an error with environment variable manipulation.""" | ||
|
||
def __init__(self, message): | ||
self.message = message | ||
super().__init__(self.message) | ||
|
||
def __str__(self): | ||
return self.message | ||
|
||
def is_env_value_list(v): | ||
return isinstance(v, list) and all(isinstance(item, str) for item in v) | ||
|
||
class ListEnvVarUpdate(): | ||
def __init__(self, value: List[str], op: EnvVarOp): | ||
# strip white space from each entry | ||
self._value = [v.strip() for v in value] | ||
self._op = op | ||
|
||
@property | ||
def op(self): | ||
return self._op | ||
|
||
@property | ||
def value(self): | ||
return self._value | ||
|
||
def __repr__(self): | ||
return f"env.ListEnvVarUpdate({self.value}, {self.op})" | ||
|
||
def __str__(self): | ||
return f"({self.value}, {self.op})" | ||
|
||
class EnvVar: | ||
def __init__(self, name: str): | ||
self._name = name | ||
|
||
@property | ||
def name(self): | ||
return self._name | ||
|
||
class ListEnvVar(EnvVar): | ||
def __init__(self, name: str, value: List[str], op: EnvVarOp): | ||
super().__init__(name) | ||
|
||
self._updates = [ListEnvVarUpdate(value, op)] | ||
|
||
def update(self, value: List[str], op:EnvVarOp): | ||
self._updates.append(ListEnvVarUpdate(value, op)) | ||
|
||
@property | ||
def updates(self): | ||
return self._updates | ||
|
||
def concat(self, other: 'ListEnvVar'): | ||
self._updates += other.updates | ||
|
||
# Given the current value, return the value that should be set | ||
def get_value(self, current: Optional[str]): | ||
v = current | ||
|
||
# if the variable is currently not set, first initialise it as empty. | ||
if v is None: | ||
if len(self._updates)==0: | ||
return None | ||
v = "" | ||
|
||
for update in self._updates: | ||
joined = ":".join(update.value) | ||
if v == "" or update.op==EnvVarOp.SET: | ||
v = joined | ||
elif update.op==EnvVarOp.APPEND: | ||
v = ":".join([v, joined]) | ||
elif update.op==EnvVarOp.PREPEND: | ||
v = ":".join([joined, v]) | ||
else: | ||
raise EnvVarError(f"Internal error: implement the operation {update.op}"); | ||
# strip any leading/trailing ":" | ||
v = v.strip(':') | ||
|
||
return v | ||
|
||
def __repr__(self): | ||
return f"env.ListEnvVar(\"{self.name}\", {self._updates})" | ||
|
||
def __str__(self): | ||
return f"(\"{self.name}\": [{','.join([str(u) for u in self._updates])}])" | ||
|
||
|
||
class ScalarEnvVar(EnvVar): | ||
def __init__(self, name: str, value: Optional[str]): | ||
super().__init__(name) | ||
self._value = value | ||
|
||
@property | ||
def value(self): | ||
return self._value | ||
|
||
@property | ||
def is_null(self): | ||
return self.value is None | ||
|
||
def update(self, value: Optional[str]): | ||
self._value = value | ||
|
||
def __repr__(self): | ||
return f"env.ScalarEnvVar(\"{self.name}\", \"{self.value}\")" | ||
|
||
def __str__(self): | ||
return f"(\"{self.name}\": \"{self.value}\")" | ||
|
||
class Env: | ||
def __init__(self): | ||
self._vars = {} | ||
|
||
def apply(self, var: EnvVar): | ||
self._vars[var.name] = var | ||
|
||
# returns true if the environment variable with name is a list variable, | ||
# e.g. PATH, LD_LIBRARY_PATH, PKG_CONFIG_PATH, etc. | ||
def is_list_var(name: str) -> bool: | ||
return name in list_variables | ||
|
||
class Env: | ||
|
||
def __init__(self): | ||
self._lists = {} | ||
self._scalars = {} | ||
|
||
@property | ||
def lists(self): | ||
return self._listself._lists | ||
|
||
@property | ||
def scalars(self): | ||
return self._scalars | ||
|
||
def set_scalar(self, var: ScalarEnvVar): | ||
self._scalars[var.name] = var | ||
|
||
def set_list(self, var: ListEnvVar): | ||
if var.name in self._lists.keys(): | ||
old = self._lists[var.name] | ||
self._lists[var.name] = old.concat(var) | ||
else: | ||
self._lists[var.name] = var |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import shutil | ||
import unittest | ||
|
||
import env | ||
|
||
class TestListEnvVars(unittest.TestCase): | ||
|
||
def test_env_var_name(self): | ||
# test that the name is set correctly | ||
self.assertEqual("PATH", env.ListEnvVar("PATH", ["/foo/bin"], env.EnvVarOp.SET).name) | ||
|
||
def test_env_var_list_shor(self): | ||
self.assertEqual( | ||
"/foo/bin:/bar/bin", | ||
env.ListEnvVar("PATH", ["/foo/bin", "/bar/bin"], env.EnvVarOp.SET).get_value(None)) | ||
self.assertEqual( | ||
"/foo/bin:/bar/bin", | ||
env.ListEnvVar("PATH", ["/foo/bin", "/bar/bin"], env.EnvVarOp.SET).get_value("")) | ||
self.assertEqual( | ||
"/foo/bin:/bar/bin", | ||
env.ListEnvVar("PATH", ["/foo/bin", "/bar/bin"], env.EnvVarOp.SET).get_value("/wombat")) | ||
self.assertEqual( | ||
"/foo/bin:/bar/bin", | ||
env.ListEnvVar("PATH", ["/foo/bin", "/bar/bin"], env.EnvVarOp.PREPEND).get_value(None)) | ||
self.assertEqual( | ||
"/foo/bin:/bar/bin", | ||
env.ListEnvVar("PATH", ["/foo/bin", "/bar/bin"], env.EnvVarOp.PREPEND).get_value("")) | ||
self.assertEqual( | ||
"/foo/bin:/bar/bin:/wombat", | ||
env.ListEnvVar("PATH", ["/foo/bin", "/bar/bin"], env.EnvVarOp.PREPEND).get_value("/wombat")) | ||
self.assertEqual( | ||
"/foo/bin:/bar/bin", | ||
env.ListEnvVar("PATH", ["/foo/bin", "/bar/bin"], env.EnvVarOp.APPEND).get_value(None)) | ||
self.assertEqual( | ||
"/foo/bin:/bar/bin", | ||
env.ListEnvVar("PATH", ["/foo/bin", "/bar/bin"], env.EnvVarOp.APPEND).get_value("")) | ||
self.assertEqual( | ||
"/wombat:/foo/bin:/bar/bin", | ||
env.ListEnvVar("PATH", ["/foo/bin", "/bar/bin"], env.EnvVarOp.APPEND).get_value("/wombat")) | ||
|
||
def test_env_var_list_long(self): | ||
v = env.ListEnvVar("PATH", ["c"], env.EnvVarOp.PREPEND) | ||
v.update(["a","b"], env.EnvVarOp.PREPEND) | ||
v.update(["e","f"], env.EnvVarOp.APPEND) | ||
self.assertEqual("a:b:c:d:e:f", v.get_value("d")) | ||
self.assertEqual("a:b:c:e:f", v.get_value(None)) | ||
|
||
v = env.ListEnvVar("PATH", ["c"], env.EnvVarOp.SET) | ||
v.update(["a","b"], env.EnvVarOp.PREPEND) | ||
v.update(["e","f"], env.EnvVarOp.APPEND) | ||
self.assertEqual("a:b:c:e:f", v.get_value("d")) | ||
|
||
self.assertEqual("a:b:c:e:f", v.get_value(None)) | ||
|
||
def test_env_var_list_concat(self): | ||
v = env.ListEnvVar("PATH", ["a"], env.EnvVarOp.PREPEND) | ||
x = env.ListEnvVar("PATH", ["c"], env.EnvVarOp.APPEND) | ||
v.concat(x) | ||
self.assertEqual("a:b:c", v.get_value("b")) | ||
|
||
def test_env_var_scalars(self): | ||
v = env.ScalarEnvVar("HOME", "/users/bob") | ||
self.assertEqual("HOME", v.name) | ||
self.assertEqual("/users/bob", v.value) | ||
|
||
v = env.ScalarEnvVar("HOME", None) | ||
self.assertEqual("HOME", v.name) | ||
self.assertEqual(None, v.value) | ||
|
||
class TestEnvVarCategories(unittest.TestCase): | ||
|
||
def test_env_var_categories(self): | ||
self.assertTrue(env.is_list_var("LD_LIBRARY_PATH")) | ||
self.assertTrue(env.is_list_var("PKG_CONFIG_PATH")) | ||
self.assertFalse(env.is_list_var("HOME")) | ||
|
||
class TestEnv(unittest.TestCase): | ||
|
||
def test_env(self): | ||
e = env.Env() | ||
|
||
e.set_scalar(env.ScalarEnvVar("HOME", "/users/bob")) | ||
e.set_scalar(env.ScalarEnvVar("VISIBLE", "")) | ||
|
||
print() | ||
print(e.scalars) | ||
print() | ||
|
||
#for name, value in e.scalars: | ||
#print(f"{name}: {value}") |