Skip to content

Conversation

@yaugenst-flex
Copy link
Collaborator

@yaugenst-flex yaugenst-flex commented Dec 14, 2025

Avoid materializing the large (x,y,z,theta,phi) contraction intermediate during far-field projection. Instead, fold trapezoid weights into a single weighted contraction, which reduces autograd tape size and peak memory.

The Cartesian far-field approximation calls this integral many times with single-angle (theta,phi) inputs; reuse precomputed trapezoid weights and skip einsum path optimization for the (1,1) angle case to avoid per-call planning overhead.

Changes:

  • Add _trapz_weights_1d and _far_field_integral helpers.
  • Use _far_field_integral in _far_fields_for_surface and reuse weights per call.
  • Refactor Cartesian far-field accumulation to build per-point fields and stack.
  • Add VJP regression tests against the reference 5D+trapezoid implementation.

Tested on the autograd metalens notebook and a couple of regression tests to make sure the results still match. Would be nice to get this in for 2.10 I guess.

Benchmarks

Note some of these are in log scale 😅

Code_Generated_Image-2 Code_Generated_Image-5 Code_Generated_Image-3 Code_Generated_Image-4

Greptile Overview

Greptile Summary

Refactored far-field projection to reduce memory usage in autograd backward pass by avoiding materialization of large intermediate tensors. Instead of computing the full (x,y,z,theta,phi) contraction and then applying trapezoidal integration, the implementation now folds trapezoid weights directly into a single weighted contraction via the new _far_field_integral helper.

Key changes:

  • Added _trapz_weights_1d to precompute trapezoidal integration weights
  • Added _far_field_integral that combines phase multiplication, current contraction, and weighted integration in one einsum operation
  • Refactored Cartesian and k-space projection loops to build per-point fields and stack them, avoiding add_at operations
  • Optimized single-angle case by skipping einsum path optimization when phase_0.shape[1] == 1 and phase_0.shape[2] == 1
  • Added comprehensive VJP regression tests to verify autograd correctness against the reference implementation

According to the PR benchmarks, this achieves significant memory reduction (shown on log scale) while maintaining numerical correctness.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • Score reflects well-tested memory optimization with comprehensive VJP regression tests verifying correctness. The refactoring avoids materializing large intermediate tensors while maintaining numerical equivalence to the reference implementation. Changes are purely performance-focused with no functional changes to the API or behavior. Only minor style issues related to missing docstrings for internal helper functions.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
tidy3d/components/field_projection.py 5/5 Refactored far-field integral to fold trapezoid weights into contraction, avoiding large intermediate (x,y,z,theta,phi) tensor and reducing memory
tests/test_components/test_field_projection.py 5/5 Added VJP regression tests for new _far_field_integral helper to verify autograd backward pass correctness

Sequence Diagram

sequenceDiagram
    participant Client
    participant FieldProjector
    participant _far_fields_for_surface
    participant _far_field_integral
    participant _trapz_weights_1d
    participant autograd

    Client->>FieldProjector: project_fields(monitor)
    FieldProjector->>FieldProjector: _project_fields_cartesian/kspace
    
    Note over FieldProjector: Precompute surface currents<br/>and weights once per monitor
    FieldProjector->>_trapz_weights_1d: compute weights for x,y,z
    _trapz_weights_1d-->>FieldProjector: weights[x,y,z]
    
    loop for each observation point (x,y,z) or (ux,uy)
        Note over FieldProjector: Convert to spherical coords
        loop for each frequency
            FieldProjector->>_far_fields_for_surface: compute fields at (theta,phi,f)
            
            Note over _far_fields_for_surface: Compute phase factors
            _far_fields_for_surface->>_far_fields_for_surface: phase_0, phase_1, phase_2
            
            loop for each field component (E1,E2,H1,H2)
                _far_fields_for_surface->>_far_field_integral: currents, phases, weights
                
                Note over _far_field_integral: Single weighted einsum<br/>(avoids large intermediate)
                _far_field_integral->>autograd: einsum with folded weights
                autograd-->>_far_field_integral: integrated result
                
                _far_field_integral-->>_far_fields_for_surface: jm_i
            end
            
            Note over _far_fields_for_surface: Transform to spherical<br/>field components
            _far_fields_for_surface-->>FieldProjector: [Er,Etheta,Ephi,Hr,Htheta,Hphi]
        end
        
        Note over FieldProjector: Accumulate fields per point
    end
    
    Note over FieldProjector: Stack all point fields<br/>and reshape to grid
    FieldProjector-->>Client: FieldProjectionData
Loading

@yaugenst-flex yaugenst-flex self-assigned this Dec 14, 2025
@yaugenst-flex yaugenst-flex force-pushed the FXC-4540-decrease-local-field-projection-memory-overhead-especially-during-autograd-backward branch from 49aa82a to bb8954c Compare December 14, 2025 08:21
@yaugenst-flex yaugenst-flex marked this pull request as ready for review December 15, 2025 06:16
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@github-actions
Copy link
Contributor

github-actions bot commented Dec 15, 2025

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/field_projection.py (90.7%): Missing lines 133,145,952-955,963-965

Summary

  • Total: 97 lines
  • Missing: 9 lines
  • Coverage: 90%

tidy3d/components/field_projection.py

Lines 129-137

  129     optimize = not (phase_0.shape[1] == 1 and phase_0.shape[2] == 1)
  130 
  131     if is_2d:
  132         if idx_integration_1d is None:
! 133             raise ValueError("Expected 'idx_integration_1d' for 2D far-field projection.")
  134 
  135         if idx_integration_1d == 0:
  136             equation = "xtp,ytp,zt,xyz,x->yztp"
  137             weight = weights[0]

Lines 141-149

  141         elif idx_integration_1d == 2:
  142             equation = "xtp,ytp,zt,xyz,z->xytp"
  143             weight = weights[2]
  144         else:
! 145             raise ValueError(f"Invalid 2D integration axis: '{idx_integration_1d}'.")
  146 
  147         return anp.einsum(
  148             equation,
  149             phase_0,

Lines 948-959

  948                     fields_by_freq.append(fields_sum)
  949 
  950                 point_fields.append(anp.stack(fields_by_freq, axis=1))
  951             else:
! 952                 x_obs, y_obs, z_obs = monitor.sph_2_car(monitor.proj_distance, theta, phi)
! 953                 fields_sum = anp.zeros((len(field_names), len(freqs)), dtype=complex)
! 954                 for surface, currents in surface_currents:
! 955                     fields_surface = self._fields_for_surface_exact(
  956                         x=x_obs,
  957                         y=y_obs,
  958                         z=z_obs,
  959                         surface=surface,

Lines 959-969

  959                         surface=surface,
  960                         currents=currents,
  961                         medium=medium,
  962                     )
! 963                     fields_surface = anp.reshape(fields_surface, fields_sum.shape)
! 964                     fields_sum = fields_sum + fields_surface
! 965                 point_fields.append(fields_sum)
  966 
  967         stacked_fields = anp.stack(point_fields, axis=0)
  968         stacked_fields = anp.reshape(
  969             stacked_fields, (len(ux), len(uy), len(field_names), len(freqs))

Copy link
Collaborator

@momchil-flex momchil-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive!

Copy link
Contributor

@groberts-flex groberts-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @yaugenst-flex for the change! the performance looks great and I can confirm that when I try it in my example for far_fields_approx=True, I see a significant speedup and I no longer am seeing crashes from OOM.

I left some comments on the code - overall, I think there are some areas that could be a little more readable and some more comments/documentation in the code would make this easier to follow and tweak in the future. Nothing major in there, it seems like things are working well from the tests/profiling!

Copy link
Contributor

@groberts-flex groberts-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for making the changes @yaugenst-flex , looks good!

@daquinteroflex daquinteroflex force-pushed the FXC-4540-decrease-local-field-projection-memory-overhead-especially-during-autograd-backward branch from c649652 to a6ffc67 Compare December 16, 2025 13:22
@yaugenst-flex yaugenst-flex added this pull request to the merge queue Dec 16, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Dec 16, 2025
Avoid materializing the large (x,y,z,theta,phi) contraction intermediate during
far-field projection. Instead, fold trapezoid weights into a single weighted
contraction, which reduces autograd tape size and peak memory.

The Cartesian far-field approximation calls this integral many times with
single-angle (theta,phi) inputs; reuse precomputed trapezoid weights and skip
`einsum` path optimization for the (1,1) angle case to avoid per-call planning
overhead.

Changes:
- Add `_trapz_weights_1d` and `_far_field_integral` helpers.
- Use `_far_field_integral` in `_far_fields_for_surface` and reuse weights per call.
- Refactor Cartesian far-field accumulation to build per-point fields and stack.
- Add VJP regression tests against the reference 5D+trapezoid implementation.
@yaugenst-flex yaugenst-flex force-pushed the FXC-4540-decrease-local-field-projection-memory-overhead-especially-during-autograd-backward branch from a6ffc67 to c5c4cdc Compare December 16, 2025 14:48
@yaugenst-flex yaugenst-flex added this pull request to the merge queue Dec 16, 2025
Merged via the queue into develop with commit e9b6d42 Dec 16, 2025
20 checks passed
@yaugenst-flex yaugenst-flex deleted the FXC-4540-decrease-local-field-projection-memory-overhead-especially-during-autograd-backward branch December 16, 2025 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants