Skip to content

Commit 1cd8a82

Browse files
test autograd
1 parent ae541e1 commit 1cd8a82

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

tests/test_components/autograd/test_autograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = True) -> None:
662662
# args = [("polyslab", "mode")]
663663

664664

665-
def get_functions(structure_key: str, monitor_key: str) -> typing.Callable:
665+
def get_functions(structure_key: str, monitor_key: str) -> dict[str, typing.Callable]:
666666
if structure_key == ALL_KEY:
667667
structure_keys = structure_keys_
668668
else:

tests/test_web/test_simulation_cache.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import pytest
66

77
import tidy3d as td
8+
from tests.test_components.autograd.test_autograd import ALL_KEY, get_functions, params0
89
from tidy3d import config
9-
from tidy3d.web import Job, common, run_async, download
10-
from tidy3d.web import Job, run_async
10+
from tidy3d.web import Job, common, run_async
1111
from tidy3d.web.api import webapi as web
1212
from tidy3d.web.api.container import WebContainer
1313
from tidy3d.web.cache import (
@@ -96,7 +96,7 @@ def _extract_simulation(kwargs):
9696

9797
def _fake_upload(**kwargs):
9898
counters["upload"] += 1
99-
task_id = f"{MOCK_TASK_ID}{kwargs["simulation"]._hash_self()}"
99+
task_id = f"{MOCK_TASK_ID}{kwargs['simulation']._hash_self()}"
100100
sim = _extract_simulation(kwargs)
101101
if sim is not None:
102102
TASK_TO_SIM[task_id] = sim
@@ -171,7 +171,9 @@ def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path):
171171
sim2 = basic_simulation.updated_copy(shutoff=1e-4)
172172
sim3 = basic_simulation.updated_copy(shutoff=1e-3)
173173

174-
data = run_async({"task1": basic_simulation, "task2": sim2}, use_cache=True, path_dir=str(tmp_path))
174+
data = run_async(
175+
{"task1": basic_simulation, "task2": sim2}, use_cache=True, path_dir=str(tmp_path)
176+
)
175177
data_task1 = data["task1"] # access to store in cache
176178
data_task2 = data["task2"] # access to store in cache
177179
assert counters["download"] == 2
@@ -186,7 +188,9 @@ def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path):
186188
assert len(cache) == 2
187189

188190
_reset_counters(counters)
189-
data = run_async({"task1": basic_simulation, "task3": sim3}, use_cache=True, path_dir=str(tmp_path))
191+
data = run_async(
192+
{"task1": basic_simulation, "task3": sim3}, use_cache=True, path_dir=str(tmp_path)
193+
)
190194

191195
data_task1 = data["task1"]
192196
data_task2 = data["task3"] # access to store in cache
@@ -197,7 +201,7 @@ def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path):
197201
assert len(cache) == 3
198202

199203

200-
def _test_job_run_cache(monkeypatch, tmp_path_factory, basic_simulation):
204+
def _test_job_run_cache(monkeypatch, basic_simulation):
201205
counters = _patch_run_pipeline(monkeypatch)
202206
cache = resolve_simulation_cache(use_cache=True)
203207
cache.clear()
@@ -214,6 +218,27 @@ def _test_job_run_cache(monkeypatch, tmp_path_factory, basic_simulation):
214218
assert counters["download"] == 0
215219

216220

221+
@pytest.mark.parametrize("structure_key", ["polyslab"])
222+
@pytest.mark.parametrize("monitor_key", ["mode"])
223+
def _test_autograd_cache(monkeypatch, structure_key, monitor_key):
224+
counters = _patch_run_pipeline(monkeypatch)
225+
cache = resolve_simulation_cache(use_cache=True)
226+
cache.clear()
227+
228+
functions = get_functions(ALL_KEY, "mode")
229+
make_sim = functions["sim"]
230+
sim = make_sim(params0)
231+
web.run(sim, use_cache=True)
232+
assert counters["download"] == 1
233+
assert len(cache) == 1
234+
235+
_reset_counters(counters)
236+
sim = make_sim(params0)
237+
web.run(sim, use_cache=True)
238+
assert counters["download"] == 0
239+
assert len(cache) == 1
240+
241+
217242
def _test_load_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data):
218243
get_cache().clear()
219244
counters = _patch_run_pipeline(monkeypatch)
@@ -323,4 +348,5 @@ def test_cache_end_to_end(monkeypatch, tmp_path, tmp_path_factory, basic_simulat
323348
_test_cache_eviction_by_entries(monkeypatch, tmp_path_factory, basic_simulation)
324349
_test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation)
325350
_test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path)
326-
_test_job_run_cache(monkeypatch, tmp_path_factory, basic_simulation)
351+
_test_job_run_cache(monkeypatch, basic_simulation)
352+
_test_autograd_cache(monkeypatch, basic_simulation)

tidy3d/web/api/container.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,11 @@
1818

1919
from tidy3d.components.base import Tidy3dBaseModel, cached_property
2020
from tidy3d.components.mode.mode_solver import ModeSolver
21-
from tidy3d.components.mode.simulation import ModeSimulation
2221
from tidy3d.components.types import annotate_type
2322
from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType
2423
from tidy3d.exceptions import DataError
2524
from tidy3d.log import get_logging_console, log
2625
from tidy3d.web.api import webapi as web
27-
from tidy3d.web.api.tidy3d_stub import Tidy3dStub, Tidy3dStubData
28-
from tidy3d.web.api.webapi import _get_simulation_data_from_cache_entry, get_reduced_simulation
29-
from tidy3d.web.cache import TMP_BATCH_PREFIX, CacheEntry, resolve_simulation_cache
3026
from tidy3d.web.api.tidy3d_stub import Tidy3dStub
3127
from tidy3d.web.api.webapi import (
3228
restore_simulation_if_cached,

0 commit comments

Comments
 (0)