Skip to content

Commit a64d855

Browse files
yaugenst-flexdaquinteroflex
authored andcommitted
perf: reduce memory usage in field projection autograd backward pass
Avoid materializing the large (x,y,z,theta,phi) contraction intermediate during far-field projection. Instead, fold trapezoid weights into a single weighted contraction, which reduces autograd tape size and peak memory. The Cartesian far-field approximation calls this integral many times with single-angle (theta,phi) inputs; reuse precomputed trapezoid weights and skip `einsum` path optimization for the (1,1) angle case to avoid per-call planning overhead. Changes: - Add `_trapz_weights_1d` and `_far_field_integral` helpers. - Use `_far_field_integral` in `_far_fields_for_surface` and reuse weights per call. - Refactor Cartesian far-field accumulation to build per-point fields and stack. - Add VJP regression tests against the reference 5D+trapezoid implementation.
1 parent 71ff71e commit a64d855

File tree

2 files changed

+331
-67
lines changed

2 files changed

+331
-67
lines changed

tests/test_components/test_field_projection.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
from __future__ import annotations
44

5+
import autograd.numpy as anp
56
import numpy as np
67
import pydantic.v1 as pydantic
78
import pytest
9+
from autograd import make_vjp
810

911
import tidy3d as td
10-
from tidy3d.components.field_projection import FieldProjector
12+
from tidy3d.components.field_projection import FieldProjector, _far_field_integral
1113
from tidy3d.exceptions import DataError
1214

1315
MEDIUM = td.Medium(permittivity=3)
@@ -665,3 +667,116 @@ def test_2d_sim_with_proj_monitors_near():
665667
def test_trapezoid(array, pts, axes, expected):
666668
result = FieldProjector.trapezoid(array, pts, axes)
667669
assert np.allclose(result, expected)
670+
671+
672+
@pytest.mark.parametrize("idx_u, idx_v", [(0, 1), (0, 2), (1, 2)])
673+
def test_far_field_integral_vjp_3d(idx_u, idx_v):
674+
rng = np.random.default_rng(0)
675+
n_x, n_y, n_z = 3, 4, 5
676+
n_theta, n_phi = 6, 7
677+
678+
currents = rng.standard_normal((n_x, n_y, n_z)) + 1j * rng.standard_normal((n_x, n_y, n_z))
679+
pts = [
680+
np.cumsum(rng.random(n_x)),
681+
np.cumsum(rng.random(n_y)),
682+
np.cumsum(rng.random(n_z)),
683+
]
684+
phase_0 = rng.standard_normal((n_x, n_theta, n_phi)) + 1j * rng.standard_normal(
685+
(n_x, n_theta, n_phi)
686+
)
687+
phase_1 = rng.standard_normal((n_y, n_theta, n_phi)) + 1j * rng.standard_normal(
688+
(n_y, n_theta, n_phi)
689+
)
690+
phase_2 = rng.standard_normal((n_z, n_theta)) + 1j * rng.standard_normal((n_z, n_theta))
691+
692+
def reference(currents_in):
693+
chunk = anp.einsum("xtp,ytp,zt,xyz->xyztp", phase_0, phase_1, phase_2, currents_in)
694+
axes = tuple(sorted((idx_u, idx_v)))
695+
pts_int = tuple(pts[axis] for axis in axes)
696+
return FieldProjector.trapezoid(chunk, pts_int, axes)
697+
698+
def primitive(currents_in):
699+
return _far_field_integral(
700+
currents_in,
701+
phase_0,
702+
phase_1,
703+
phase_2,
704+
pts,
705+
idx_u,
706+
idx_v,
707+
is_2d=False,
708+
idx_int_1d=None,
709+
)
710+
711+
vjp_primitive, ans_primitive = make_vjp(primitive)(currents)
712+
vjp_reference, ans_reference = make_vjp(reference)(currents)
713+
714+
np.testing.assert_allclose(
715+
np.asarray(ans_primitive), np.asarray(ans_reference), rtol=1e-12, atol=1e-12
716+
)
717+
718+
g = rng.standard_normal(np.asarray(ans_reference).shape) + 1j * rng.standard_normal(
719+
np.asarray(ans_reference).shape
720+
)
721+
grad_primitive = np.asarray(vjp_primitive(g))
722+
grad_reference = np.asarray(vjp_reference(g))
723+
np.testing.assert_allclose(grad_primitive, grad_reference, rtol=1e-10, atol=1e-10)
724+
725+
726+
@pytest.mark.parametrize("idx_int_1d", [0, 1, 2])
727+
def test_far_field_integral_vjp_2d(idx_int_1d):
728+
rng = np.random.default_rng(1)
729+
n_theta, n_phi = 5, 6
730+
731+
n_x, n_y, n_z = 1, 1, 1
732+
if idx_int_1d == 0:
733+
n_x = 4
734+
elif idx_int_1d == 1:
735+
n_y = 4
736+
else:
737+
n_z = 4
738+
739+
currents = rng.standard_normal((n_x, n_y, n_z)) + 1j * rng.standard_normal((n_x, n_y, n_z))
740+
pts = [
741+
np.cumsum(rng.random(n_x)) if n_x > 1 else np.zeros((n_x,)),
742+
np.cumsum(rng.random(n_y)) if n_y > 1 else np.zeros((n_y,)),
743+
np.cumsum(rng.random(n_z)) if n_z > 1 else np.zeros((n_z,)),
744+
]
745+
phase_0 = rng.standard_normal((n_x, n_theta, n_phi)) + 1j * rng.standard_normal(
746+
(n_x, n_theta, n_phi)
747+
)
748+
phase_1 = rng.standard_normal((n_y, n_theta, n_phi)) + 1j * rng.standard_normal(
749+
(n_y, n_theta, n_phi)
750+
)
751+
phase_2 = rng.standard_normal((n_z, n_theta)) + 1j * rng.standard_normal((n_z, n_theta))
752+
753+
def reference(currents_in):
754+
chunk = anp.einsum("xtp,ytp,zt,xyz->xyztp", phase_0, phase_1, phase_2, currents_in)
755+
return FieldProjector.trapezoid(chunk, pts[idx_int_1d], idx_int_1d)
756+
757+
def primitive(currents_in):
758+
return _far_field_integral(
759+
currents_in,
760+
phase_0,
761+
phase_1,
762+
phase_2,
763+
pts,
764+
0,
765+
1,
766+
is_2d=True,
767+
idx_int_1d=idx_int_1d,
768+
)
769+
770+
vjp_primitive, ans_primitive = make_vjp(primitive)(currents)
771+
vjp_reference, ans_reference = make_vjp(reference)(currents)
772+
773+
np.testing.assert_allclose(
774+
np.asarray(ans_primitive), np.asarray(ans_reference), rtol=1e-12, atol=1e-12
775+
)
776+
777+
g = rng.standard_normal(np.asarray(ans_reference).shape) + 1j * rng.standard_normal(
778+
np.asarray(ans_reference).shape
779+
)
780+
grad_primitive = np.asarray(vjp_primitive(g))
781+
grad_reference = np.asarray(vjp_reference(g))
782+
np.testing.assert_allclose(grad_primitive, grad_reference, rtol=1e-10, atol=1e-10)

0 commit comments

Comments
 (0)