Skip to content

Commit a6ffc67

Browse files
yaugenst-flexdaquinteroflex
authored andcommitted
greg comments
1 parent a64d855 commit a6ffc67

File tree

2 files changed

+108
-93
lines changed

2 files changed

+108
-93
lines changed

tests/test_components/test_field_projection.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ def primitive(currents_in):
705705
idx_u,
706706
idx_v,
707707
is_2d=False,
708-
idx_int_1d=None,
708+
idx_integration_1d=None,
709709
)
710710

711711
vjp_primitive, ans_primitive = make_vjp(primitive)(currents)
@@ -723,15 +723,15 @@ def primitive(currents_in):
723723
np.testing.assert_allclose(grad_primitive, grad_reference, rtol=1e-10, atol=1e-10)
724724

725725

726-
@pytest.mark.parametrize("idx_int_1d", [0, 1, 2])
727-
def test_far_field_integral_vjp_2d(idx_int_1d):
726+
@pytest.mark.parametrize("idx_integration_1d", [0, 1, 2])
727+
def test_far_field_integral_vjp_2d(idx_integration_1d):
728728
rng = np.random.default_rng(1)
729729
n_theta, n_phi = 5, 6
730730

731731
n_x, n_y, n_z = 1, 1, 1
732-
if idx_int_1d == 0:
732+
if idx_integration_1d == 0:
733733
n_x = 4
734-
elif idx_int_1d == 1:
734+
elif idx_integration_1d == 1:
735735
n_y = 4
736736
else:
737737
n_z = 4
@@ -752,7 +752,7 @@ def test_far_field_integral_vjp_2d(idx_int_1d):
752752

753753
def reference(currents_in):
754754
chunk = anp.einsum("xtp,ytp,zt,xyz->xyztp", phase_0, phase_1, phase_2, currents_in)
755-
return FieldProjector.trapezoid(chunk, pts[idx_int_1d], idx_int_1d)
755+
return FieldProjector.trapezoid(chunk, pts[idx_integration_1d], idx_integration_1d)
756756

757757
def primitive(currents_in):
758758
return _far_field_integral(
@@ -764,7 +764,7 @@ def primitive(currents_in):
764764
0,
765765
1,
766766
is_2d=True,
767-
idx_int_1d=idx_int_1d,
767+
idx_integration_1d=idx_integration_1d,
768768
)
769769

770770
vjp_primitive, ans_primitive = make_vjp(primitive)(currents)

tidy3d/components/field_projection.py

Lines changed: 101 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from collections.abc import Iterable
6+
from itertools import product
67
from typing import Union
78

89
import autograd.numpy as anp
@@ -50,7 +51,18 @@
5051

5152

5253
def _trapz_weights_1d(points: np.ndarray) -> np.ndarray:
53-
"""Trapezoidal integration weights for `trapz(y, x=points)`."""
54+
"""Trapezoidal integration weights for `trapz(y, x=points)`.
55+
56+
Parameters
57+
----------
58+
points : np.ndarray
59+
1D array of integration points.
60+
61+
Returns
62+
-------
63+
np.ndarray
64+
Trapezoidal integration weights with shape ``(len(points),)``.
65+
"""
5466
points = np.asarray(points)
5567
num_points = points.size
5668
if num_points <= 1:
@@ -75,82 +87,92 @@ def _far_field_integral(
7587
idx_v: int,
7688
*,
7789
is_2d: bool,
78-
idx_int_1d: Union[int, None],
90+
idx_integration_1d: Union[int, None],
7991
weights: list[np.ndarray] | None = None,
8092
) -> np.ndarray:
93+
"""Evaluate the separable far-field surface/line integral.
94+
95+
This helper computes the near-to-far integral using precomputed separable phase factors
96+
and trapezoidal integration weights, with an implementation tailored for autograd.
97+
98+
Parameters
99+
----------
100+
currents : np.ndarray
101+
Complex surface current values on the monitor grid with shape ``(nx, ny, nz)``.
102+
phase_0 : np.ndarray
103+
Phase factor along the x-axis with shape ``(nx, n_theta, n_phi)``.
104+
phase_1 : np.ndarray
105+
Phase factor along the y-axis with shape ``(ny, n_theta, n_phi)``.
106+
phase_2 : np.ndarray
107+
Phase factor along the z-axis with shape ``(nz, n_theta)``.
108+
pts : list[np.ndarray]
109+
List of 1D coordinate arrays ``[x, y, z]`` matching the spatial axes of ``currents``.
110+
idx_u : int
111+
First surface axis index (0, 1, or 2) for 3D integration.
112+
idx_v : int
113+
Second surface axis index (0, 1, or 2) for 3D integration.
114+
is_2d : bool
115+
If ``True``, treat the source as a 1D line and integrate along ``idx_integration_1d``.
116+
idx_integration_1d : int | None
117+
Spatial axis index (0, 1, or 2) used for the 2D line integral.
118+
weights : list[np.ndarray] | None
119+
Optional trapezoidal weights for each axis. If ``None``, computed from ``pts``.
120+
121+
Returns
122+
-------
123+
np.ndarray
124+
Integrated values as an array with trailing axes ``(n_theta, n_phi)``.
125+
"""
81126
if weights is None:
82127
weights = [_trapz_weights_1d(pt) for pt in pts]
83128

84129
optimize = not (phase_0.shape[1] == 1 and phase_0.shape[2] == 1)
85130

86131
if is_2d:
87-
if idx_int_1d is None:
88-
raise ValueError("Expected 'idx_int_1d' for 2D far-field projection.")
89-
90-
if idx_int_1d == 0:
91-
return anp.einsum(
92-
"xtp,ytp,zt,xyz,x->yztp",
93-
phase_0,
94-
phase_1,
95-
phase_2,
96-
currents,
97-
weights[0],
98-
optimize=optimize,
99-
)
100-
if idx_int_1d == 1:
101-
return anp.einsum(
102-
"xtp,ytp,zt,xyz,y->xztp",
103-
phase_0,
104-
phase_1,
105-
phase_2,
106-
currents,
107-
weights[1],
108-
optimize=optimize,
109-
)
110-
if idx_int_1d == 2:
111-
return anp.einsum(
112-
"xtp,ytp,zt,xyz,z->xytp",
113-
phase_0,
114-
phase_1,
115-
phase_2,
116-
currents,
117-
weights[2],
118-
optimize=optimize,
119-
)
120-
raise ValueError(f"Invalid 2D integration axis: '{idx_int_1d}'.")
132+
if idx_integration_1d is None:
133+
raise ValueError("Expected 'idx_integration_1d' for 2D far-field projection.")
134+
135+
if idx_integration_1d == 0:
136+
equation = "xtp,ytp,zt,xyz,x->yztp"
137+
weight = weights[0]
138+
elif idx_integration_1d == 1:
139+
equation = "xtp,ytp,zt,xyz,y->xztp"
140+
weight = weights[1]
141+
elif idx_integration_1d == 2:
142+
equation = "xtp,ytp,zt,xyz,z->xytp"
143+
weight = weights[2]
144+
else:
145+
raise ValueError(f"Invalid 2D integration axis: '{idx_integration_1d}'.")
121146

122-
integrated_axes = {idx_u, idx_v}
123-
remaining_axis = ({0, 1, 2} - integrated_axes).pop()
124-
if remaining_axis == 0:
125-
return anp.einsum(
126-
"xtp,ytp,zt,xyz,y,z->xtp",
127-
phase_0,
128-
phase_1,
129-
phase_2,
130-
currents,
131-
weights[1],
132-
weights[2],
133-
optimize=optimize,
134-
)
135-
if remaining_axis == 1:
136147
return anp.einsum(
137-
"xtp,ytp,zt,xyz,x,z->ytp",
148+
equation,
138149
phase_0,
139150
phase_1,
140151
phase_2,
141152
currents,
142-
weights[0],
143-
weights[2],
153+
weight,
144154
optimize=optimize,
145155
)
156+
157+
integrated_axes = {idx_u, idx_v}
158+
remaining_axis = ({0, 1, 2} - integrated_axes).pop()
159+
if remaining_axis == 0:
160+
equation = "xtp,ytp,zt,xyz,y,z->xtp"
161+
weights_uv = (weights[1], weights[2])
162+
elif remaining_axis == 1:
163+
equation = "xtp,ytp,zt,xyz,x,z->ytp"
164+
weights_uv = (weights[0], weights[2])
165+
else:
166+
equation = "xtp,ytp,zt,xyz,x,y->ztp"
167+
weights_uv = (weights[0], weights[1])
168+
146169
return anp.einsum(
147-
"xtp,ytp,zt,xyz,x,y->ztp",
170+
equation,
148171
phase_0,
149172
phase_1,
150173
phase_2,
151174
currents,
152-
weights[0],
153-
weights[1],
175+
*weights_uv,
154176
optimize=optimize,
155177
)
156178

@@ -536,7 +558,7 @@ def _far_fields_for_surface(
536558
_, source_names = surface.monitor.pop_axis(("x", "y", "z"), axis=surface.axis)
537559

538560
# integration dimension for 2d far field projection
539-
idx_int_1d = None
561+
idx_integration_1d = None
540562
zero_dim = [dim for dim, size in enumerate(self.sim_data.simulation.size) if size == 0]
541563
if self.is_2d_simulation:
542564
# Ensure zero_dim has a single element since {zero_dim} expects a value
@@ -545,7 +567,7 @@ def _far_fields_for_surface(
545567

546568
zero_dim = zero_dim[0]
547569
integration_axis = {0, 1, 2} - {zero_dim, surface.axis}
548-
idx_int_1d = integration_axis.pop()
570+
idx_integration_1d = integration_axis.pop()
549571

550572
idx_u, idx_v = idx_uv
551573
cmp_1, cmp_2 = source_names
@@ -588,7 +610,7 @@ def _far_fields_for_surface(
588610
idx_u,
589611
idx_v,
590612
is_2d=self.is_2d_simulation,
591-
idx_int_1d=idx_int_1d,
613+
idx_integration_1d=idx_integration_1d,
592614
weights=weights,
593615
)
594616

@@ -803,15 +825,9 @@ def _project_fields_cartesian(
803825

804826
total_points = len(x) * len(y) * len(z)
805827

806-
def iter_coords():
807-
for _x in x:
808-
for _y in y:
809-
for _z in z:
810-
yield _x, _y, _z
811-
812828
point_fields = []
813829
for _x, _y, _z in track(
814-
iter_coords(),
830+
product(x, y, z),
815831
description="Computing projected fields",
816832
total=total_points,
817833
console=get_logging_console(),
@@ -905,14 +921,9 @@ def _project_fields_kspace(
905921

906922
total_points = len(ux) * len(uy)
907923

908-
def iter_coords():
909-
for _ux in ux:
910-
for _uy in uy:
911-
yield _ux, _uy
912-
913924
point_fields = []
914925
for _ux, _uy in track(
915-
iter_coords(),
926+
product(ux, uy),
916927
description="Computing projected fields",
917928
total=total_points,
918929
console=get_logging_console(),
@@ -924,30 +935,34 @@ def iter_coords():
924935
for idx_f, frequency in enumerate(freqs):
925936
fields_sum = anp.zeros((len(field_names),), dtype=complex)
926937
for surface, currents in surface_currents:
927-
_fields = self._far_fields_for_surface(
938+
fields_surface = self._far_fields_for_surface(
928939
frequency=frequency,
929940
theta=theta,
930941
phi=phi,
931942
surface=surface,
932943
currents=currents,
933944
medium=medium,
934945
)
935-
_fields = anp.reshape(_fields, fields_sum.shape)
936-
fields_sum = fields_sum + _fields * phase[idx_f]
946+
fields_surface = anp.reshape(fields_surface, fields_sum.shape)
947+
fields_sum = fields_sum + fields_surface * phase[idx_f]
937948
fields_by_freq.append(fields_sum)
938949

939950
point_fields.append(anp.stack(fields_by_freq, axis=1))
940-
continue
941-
942-
_x, _y, _z = monitor.sph_2_car(monitor.proj_distance, theta, phi)
943-
fields_sum = anp.zeros((len(field_names), len(freqs)), dtype=complex)
944-
for surface, currents in surface_currents:
945-
_fields = self._fields_for_surface_exact(
946-
x=_x, y=_y, z=_z, surface=surface, currents=currents, medium=medium
947-
)
948-
_fields = anp.reshape(_fields, fields_sum.shape)
949-
fields_sum = fields_sum + _fields
950-
point_fields.append(fields_sum)
951+
else:
952+
x_obs, y_obs, z_obs = monitor.sph_2_car(monitor.proj_distance, theta, phi)
953+
fields_sum = anp.zeros((len(field_names), len(freqs)), dtype=complex)
954+
for surface, currents in surface_currents:
955+
fields_surface = self._fields_for_surface_exact(
956+
x=x_obs,
957+
y=y_obs,
958+
z=z_obs,
959+
surface=surface,
960+
currents=currents,
961+
medium=medium,
962+
)
963+
fields_surface = anp.reshape(fields_surface, fields_sum.shape)
964+
fields_sum = fields_sum + fields_surface
965+
point_fields.append(fields_sum)
951966

952967
stacked_fields = anp.stack(point_fields, axis=0)
953968
stacked_fields = anp.reshape(

0 commit comments

Comments
 (0)