fix(tidy3d): FXC-4878-fix-gradient-for-custom-pole-residue-eps#3171
Conversation
|
@greptile |
|
In the added numerical test, the angle to FD decreased from 40° to 2.4°. |
3e70e85 to
0f0ba01
Compare
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/medium.pyLines 1191-1204 1191 n = axis.size
1192 i0 = int(np.searchsorted(axis, vmin, side="left"))
1193 i1 = int(np.searchsorted(axis, vmax, side="right"))
1194 if i1 <= i0 and n:
! 1195 old = (i0, i1)
! 1196 if i1 < n:
! 1197 i1 = i0 + 1 # expand right
! 1198 elif i0 > 0:
! 1199 i0 = i1 - 1 # expand left
! 1200 log.warning(
1201 f"Empty bounds crop on '{name}' while computing CustomMedium parameter gradients "
1202 f"(adjoint field grid -> medium grid): bounds=[{vmin!r}, {vmax!r}], "
1203 f"grid=[{axis[0]!r}, {axis[-1]!r}] -> indices {old}; using ({i0}, {i1}).",
1204 log_once=True,Lines 1278-1291 1278 return field_values.sum(axis=0, keepdims=True)
1279
1280 # Ensure parameter coordinates are sorted for searchsorted-based binning.
1281 if np.any(param_coords_1d[1:] < param_coords_1d[:-1]):
! 1282 raise ValueError("Spatial coordinates must be sorted before computing derivatives.")
1283 param_coords_sorted = param_coords_1d
1284
1285 n_param = param_coords_sorted.size
1286 if method not in ALLOWED_INTERP_METHODS:
! 1287 raise ValueError(
1288 f"Unsupported interpolation method: {method!r}. "
1289 f"Choose one of: {', '.join(ALLOWED_INTERP_METHODS)}."
1290 )Lines 1309-1341 1309 return param_values
1310
1311 # linear
1312 # Find bracketing parameter indices for each field coordinate.
! 1313 param_index_upper = np.searchsorted(param_coords_sorted, field_coords_1d, side="right")
! 1314 param_index_upper = np.clip(param_index_upper, 1, n_param - 1)
! 1315 param_index_lower = param_index_upper - 1
1316
1317 # Compute interpolation fraction within the bracketing segment.
! 1318 segment_width = (
1319 param_coords_sorted[param_index_upper] - param_coords_sorted[param_index_lower]
1320 )
! 1321 segment_width = np.where(segment_width == 0, 1.0, segment_width)
! 1322 frac_upper = (field_coords_1d - param_coords_sorted[param_index_lower]) / segment_width
! 1323 frac_upper = np.clip(frac_upper, 0.0, 1.0)
1324
1325 # Weights per field sample (broadcast across the flattened trailing dimensions).
! 1326 w_lower = (1.0 - frac_upper)[:, None]
! 1327 w_upper = frac_upper[:, None]
1328
1329 # Accumulate contributions into both bracketing parameter indices.
! 1330 param_values_2d = npo.zeros(
1331 (n_param, field_values_2d.shape[1]), dtype=field_values.dtype
1332 )
! 1333 npo.add.at(param_values_2d, param_index_lower, field_values_2d * w_lower)
! 1334 npo.add.at(param_values_2d, param_index_upper, field_values_2d * w_upper)
1335
! 1336 param_values = param_values_2d.reshape((n_param,) + field_values.shape[1:])
! 1337 return param_values
1338
1339 def _interp_axis(
1340 arr: NDArray, axis: int, field_axis: NDArray, param_axis: NDArray
1341 ) -> NDArray:Lines 1355-1363 1355 freqs_da = np.asarray(E_der_dim.coords["f"])
1356 if component == "sigma":
1357 values = values.imag * (-1.0 / (2.0 * np.pi * freqs_da * EPSILON_0))
1358 elif component == "imag":
! 1359 values = values.imag
1360 elif component == "real":
1361 values = values.real
1362
1363 return values.sum(axis=-1).reshape(eps_shape) |
0f0ba01 to
c6cea2e
Compare
yaugenst-flex
left a comment
There was a problem hiding this comment.
Thanks @marcorudolphflex!
c6cea2e to
e8a7caa
Compare
| ) | ||
| return slice(i0, i1) | ||
|
|
||
| # usage |
There was a problem hiding this comment.
could this comment be expanded on or possibly removed?
| i1 = i0 + 1 # expand right | ||
| elif i0 > 0: | ||
| i0 = i1 - 1 # expand left | ||
| log.warning( |
There was a problem hiding this comment.
if this warning is hit, will it eventually lead to an error downstream or just a 0 gradient?
| field_coords = {k: field_coords[k][s] for k, s in (("x", sx), ("y", sy), ("z", sz))} | ||
| values = values[sx, sy, sz, :] | ||
|
|
||
| def _axis_sizes(coords: NDArray) -> NDArray: |
There was a problem hiding this comment.
in components/grid/grid.py we have a function called cell_sizes. could this or cell_size_meshgrid be used here?
groberts-flex
left a comment
There was a problem hiding this comment.
thanks for the change, looks good overall. left a few minor comments/questions
we missed to handle CustomPoleResidue when fixing this bug on custom medium gradients (correct backward interpolation) #3113
Note
Fixes gradient backpropagation for
CustomPoleResidueby applying the transpose-based spatial interpolation used for custom media, and adds targeted tests._derivative_field_cmp_customonce (moved to shared logic) and uses it inCustomPoleResidue._compute_derivativeswhen structuredSpatialDataArrayis provided; caches sorted coord data via_eps_inf_sorted/_poles_sorted_derivative_field_cmpagainst unstructured datasets and raisesNotImplementedErrorfor unstructuredCustomPoleResidueadjointstest_autograd_custom_pole_residue_numerical.pyvalidating adjoint vs finite-difference gradients; updates existing autograd test hooksCHANGELOG.mdentry for the fixWritten by Cursor Bugbot for commit e8a7caa. This will update automatically on new commits. Configure here.
Greptile Summary
This PR extends the gradient computation fix from PR #3113 to handle
CustomPoleResiduemedium by applying the same custom transpose-based interpolation approach used forCustomMedium.Key changes:
_derivative_field_cmp_custommethod fromCustomMediumtoAbstractCustomMediumclass to enable code reuse across bothCustomMediumandCustomPoleResidue_eps_inf_sorted,_poles_sorted) toCustomPoleResiduethat ensure spatial coordinates are sorted before gradient computation, preventing errors during interpolationCustomPoleResidue._compute_derivativesto use the custom derivative method wheneps_infis aSpatialDataArray, falling back to the scalar method otherwiseCustomPoleResiduewith pole-residue dispersion modelTechnical improvements:
searchsortedclips field data to intersection bounds before processingConfidence Score: 4/5
Important Files Changed
Sequence Diagram
sequenceDiagram participant Client as Autograd System participant CPR as CustomPoleResidue participant ACM as AbstractCustomMedium participant CDC as _compute_derivatives participant DFCC as _derivative_field_cmp_custom participant TIA as _transpose_interp_axis Client->>CPR: Request gradients for eps_inf/poles CPR->>CPR: Access cached _eps_inf_sorted, _poles_sorted CPR->>CDC: _compute_derivatives(derivative_info) Note over CDC: Check if eps_inf is SpatialDataArray alt eps_inf is SpatialDataArray loop For each dimension (x, y, z) CDC->>DFCC: _derivative_field_cmp_custom(E_der_map, eps_inf_sorted, dim, freqs, bounds, "complex") DFCC->>ACM: Extract field coordinates and values DFCC->>ACM: Apply bounds filtering (searchsorted) DFCC->>ACM: Compute volume element scaling (_axis_sizes) DFCC->>ACM: Multiply values by scale loop For each axis (x, y, z) DFCC->>TIA: _transpose_interp_axis(arr, field_axis, param_axis) alt method == "nearest" TIA->>TIA: Compute midpoints between param coords TIA->>TIA: Use searchsorted to find indices TIA->>TIA: Accumulate using npo.add.at else method == "linear" TIA->>TIA: Find upper/lower indices via searchsorted TIA->>TIA: Compute interpolation weights TIA->>TIA: Check segment_width == 0, use 1.0 if zero TIA->>TIA: Accumulate weighted contributions using npo.add.at end TIA-->>DFCC: Return interpolated array end DFCC->>DFCC: Extract complex component DFCC->>DFCC: Sum over frequencies DFCC-->>CDC: Return gradient array end CDC->>CDC: Sum dJ_deps_complex from all dimensions else eps_inf is scalar loop For each dimension (x, y, z) CDC->>ACM: _derivative_field_cmp(E_der_map, eps_inf_sorted, dim) ACM-->>CDC: Return scalar gradient end end CDC->>CDC: Extract sorted pole values from _poles_sorted loop For each frequency CDC->>CDC: _get_vjps_from_params(dJ_deps_complex, poles_vals, omega, paths) Note over CDC: Compute VJPs for eps_inf and pole coefficients (a, c) CDC->>CDC: Accumulate vjps_total end CDC-->>Client: Return vjps (gradient dictionary)