Skip to content

Commit 3959172

Browse files
Henry PeteetGitHub Enterprise
Henry Peteet
authored and
GitHub Enterprise
committed
Speed up pytest GitHub check (Unity-Technologies#15)
1 parent ee49d16 commit 3959172

File tree

16 files changed

+187
-67
lines changed

16 files changed

+187
-67
lines changed

.github/workflows/nightly.yml

+71
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,74 @@ jobs:
2424
- name: Run markdown checker
2525
run: |
2626
pre-commit run --hook-stage manual markdown-link-check-full --all-files
27+
full-pytest:
28+
if: ${{ github.server_url == 'https://github.cds.internal.unity3d.com' }}
29+
runs-on: [ self-hosted, Linux, X64 ]
30+
# TODO: Re-use pytest workflow once https://github.com/github/roadmap/issues/257 is done.
31+
# steps:
32+
# - uses: actions/checkout@v2
33+
# - uses: ./.github/workflows/pytest.yml
34+
# with:
35+
# # Run all tests.
36+
# pytest_markers: "not slow or slow"
37+
env:
38+
TEST_ENFORCE_BUFFER_KEY_TYPES: 1
39+
strategy:
40+
# If one test in the matrix fails we still want to run the others.
41+
fail-fast: false
42+
matrix:
43+
python-version: [3.7.x, 3.8.x, 3.9.x]
44+
include:
45+
- python-version: 3.7.x
46+
pip_constraints: test_constraints_min_version.txt
47+
- python-version: 3.8.x
48+
pip_constraints: test_constraints_mid_version.txt
49+
- python-version: 3.9.x
50+
pip_constraints: test_constraints_max_version.txt
51+
steps:
52+
- uses: actions/checkout@v2
53+
- name: Set up Python
54+
uses: actions/setup-python@v2
55+
with:
56+
python-version: ${{ matrix.python-version }}
57+
# Caching not supported on GitHub Enterprise
58+
# See https://github.com/actions/cache/issues/505
59+
# - name: Cache pip
60+
# uses: actions/cache@v2
61+
# with:
62+
# # This path is specific to Ubuntu
63+
# path: ~/.cache/pip
64+
# # Look to see if there is a cache hit for the corresponding requirements file
65+
# key: ${{ runner.os }}-pip-${{ hashFiles('ml-agents/setup.py', 'ml-agents-envs/setup.py', 'test_requirements.txt', matrix.pip_constraints) }}
66+
# restore-keys: |
67+
# ${{ runner.os }}-pip-
68+
# ${{ runner.os }}-
69+
- name: Display Python version
70+
run: python -c "import sys; print(sys.version)"
71+
- name: Install dependencies
72+
run: |
73+
python -m pip install --upgrade pip
74+
python -m pip install --upgrade setuptools
75+
python -m pip install --progress-bar=off -e ./ml-agents-envs -c ${{ matrix.pip_constraints }}
76+
python -m pip install --progress-bar=off -e ./ml-agents -c ${{ matrix.pip_constraints }}
77+
python -m pip install --progress-bar=off -r test_requirements.txt -c ${{ matrix.pip_constraints }}
78+
python -m pip install --progress-bar=off -e ./ml-agents-plugin-examples -c ${{ matrix.pip_constraints }}
79+
- name: Save python dependencies
80+
run: |
81+
pip freeze > pip_versions-${{ matrix.python-version }}.txt
82+
cat pip_versions-${{ matrix.python-version }}.txt
83+
- name: Run pytest
84+
run: |
85+
pytest --cov=ml-agents --cov=ml-agents-envs \
86+
--cov-report=html --junitxml=junit/test-results-${{ matrix.python-version }}.xml \
87+
-p no:warnings -v -n auto
88+
- name: Upload pytest test results
89+
uses: actions/upload-artifact@v2
90+
with:
91+
name: artifacts-${{ matrix.python-version }}
92+
path: |
93+
htmlcov
94+
pip_versions-${{ matrix.python-version }}.txt
95+
junit/test-results-${{ matrix.python-version }}.xml
96+
# Use always() to always run this step to publish test results when there are test failures
97+
if: ${{ always() }}

.github/workflows/pytest.yml

+25-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@ on:
1111
push:
1212
branches: [main]
1313
workflow_dispatch:
14+
inputs:
15+
pytest_markers:
16+
description: "Restrict which tests to run based on pytest markers"
17+
required: false
18+
default: "not slow"
19+
type: string
20+
workflow_call:
21+
inputs:
22+
pytest_markers:
23+
required: false
24+
# Hacky way to make sure we run all tests
25+
default: "slow or not slow"
26+
type: string
1427

1528
jobs:
1629
pytest:
@@ -63,8 +76,19 @@ jobs:
6376
run: |
6477
pip freeze > pip_versions-${{ matrix.python-version }}.txt
6578
cat pip_versions-${{ matrix.python-version }}.txt
79+
- name: Get pytest marker
80+
id: pytest_marker
81+
run: |
82+
if [ "${{ github.event.inputs.pytest_markers }}" != "" ]; then
83+
echo "::set-output name=markers::${{ github.event.inputs.pytest_markers }}"
84+
else
85+
echo "::set-output name=markers::not slow"
86+
fi
6687
- name: Run pytest
67-
run: pytest --cov=ml-agents --cov=ml-agents-envs --cov-report=html --junitxml=junit/test-results-${{ matrix.python-version }}.xml -p no:warnings -v
88+
run: |
89+
pytest --cov=ml-agents --cov=ml-agents-envs \
90+
--cov-report=html --junitxml=junit/test-results-${{ matrix.python-version }}.xml \
91+
-p no:warnings -v -m "${{ steps.pytest_marker.outputs.markers }}" -n auto
6892
- name: Upload pytest test results
6993
uses: actions/upload-artifact@v2
7094
with:

.yamato/pytest-gpu.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pytest_gpu:
1212
python3 -u -m ml-agents.tests.yamato.setup_venv
1313
python3 -m pip install --progress-bar=off -r test_requirements.txt --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
1414
python3 -m pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
15-
python3 -m pytest -m "not check_environment_trains" --junitxml=junit/test-results.xml -p no:warnings
15+
python3 -m pytest -m "not slow" -n auto --junitxml=junit/test-results.xml -p no:warnings
1616
triggers:
1717
cancel_old_ci: true
1818
expression: |
File renamed without changes.

ml-agents-envs/mlagents_envs/registry/binary_utils.py

+26-19
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,20 @@
2424
BLOCK_SIZE = 8192
2525

2626

27-
def get_local_binary_path(name: str, url: str) -> str:
27+
def get_local_binary_path(name: str, url: str, tmp_dir: Optional[str] = None) -> str:
2828
"""
2929
Returns the path to the executable previously downloaded with the name argument. If
3030
None is found, the executable at the url argument will be downloaded and stored
3131
under name for future uses.
3232
:param name: The name that will be given to the folder containing the extracted data
3333
:param url: The URL of the zip file
34+
:param: tmp_dir: Optional override for the temporary directory to save binaries and zips in.
3435
"""
3536
NUMBER_ATTEMPTS = 5
36-
with FileLock(os.path.join(tempfile.gettempdir(), name + ".lock")):
37-
path = get_local_binary_path_if_exists(name, url)
37+
tmp_dir = tmp_dir or tempfile.gettempdir()
38+
lock = FileLock(os.path.join(tmp_dir, name + ".lock"))
39+
with lock:
40+
path = get_local_binary_path_if_exists(name, url, tmp_dir=tmp_dir)
3841
if path is None:
3942
logger.debug(
4043
f"Local environment {name} not found, downloading environment from {url}"
@@ -45,7 +48,7 @@ def get_local_binary_path(name: str, url: str) -> str:
4548
if path is not None:
4649
break
4750
try:
48-
download_and_extract_zip(url, name)
51+
download_and_extract_zip(url, name, tmp_dir=tmp_dir)
4952
except Exception:
5053
if attempt + 1 < NUMBER_ATTEMPTS:
5154
logger.warning(
@@ -54,7 +57,7 @@ def get_local_binary_path(name: str, url: str) -> str:
5457
)
5558
else:
5659
raise
57-
path = get_local_binary_path_if_exists(name, url)
60+
path = get_local_binary_path_if_exists(name, url, tmp_dir=tmp_dir)
5861

5962
if path is None:
6063
raise FileNotFoundError(
@@ -64,15 +67,16 @@ def get_local_binary_path(name: str, url: str) -> str:
6467
return path
6568

6669

67-
def get_local_binary_path_if_exists(name: str, url: str) -> Optional[str]:
70+
def get_local_binary_path_if_exists(name: str, url: str, tmp_dir: str) -> Optional[str]:
6871
"""
6972
Recursively searches for a Unity executable in the extracted files folders. This is
7073
platform dependent : It will only return a Unity executable compatible with the
7174
computer's OS. If no executable is found, None will be returned.
7275
:param name: The name/identifier of the executable
7376
:param url: The url the executable was downloaded from (for verification)
77+
:param: tmp_dir: Optional override for the temporary directory to save binaries and zips in.
7478
"""
75-
_, bin_dir = get_tmp_dir()
79+
_, bin_dir = get_tmp_dirs(tmp_dir)
7680
extension = None
7781

7882
if platform == "linux" or platform == "linux2":
@@ -100,27 +104,27 @@ def get_local_binary_path_if_exists(name: str, url: str) -> Optional[str]:
100104
return None
101105

102106

103-
def _get_tmp_dir_helper():
104-
TEMPDIR = "/tmp" if platform == "darwin" else tempfile.gettempdir()
107+
def _get_tmp_dir_helper(tmp_dir: Optional[str] = None) -> Tuple[str, str]:
108+
tmp_dir = tmp_dir or ("/tmp" if platform == "darwin" else tempfile.gettempdir())
105109
MLAGENTS = "ml-agents-binaries"
106110
TMP_FOLDER_NAME = "tmp"
107111
BINARY_FOLDER_NAME = "binaries"
108-
mla_directory = os.path.join(TEMPDIR, MLAGENTS)
112+
mla_directory = os.path.join(tmp_dir, MLAGENTS)
109113
if not os.path.exists(mla_directory):
110114
os.makedirs(mla_directory)
111115
os.chmod(mla_directory, 16877)
112-
zip_directory = os.path.join(TEMPDIR, MLAGENTS, TMP_FOLDER_NAME)
116+
zip_directory = os.path.join(tmp_dir, MLAGENTS, TMP_FOLDER_NAME)
113117
if not os.path.exists(zip_directory):
114118
os.makedirs(zip_directory)
115119
os.chmod(zip_directory, 16877)
116-
bin_directory = os.path.join(TEMPDIR, MLAGENTS, BINARY_FOLDER_NAME)
120+
bin_directory = os.path.join(tmp_dir, MLAGENTS, BINARY_FOLDER_NAME)
117121
if not os.path.exists(bin_directory):
118122
os.makedirs(bin_directory)
119123
os.chmod(bin_directory, 16877)
120-
return (zip_directory, bin_directory)
124+
return zip_directory, bin_directory
121125

122126

123-
def get_tmp_dir() -> Tuple[str, str]:
127+
def get_tmp_dirs(tmp_dir: Optional[str] = None) -> Tuple[str, str]:
124128
"""
125129
Returns the path to the folder containing the downloaded zip files and the extracted
126130
binaries. If these folders do not exist, they will be created.
@@ -130,21 +134,24 @@ def get_tmp_dir() -> Tuple[str, str]:
130134
# Should only be able to error out 3 times (once for each subdir).
131135
for _attempt in range(3):
132136
try:
133-
return _get_tmp_dir_helper()
137+
return _get_tmp_dir_helper(tmp_dir)
134138
except FileExistsError:
135139
continue
136-
return _get_tmp_dir_helper()
140+
return _get_tmp_dir_helper(tmp_dir)
137141

138142

139-
def download_and_extract_zip(url: str, name: str) -> None:
143+
def download_and_extract_zip(
144+
url: str, name: str, tmp_dir: Optional[str] = None
145+
) -> None:
140146
"""
141147
Downloads a zip file under a URL, extracts its contents into a folder with the name
142148
argument and gives chmod 755 to all the files it contains. Files are downloaded and
143149
extracted into special folders in the temp folder of the machine.
144150
:param url: The URL of the zip file
145151
:param name: The name that will be given to the folder containing the extracted data
152+
:param: tmp_dir: Optional override for the temporary directory to save binaries and zips in.
146153
"""
147-
zip_dir, bin_dir = get_tmp_dir()
154+
zip_dir, bin_dir = get_tmp_dirs(tmp_dir)
148155
url_hash = "-" + hashlib.md5(url.encode()).hexdigest()
149156
binary_path = os.path.join(bin_dir, name + url_hash)
150157
if os.path.exists(binary_path):
@@ -206,7 +213,7 @@ def load_remote_manifest(url: str) -> Dict[str, Any]:
206213
"""
207214
Converts a remote yaml file into a Python dictionary
208215
"""
209-
tmp_dir, _ = get_tmp_dir()
216+
tmp_dir, _ = get_tmp_dirs()
210217
try:
211218
request = urllib.request.urlopen(url, timeout=30)
212219
except urllib.error.HTTPError as e: # type: ignore

ml-agents-envs/mlagents_envs/registry/remote_registry_entry.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
darwin_url: Optional[str],
1717
win_url: Optional[str],
1818
additional_args: Optional[List[str]] = None,
19+
tmp_dir: Optional[str] = None,
1920
):
2021
"""
2122
A RemoteRegistryEntry is an implementation of BaseRegistryEntry that uses a
@@ -39,6 +40,7 @@ def __init__(
3940
self._darwin_url = darwin_url
4041
self._win_url = win_url
4142
self._add_args = additional_args
43+
self._tmp_dir_override = tmp_dir
4244

4345
def make(self, **kwargs: Any) -> BaseEnv:
4446
"""
@@ -58,7 +60,9 @@ def make(self, **kwargs: Any) -> BaseEnv:
5860
f"The entry {self.identifier} does not contain a valid url for this "
5961
"platform"
6062
)
61-
path = get_local_binary_path(self.identifier, url)
63+
path = get_local_binary_path(
64+
self.identifier, url, tmp_dir=self._tmp_dir_override
65+
)
6266
if "file_name" in kwargs:
6367
kwargs.pop("file_name")
6468
args: List[str] = []

ml-agents-envs/tests/test_registry.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
1-
import shutil
21
import os
2+
from pathlib import Path
3+
4+
import pytest
35

46
from mlagents_envs.registry import default_registry, UnityEnvRegistry
57
from mlagents_envs.registry.remote_registry_entry import RemoteRegistryEntry
6-
from mlagents_envs.registry.binary_utils import get_tmp_dir
78

89
BASIC_ID = "Basic"
910

1011

11-
def delete_binaries():
12-
tmp_dir, bin_dir = get_tmp_dir()
13-
shutil.rmtree(tmp_dir)
14-
shutil.rmtree(bin_dir)
15-
16-
17-
def create_registry():
12+
def create_registry(tmp_dir: str) -> UnityEnvRegistry:
1813
reg = UnityEnvRegistry()
1914
entry = RemoteRegistryEntry(
2015
BASIC_ID,
@@ -23,20 +18,21 @@ def create_registry():
2318
"https://storage.googleapis.com/mlagents-test-environments/1.0.0/linux/Basic.zip",
2419
"https://storage.googleapis.com/mlagents-test-environments/1.0.0/darwin/Basic.zip",
2520
"https://storage.googleapis.com/mlagents-test-environments/1.0.0/windows/Basic.zip",
21+
tmp_dir=tmp_dir,
2622
)
2723
reg.register(entry)
2824
return reg
2925

3026

31-
def test_basic_in_registry():
27+
@pytest.mark.parametrize("n_ports", [2])
28+
def test_basic_in_registry(base_port: int, tmp_path: Path) -> None:
3229
assert BASIC_ID in default_registry
3330
os.environ["TERM"] = "xterm"
34-
delete_binaries()
35-
registry = create_registry()
31+
registry = create_registry(str(tmp_path))
3632
for worker_id in range(2):
3733
assert BASIC_ID in registry
3834
env = registry[BASIC_ID].make(
39-
base_port=6002, worker_id=worker_id, no_graphics=True
35+
base_port=base_port, worker_id=worker_id, no_graphics=True
4036
)
4137
env.reset()
4238
env.step()

0 commit comments

Comments
 (0)