33from __future__ import annotations
44
55from collections .abc import Iterable
6+ from itertools import product
67from typing import Union
78
89import autograd .numpy as anp
5051
5152
5253def _trapz_weights_1d (points : np .ndarray ) -> np .ndarray :
53- """Trapezoidal integration weights for `trapz(y, x=points)`."""
54+ """Trapezoidal integration weights for `trapz(y, x=points)`.
55+
56+ Parameters
57+ ----------
58+ points : np.ndarray
59+ 1D array of integration points.
60+
61+ Returns
62+ -------
63+ np.ndarray
64+ Trapezoidal integration weights with shape ``(len(points),)``.
65+ """
5466 points = np .asarray (points )
5567 num_points = points .size
5668 if num_points <= 1 :
@@ -75,82 +87,92 @@ def _far_field_integral(
7587 idx_v : int ,
7688 * ,
7789 is_2d : bool ,
78- idx_int_1d : Union [int , None ],
90+ idx_integration_1d : Union [int , None ],
7991 weights : list [np .ndarray ] | None = None ,
8092) -> np .ndarray :
93+ """Evaluate the separable far-field surface/line integral.
94+
95+ This helper computes the near-to-far integral using precomputed separable phase factors
96+ and trapezoidal integration weights, with an implementation tailored for autograd.
97+
98+ Parameters
99+ ----------
100+ currents : np.ndarray
101+ Complex surface current values on the monitor grid with shape ``(nx, ny, nz)``.
102+ phase_0 : np.ndarray
103+ Phase factor along the x-axis with shape ``(nx, n_theta, n_phi)``.
104+ phase_1 : np.ndarray
105+ Phase factor along the y-axis with shape ``(ny, n_theta, n_phi)``.
106+ phase_2 : np.ndarray
107+ Phase factor along the z-axis with shape ``(nz, n_theta)``.
108+ pts : list[np.ndarray]
109+ List of 1D coordinate arrays ``[x, y, z]`` matching the spatial axes of ``currents``.
110+ idx_u : int
111+ First surface axis index (0, 1, or 2) for 3D integration.
112+ idx_v : int
113+ Second surface axis index (0, 1, or 2) for 3D integration.
114+ is_2d : bool
115+ If ``True``, treat the source as a 1D line and integrate along ``idx_integration_1d``.
116+ idx_integration_1d : int | None
117+ Spatial axis index (0, 1, or 2) used for the 2D line integral.
118+ weights : list[np.ndarray] | None
119+ Optional trapezoidal weights for each axis. If ``None``, computed from ``pts``.
120+
121+ Returns
122+ -------
123+ np.ndarray
124+ Integrated values as an array with trailing axes ``(n_theta, n_phi)``.
125+ """
81126 if weights is None :
82127 weights = [_trapz_weights_1d (pt ) for pt in pts ]
83128
84129 optimize = not (phase_0 .shape [1 ] == 1 and phase_0 .shape [2 ] == 1 )
85130
86131 if is_2d :
87- if idx_int_1d is None :
88- raise ValueError ("Expected 'idx_int_1d' for 2D far-field projection." )
89-
90- if idx_int_1d == 0 :
91- return anp .einsum (
92- "xtp,ytp,zt,xyz,x->yztp" ,
93- phase_0 ,
94- phase_1 ,
95- phase_2 ,
96- currents ,
97- weights [0 ],
98- optimize = optimize ,
99- )
100- if idx_int_1d == 1 :
101- return anp .einsum (
102- "xtp,ytp,zt,xyz,y->xztp" ,
103- phase_0 ,
104- phase_1 ,
105- phase_2 ,
106- currents ,
107- weights [1 ],
108- optimize = optimize ,
109- )
110- if idx_int_1d == 2 :
111- return anp .einsum (
112- "xtp,ytp,zt,xyz,z->xytp" ,
113- phase_0 ,
114- phase_1 ,
115- phase_2 ,
116- currents ,
117- weights [2 ],
118- optimize = optimize ,
119- )
120- raise ValueError (f"Invalid 2D integration axis: '{ idx_int_1d } '." )
132+ if idx_integration_1d is None :
133+ raise ValueError ("Expected 'idx_integration_1d' for 2D far-field projection." )
134+
135+ if idx_integration_1d == 0 :
136+ equation = "xtp,ytp,zt,xyz,x->yztp"
137+ weight = weights [0 ]
138+ elif idx_integration_1d == 1 :
139+ equation = "xtp,ytp,zt,xyz,y->xztp"
140+ weight = weights [1 ]
141+ elif idx_integration_1d == 2 :
142+ equation = "xtp,ytp,zt,xyz,z->xytp"
143+ weight = weights [2 ]
144+ else :
145+ raise ValueError (f"Invalid 2D integration axis: '{ idx_integration_1d } '." )
121146
122- integrated_axes = {idx_u , idx_v }
123- remaining_axis = ({0 , 1 , 2 } - integrated_axes ).pop ()
124- if remaining_axis == 0 :
125- return anp .einsum (
126- "xtp,ytp,zt,xyz,y,z->xtp" ,
127- phase_0 ,
128- phase_1 ,
129- phase_2 ,
130- currents ,
131- weights [1 ],
132- weights [2 ],
133- optimize = optimize ,
134- )
135- if remaining_axis == 1 :
136147 return anp .einsum (
137- "xtp,ytp,zt,xyz,x,z->ytp" ,
148+ equation ,
138149 phase_0 ,
139150 phase_1 ,
140151 phase_2 ,
141152 currents ,
142- weights [0 ],
143- weights [2 ],
153+ weight ,
144154 optimize = optimize ,
145155 )
156+
157+ integrated_axes = {idx_u , idx_v }
158+ remaining_axis = ({0 , 1 , 2 } - integrated_axes ).pop ()
159+ if remaining_axis == 0 :
160+ equation = "xtp,ytp,zt,xyz,y,z->xtp"
161+ weights_uv = (weights [1 ], weights [2 ])
162+ elif remaining_axis == 1 :
163+ equation = "xtp,ytp,zt,xyz,x,z->ytp"
164+ weights_uv = (weights [0 ], weights [2 ])
165+ else :
166+ equation = "xtp,ytp,zt,xyz,x,y->ztp"
167+ weights_uv = (weights [0 ], weights [1 ])
168+
146169 return anp .einsum (
147- "xtp,ytp,zt,xyz,x,y->ztp" ,
170+ equation ,
148171 phase_0 ,
149172 phase_1 ,
150173 phase_2 ,
151174 currents ,
152- weights [0 ],
153- weights [1 ],
175+ * weights_uv ,
154176 optimize = optimize ,
155177 )
156178
@@ -536,7 +558,7 @@ def _far_fields_for_surface(
536558 _ , source_names = surface .monitor .pop_axis (("x" , "y" , "z" ), axis = surface .axis )
537559
538560 # integration dimension for 2d far field projection
539- idx_int_1d = None
561+ idx_integration_1d = None
540562 zero_dim = [dim for dim , size in enumerate (self .sim_data .simulation .size ) if size == 0 ]
541563 if self .is_2d_simulation :
542564 # Ensure zero_dim has a single element since {zero_dim} expects a value
@@ -545,7 +567,7 @@ def _far_fields_for_surface(
545567
546568 zero_dim = zero_dim [0 ]
547569 integration_axis = {0 , 1 , 2 } - {zero_dim , surface .axis }
548- idx_int_1d = integration_axis .pop ()
570+ idx_integration_1d = integration_axis .pop ()
549571
550572 idx_u , idx_v = idx_uv
551573 cmp_1 , cmp_2 = source_names
@@ -588,7 +610,7 @@ def _far_fields_for_surface(
588610 idx_u ,
589611 idx_v ,
590612 is_2d = self .is_2d_simulation ,
591- idx_int_1d = idx_int_1d ,
613+ idx_integration_1d = idx_integration_1d ,
592614 weights = weights ,
593615 )
594616
@@ -803,15 +825,9 @@ def _project_fields_cartesian(
803825
804826 total_points = len (x ) * len (y ) * len (z )
805827
806- def iter_coords ():
807- for _x in x :
808- for _y in y :
809- for _z in z :
810- yield _x , _y , _z
811-
812828 point_fields = []
813829 for _x , _y , _z in track (
814- iter_coords ( ),
830+ product ( x , y , z ),
815831 description = "Computing projected fields" ,
816832 total = total_points ,
817833 console = get_logging_console (),
@@ -905,14 +921,9 @@ def _project_fields_kspace(
905921
906922 total_points = len (ux ) * len (uy )
907923
908- def iter_coords ():
909- for _ux in ux :
910- for _uy in uy :
911- yield _ux , _uy
912-
913924 point_fields = []
914925 for _ux , _uy in track (
915- iter_coords ( ),
926+ product ( ux , uy ),
916927 description = "Computing projected fields" ,
917928 total = total_points ,
918929 console = get_logging_console (),
@@ -924,30 +935,34 @@ def iter_coords():
924935 for idx_f , frequency in enumerate (freqs ):
925936 fields_sum = anp .zeros ((len (field_names ),), dtype = complex )
926937 for surface , currents in surface_currents :
927- _fields = self ._far_fields_for_surface (
938+ fields_surface = self ._far_fields_for_surface (
928939 frequency = frequency ,
929940 theta = theta ,
930941 phi = phi ,
931942 surface = surface ,
932943 currents = currents ,
933944 medium = medium ,
934945 )
935- _fields = anp .reshape (_fields , fields_sum .shape )
936- fields_sum = fields_sum + _fields * phase [idx_f ]
946+ fields_surface = anp .reshape (fields_surface , fields_sum .shape )
947+ fields_sum = fields_sum + fields_surface * phase [idx_f ]
937948 fields_by_freq .append (fields_sum )
938949
939950 point_fields .append (anp .stack (fields_by_freq , axis = 1 ))
940- continue
941-
942- _x , _y , _z = monitor .sph_2_car (monitor .proj_distance , theta , phi )
943- fields_sum = anp .zeros ((len (field_names ), len (freqs )), dtype = complex )
944- for surface , currents in surface_currents :
945- _fields = self ._fields_for_surface_exact (
946- x = _x , y = _y , z = _z , surface = surface , currents = currents , medium = medium
947- )
948- _fields = anp .reshape (_fields , fields_sum .shape )
949- fields_sum = fields_sum + _fields
950- point_fields .append (fields_sum )
951+ else :
952+ x_obs , y_obs , z_obs = monitor .sph_2_car (monitor .proj_distance , theta , phi )
953+ fields_sum = anp .zeros ((len (field_names ), len (freqs )), dtype = complex )
954+ for surface , currents in surface_currents :
955+ fields_surface = self ._fields_for_surface_exact (
956+ x = x_obs ,
957+ y = y_obs ,
958+ z = z_obs ,
959+ surface = surface ,
960+ currents = currents ,
961+ medium = medium ,
962+ )
963+ fields_surface = anp .reshape (fields_surface , fields_sum .shape )
964+ fields_sum = fields_sum + fields_surface
965+ point_fields .append (fields_sum )
951966
952967 stacked_fields = anp .stack (point_fields , axis = 0 )
953968 stacked_fields = anp .reshape (
0 commit comments