Skip to content

fix(tidy3d): FXC-4878-fix-gradient-for-custom-pole-residue-eps#3171

Merged
marcorudolphflex merged 1 commit into
developfrom
FXC-4878-fix-gradient-for-custom-pole-residue-eps
Jan 16, 2026
Merged

fix(tidy3d): FXC-4878-fix-gradient-for-custom-pole-residue-eps#3171
marcorudolphflex merged 1 commit into
developfrom
FXC-4878-fix-gradient-for-custom-pole-residue-eps

Conversation

@marcorudolphflex
Copy link
Copy Markdown
Contributor

@marcorudolphflex marcorudolphflex commented Jan 15, 2026

we missed to handle CustomPoleResidue when fixing this bug on custom medium gradients (correct backward interpolation) #3113


Note

Fixes gradient backpropagation for CustomPoleResidue by applying the transpose-based spatial interpolation used for custom media, and adds targeted tests.

  • Implements _derivative_field_cmp_custom once (moved to shared logic) and uses it in CustomPoleResidue._compute_derivatives when structured SpatialDataArray is provided; caches sorted coord data via _eps_inf_sorted/_poles_sorted
  • Guards _derivative_field_cmp against unstructured datasets and raises NotImplementedError for unstructured CustomPoleResidue adjoints
  • Adds numerical test test_autograd_custom_pole_residue_numerical.py validating adjoint vs finite-difference gradients; updates existing autograd test hooks
  • Updates CHANGELOG.md entry for the fix

Written by Cursor Bugbot for commit e8a7caa. This will update automatically on new commits. Configure here.

Greptile Summary

This PR extends the gradient computation fix from PR #3113 to handle CustomPoleResidue medium by applying the same custom transpose-based interpolation approach used for CustomMedium.

Key changes:

  • Moved _derivative_field_cmp_custom method from CustomMedium to AbstractCustomMedium class to enable code reuse across both CustomMedium and CustomPoleResidue
  • Added cached properties (_eps_inf_sorted, _poles_sorted) to CustomPoleResidue that ensure spatial coordinates are sorted before gradient computation, preventing errors during interpolation
  • Updated CustomPoleResidue._compute_derivatives to use the custom derivative method when eps_inf is a SpatialDataArray, falling back to the scalar method otherwise
  • Added comprehensive numerical test file validating adjoint gradients against finite-difference gradients for CustomPoleResidue with pole-residue dispersion model

Technical improvements:

  • The custom implementation correctly handles the transpose (adjoint) operation needed for gradient backpropagation by accumulating field values onto parameter grid points using weighted contributions
  • Proper boundary handling via searchsorted clips field data to intersection bounds before processing
  • Volume elements computed from field coordinates and applied before interpolation fix normalization issues from the previous xarray-based approach

Confidence Score: 4/5

  • This PR is safe to merge with minor suggestions for improvement
  • The implementation is mathematically sound and includes comprehensive numerical validation tests. The refactor addresses real bugs in gradient computation by extending the fix from PR fix(tidy3d): FXC-4641-fix-gradients-in-custom-medium #3113 to CustomPoleResidue. Score reduced from 5 to 4 due to one style suggestion regarding floating-point comparison (using == instead of np.isclose for norm checks) that could improve robustness, though this is a minor issue in the test code rather than production logic.
  • No files require special attention - the implementation is solid with good test coverage

Important Files Changed

Filename Overview
tests/test_components/autograd/numerical/test_autograd_custom_pole_residue_numerical.py New comprehensive numerical test validating CustomPoleResidue adjoint gradients against finite differences; well-structured with good parametrization
tidy3d/components/medium.py Moved _derivative_field_cmp_custom to AbstractCustomMedium for code reuse; added cached sorted properties and updated CustomPoleResidue gradient computation to use custom derivative method

Sequence Diagram

sequenceDiagram
    participant Client as Autograd System
    participant CPR as CustomPoleResidue
    participant ACM as AbstractCustomMedium
    participant CDC as _compute_derivatives
    participant DFCC as _derivative_field_cmp_custom
    participant TIA as _transpose_interp_axis
    
    Client->>CPR: Request gradients for eps_inf/poles
    CPR->>CPR: Access cached _eps_inf_sorted, _poles_sorted
    CPR->>CDC: _compute_derivatives(derivative_info)
    
    Note over CDC: Check if eps_inf is SpatialDataArray
    
    alt eps_inf is SpatialDataArray
        loop For each dimension (x, y, z)
            CDC->>DFCC: _derivative_field_cmp_custom(E_der_map, eps_inf_sorted, dim, freqs, bounds, "complex")
            DFCC->>ACM: Extract field coordinates and values
            DFCC->>ACM: Apply bounds filtering (searchsorted)
            DFCC->>ACM: Compute volume element scaling (_axis_sizes)
            DFCC->>ACM: Multiply values by scale
            
            loop For each axis (x, y, z)
                DFCC->>TIA: _transpose_interp_axis(arr, field_axis, param_axis)
                
                alt method == "nearest"
                    TIA->>TIA: Compute midpoints between param coords
                    TIA->>TIA: Use searchsorted to find indices
                    TIA->>TIA: Accumulate using npo.add.at
                else method == "linear"
                    TIA->>TIA: Find upper/lower indices via searchsorted
                    TIA->>TIA: Compute interpolation weights
                    TIA->>TIA: Check segment_width == 0, use 1.0 if zero
                    TIA->>TIA: Accumulate weighted contributions using npo.add.at
                end
                
                TIA-->>DFCC: Return interpolated array
            end
            
            DFCC->>DFCC: Extract complex component
            DFCC->>DFCC: Sum over frequencies
            DFCC-->>CDC: Return gradient array
        end
        CDC->>CDC: Sum dJ_deps_complex from all dimensions
    else eps_inf is scalar
        loop For each dimension (x, y, z)
            CDC->>ACM: _derivative_field_cmp(E_der_map, eps_inf_sorted, dim)
            ACM-->>CDC: Return scalar gradient
        end
    end
    
    CDC->>CDC: Extract sorted pole values from _poles_sorted
    
    loop For each frequency
        CDC->>CDC: _get_vjps_from_params(dJ_deps_complex, poles_vals, omega, paths)
        Note over CDC: Compute VJPs for eps_inf and pole coefficients (a, c)
        CDC->>CDC: Accumulate vjps_total
    end
    
    CDC-->>Client: Return vjps (gradient dictionary)
Loading

@marcorudolphflex
Copy link
Copy Markdown
Contributor Author

@greptile

@marcorudolphflex
Copy link
Copy Markdown
Contributor Author

marcorudolphflex commented Jan 15, 2026

In the added numerical test, the angle to FD decreased from 40° to 2.4°.
I also ran the numerical tests for Custom Medium to be sure no regression happened.

@marcorudolphflex marcorudolphflex marked this pull request as ready for review January 15, 2026 15:42
Copy link
Copy Markdown

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@marcorudolphflex marcorudolphflex force-pushed the FXC-4878-fix-gradient-for-custom-pole-residue-eps branch from 3e70e85 to 0f0ba01 Compare January 15, 2026 16:47
@github-actions
Copy link
Copy Markdown
Contributor

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/medium.py (78.9%): Missing lines 1195-1200,1282,1287,1313-1315,1318,1321-1323,1326-1327,1330,1333-1334,1336-1337,1359

Summary

  • Total: 109 lines
  • Missing: 23 lines
  • Coverage: 78%

tidy3d/components/medium.py

Lines 1191-1204

  1191             n = axis.size
  1192             i0 = int(np.searchsorted(axis, vmin, side="left"))
  1193             i1 = int(np.searchsorted(axis, vmax, side="right"))
  1194             if i1 <= i0 and n:
! 1195                 old = (i0, i1)
! 1196                 if i1 < n:
! 1197                     i1 = i0 + 1  # expand right
! 1198                 elif i0 > 0:
! 1199                     i0 = i1 - 1  # expand left
! 1200                 log.warning(
  1201                     f"Empty bounds crop on '{name}' while computing CustomMedium parameter gradients "
  1202                     f"(adjoint field grid -> medium grid): bounds=[{vmin!r}, {vmax!r}], "
  1203                     f"grid=[{axis[0]!r}, {axis[-1]!r}] -> indices {old}; using ({i0}, {i1}).",
  1204                     log_once=True,

Lines 1278-1291

  1278                 return field_values.sum(axis=0, keepdims=True)
  1279 
  1280             # Ensure parameter coordinates are sorted for searchsorted-based binning.
  1281             if np.any(param_coords_1d[1:] < param_coords_1d[:-1]):
! 1282                 raise ValueError("Spatial coordinates must be sorted before computing derivatives.")
  1283             param_coords_sorted = param_coords_1d
  1284 
  1285             n_param = param_coords_sorted.size
  1286             if method not in ALLOWED_INTERP_METHODS:
! 1287                 raise ValueError(
  1288                     f"Unsupported interpolation method: {method!r}. "
  1289                     f"Choose one of: {', '.join(ALLOWED_INTERP_METHODS)}."
  1290                 )

Lines 1309-1341

  1309                 return param_values
  1310 
  1311             # linear
  1312             # Find bracketing parameter indices for each field coordinate.
! 1313             param_index_upper = np.searchsorted(param_coords_sorted, field_coords_1d, side="right")
! 1314             param_index_upper = np.clip(param_index_upper, 1, n_param - 1)
! 1315             param_index_lower = param_index_upper - 1
  1316 
  1317             # Compute interpolation fraction within the bracketing segment.
! 1318             segment_width = (
  1319                 param_coords_sorted[param_index_upper] - param_coords_sorted[param_index_lower]
  1320             )
! 1321             segment_width = np.where(segment_width == 0, 1.0, segment_width)
! 1322             frac_upper = (field_coords_1d - param_coords_sorted[param_index_lower]) / segment_width
! 1323             frac_upper = np.clip(frac_upper, 0.0, 1.0)
  1324 
  1325             # Weights per field sample (broadcast across the flattened trailing dimensions).
! 1326             w_lower = (1.0 - frac_upper)[:, None]
! 1327             w_upper = frac_upper[:, None]
  1328 
  1329             # Accumulate contributions into both bracketing parameter indices.
! 1330             param_values_2d = npo.zeros(
  1331                 (n_param, field_values_2d.shape[1]), dtype=field_values.dtype
  1332             )
! 1333             npo.add.at(param_values_2d, param_index_lower, field_values_2d * w_lower)
! 1334             npo.add.at(param_values_2d, param_index_upper, field_values_2d * w_upper)
  1335 
! 1336             param_values = param_values_2d.reshape((n_param,) + field_values.shape[1:])
! 1337             return param_values
  1338 
  1339         def _interp_axis(
  1340             arr: NDArray, axis: int, field_axis: NDArray, param_axis: NDArray
  1341         ) -> NDArray:

Lines 1355-1363

  1355         freqs_da = np.asarray(E_der_dim.coords["f"])
  1356         if component == "sigma":
  1357             values = values.imag * (-1.0 / (2.0 * np.pi * freqs_da * EPSILON_0))
  1358         elif component == "imag":
! 1359             values = values.imag
  1360         elif component == "real":
  1361             values = values.real
  1362 
  1363         return values.sum(axis=-1).reshape(eps_shape)

@marcorudolphflex marcorudolphflex force-pushed the FXC-4878-fix-gradient-for-custom-pole-residue-eps branch from 0f0ba01 to c6cea2e Compare January 16, 2026 08:00
Copy link
Copy Markdown
Collaborator

@yaugenst-flex yaugenst-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread poetry.lock
@marcorudolphflex marcorudolphflex force-pushed the FXC-4878-fix-gradient-for-custom-pole-residue-eps branch from c6cea2e to e8a7caa Compare January 16, 2026 11:56
@marcorudolphflex marcorudolphflex added this pull request to the merge queue Jan 16, 2026
)
return slice(i0, i1)

# usage
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this comment be expanded on or possibly removed?

i1 = i0 + 1 # expand right
elif i0 > 0:
i0 = i1 - 1 # expand left
log.warning(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this warning is hit, will it eventually lead to an error downstream or just a 0 gradient?

Merged via the queue into develop with commit 795f66e Jan 16, 2026
63 of 67 checks passed
@marcorudolphflex marcorudolphflex deleted the FXC-4878-fix-gradient-for-custom-pole-residue-eps branch January 16, 2026 13:19
field_coords = {k: field_coords[k][s] for k, s in (("x", sx), ("y", sy), ("z", sz))}
values = values[sx, sy, sz, :]

def _axis_sizes(coords: NDArray) -> NDArray:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in components/grid/grid.py we have a function called cell_sizes. could this or cell_size_meshgrid be used here?

Copy link
Copy Markdown
Contributor

@groberts-flex groberts-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the change, looks good overall. left a few minor comments/questions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants