diff --git a/labgrid/remote/client.py b/labgrid/remote/client.py index 27108a7b4..5a6c550c9 100755 --- a/labgrid/remote/client.py +++ b/labgrid/remote/client.py @@ -28,6 +28,9 @@ import attr import grpc +# TODO: drop if Python >= 3.11 guaranteed +from exceptiongroup import ExceptionGroup # pylint: disable=redefined-builtin + from .common import ( ResourceEntry, ResourceMatch, @@ -57,7 +60,8 @@ class Error(Exception): - pass + def __str__(self): + return f"Error: {' '.join(self.args)}" class UserError(Error): @@ -72,6 +76,13 @@ class InteractiveCommandError(Error): pass +class ErrorGroup(ExceptionGroup): + def __str__(self): + # TODO: drop pylint disable once https://github.com/pylint-dev/pylint/issues/8985 is fixed + errors_combined = "\n".join(f"- {' '.join(e.args)}" for e in self.exceptions) # pylint: disable=not-an-iterable + return f"{self.message}:\n{errors_combined}" + + @attr.s(eq=False) class ClientSession: """The ClientSession encapsulates all the actions a Client can invoke on @@ -478,6 +489,17 @@ def get_place(self, place=None): raise UserError(f"pattern {pattern} matches multiple places ({', '.join(places)})") return self.places[places[0]] + def get_place_names_from_env(self): + """Returns a list of RemotePlace names found in the environment config.""" + places = [] + for role_config in self.env.config.get_targets().values(): + resources, _ = target_factory.normalize_config(role_config) + remote_places = resources.get("RemotePlace", []) + for place in remote_places: + places.append(place) + + return places + def get_idle_place(self, place=None): place = self.get_place(place) if place.acquired: @@ -681,17 +703,31 @@ def check_matches(self, place): raise UserError(f"Match {match} has no matching remote resource") async def acquire(self): + errors = [] + places = self.get_place_names_from_env() if self.env else [self.args.place] + for place in places: + try: + await self._acquire_place(place) + except Error as e: + errors.append(e) + + if errors: + if len(errors) == 1: + raise errors[0] + raise ErrorGroup("Multiple errors occurred during acquire", errors) + + async def _acquire_place(self, place): """Acquire a place, marking it unavailable for other clients""" - place = self.get_place() + place = self.get_place(place) if place.acquired: host, user = place.acquired.split("/") allowhelp = f"'labgrid-client -p {place.name} allow {self.gethostname()}/{self.getuser()}' on {host}." if self.getuser() == user: if self.gethostname() == host: - raise UserError("You have already acquired this place.") + raise UserError(f"You have already acquired place {place.name}.") else: raise UserError( - f"You have already acquired this place on {host}. To work simultaneously, execute {allowhelp}" + f"You have already acquired place {place.name} on {host}. To work simultaneously, execute {allowhelp}" ) else: raise UserError( @@ -727,8 +763,22 @@ async def acquire(self): raise ServerError(e.details()) async def release(self): + errors = [] + places = self.get_place_names_from_env() if self.env else [self.args.place] + for place in places: + try: + await self._release_place(place) + except Error as e: + errors.append(e) + + if errors: + if len(errors) == 1: + raise errors[0] + raise ErrorGroup("Multiple errors occurred during release", errors) + + async def _release_place(self, place): """Release a previously acquired place""" - place = self.get_place() + place = self.get_place(place) if not place.acquired: raise UserError(f"place {place.name} is not acquired") _, user = place.acquired.split("/") @@ -2212,11 +2262,11 @@ def main(): if args.debug: traceback.print_exc(file=sys.stderr) exitcode = e.exitcode - except Error as e: + except (Error, ErrorGroup) as e: if args.debug: traceback.print_exc(file=sys.stderr) else: - print(f"{parser.prog}: error: {e}", file=sys.stderr) + print(f"{parser.prog}: {e}", file=sys.stderr) exitcode = 1 except KeyboardInterrupt: exitcode = 1 diff --git a/pyproject.toml b/pyproject.toml index 33caa6f31..424226533 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ classifiers = [ ] dependencies = [ "attrs>=21.4.0", + "exceptiongroup>=1.3.0", # TODO: drop if Python >= 3.11 guaranteed "grpcio>=1.64.1, <2.0.0", "grpcio-reflection>=1.64.1, <2.0.0", "protobuf>=5.27.0", diff --git a/tests/test_client.py b/tests/test_client.py index 34297e7eb..8d2cc1c9f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,22 +9,33 @@ def test_startup(coordinator): pass @pytest.fixture(scope='function') -def place(coordinator): - with pexpect.spawn('python -m labgrid.remote.client -p test create') as spawn: - spawn.expect(pexpect.EOF) - spawn.close() +def place(create_place): + create_place('test') + +@pytest.fixture(scope='function') +def create_place(coordinator): + place_names = [] + + def _create_place(place_name): + with pexpect.spawn(f'python -m labgrid.remote.client -p {place_name} create') as spawn: + spawn.expect(pexpect.EOF) assert spawn.exitstatus == 0, spawn.before.strip() - with pexpect.spawn('python -m labgrid.remote.client -p test set-tags board=123board') as spawn: - spawn.expect(pexpect.EOF) - spawn.close() + place_names.append(place_name) + + with pexpect.spawn(f'python -m labgrid.remote.client -p {place_name} set-tags board=123board') as spawn: + spawn.expect(pexpect.EOF) assert spawn.exitstatus == 0, spawn.before.strip() - yield + yield _create_place - with pexpect.spawn('python -m labgrid.remote.client -p test delete') as spawn: - spawn.expect(pexpect.EOF) - spawn.close() + for place_name in place_names: + # clean up + with pexpect.spawn(f'python -m labgrid.remote.client -p {place_name} release') as spawn: + spawn.expect(pexpect.EOF) + + with pexpect.spawn(f'python -m labgrid.remote.client -p {place_name} delete') as spawn: + spawn.expect(pexpect.EOF) assert spawn.exitstatus == 0, spawn.before.strip() @pytest.fixture(scope='function') @@ -151,6 +162,49 @@ def test_place_acquire(place): spawn.close() assert spawn.exitstatus == 0, spawn.before.strip() +def test_place_acquire_multiple(create_place, tmpdir): + # create multiple places + place_names = ['test1', 'test2'] + for place_name in place_names: + create_place(place_name) + + # create env config with multiple RemotePlaces + p = tmpdir.join('config.yaml') + p.write('targets:') + for place_name in place_names: + p.write( + f""" + {place_name}: + resources: + RemotePlace: + name: {place_name} + """, + mode='a', + ) + + # acquire all places in env config + with pexpect.spawn(f'python -m labgrid.remote.client -c {p} acquire') as spawn: + spawn.expect(pexpect.EOF) + assert spawn.exitstatus == 0, spawn.before.strip() + + # check 'who' + with pexpect.spawn('python -m labgrid.remote.client who') as spawn: + spawn.expect(pexpect.EOF) + for place_name in place_names: + assert place_name.encode('utf-8') in spawn.before + + assert spawn.exitstatus == 0, spawn.before.strip() + + # release all places in env config + with pexpect.spawn(f'python -m labgrid.remote.client -c {p} release') as spawn: + spawn.expect(pexpect.EOF) + assert spawn.exitstatus == 0, spawn.before.strip() + + # check 'who' again + with pexpect.spawn('python -m labgrid.remote.client who') as spawn: + spawn.expect('User.*Host.*Place.*Changed\r\n') + assert not spawn.before, spawn.before + def test_place_acquire_enforce(place): with pexpect.spawn('python -m labgrid.remote.client -p test add-match does/not/exist') as spawn: spawn.expect(pexpect.EOF)