-
Notifications
You must be signed in to change notification settings - Fork 66
perf: reduce memory usage in field projection autograd backward pass #3086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: reduce memory usage in field projection autograd backward pass #3086
Conversation
49aa82a to
bb8954c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/field_projection.pyLines 129-137 129 optimize = not (phase_0.shape[1] == 1 and phase_0.shape[2] == 1)
130
131 if is_2d:
132 if idx_integration_1d is None:
! 133 raise ValueError("Expected 'idx_integration_1d' for 2D far-field projection.")
134
135 if idx_integration_1d == 0:
136 equation = "xtp,ytp,zt,xyz,x->yztp"
137 weight = weights[0]Lines 141-149 141 elif idx_integration_1d == 2:
142 equation = "xtp,ytp,zt,xyz,z->xytp"
143 weight = weights[2]
144 else:
! 145 raise ValueError(f"Invalid 2D integration axis: '{idx_integration_1d}'.")
146
147 return anp.einsum(
148 equation,
149 phase_0,Lines 948-959 948 fields_by_freq.append(fields_sum)
949
950 point_fields.append(anp.stack(fields_by_freq, axis=1))
951 else:
! 952 x_obs, y_obs, z_obs = monitor.sph_2_car(monitor.proj_distance, theta, phi)
! 953 fields_sum = anp.zeros((len(field_names), len(freqs)), dtype=complex)
! 954 for surface, currents in surface_currents:
! 955 fields_surface = self._fields_for_surface_exact(
956 x=x_obs,
957 y=y_obs,
958 z=z_obs,
959 surface=surface,Lines 959-969 959 surface=surface,
960 currents=currents,
961 medium=medium,
962 )
! 963 fields_surface = anp.reshape(fields_surface, fields_sum.shape)
! 964 fields_sum = fields_sum + fields_surface
! 965 point_fields.append(fields_sum)
966
967 stacked_fields = anp.stack(point_fields, axis=0)
968 stacked_fields = anp.reshape(
969 stacked_fields, (len(ux), len(uy), len(field_names), len(freqs)) |
momchil-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Impressive!
groberts-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @yaugenst-flex for the change! the performance looks great and I can confirm that when I try it in my example for far_fields_approx=True, I see a significant speedup and I no longer am seeing crashes from OOM.
I left some comments on the code - overall, I think there are some areas that could be a little more readable and some more comments/documentation in the code would make this easier to follow and tweak in the future. Nothing major in there, it seems like things are working well from the tests/profiling!
groberts-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for making the changes @yaugenst-flex , looks good!
c649652 to
a6ffc67
Compare
Avoid materializing the large (x,y,z,theta,phi) contraction intermediate during far-field projection. Instead, fold trapezoid weights into a single weighted contraction, which reduces autograd tape size and peak memory. The Cartesian far-field approximation calls this integral many times with single-angle (theta,phi) inputs; reuse precomputed trapezoid weights and skip `einsum` path optimization for the (1,1) angle case to avoid per-call planning overhead. Changes: - Add `_trapz_weights_1d` and `_far_field_integral` helpers. - Use `_far_field_integral` in `_far_fields_for_surface` and reuse weights per call. - Refactor Cartesian far-field accumulation to build per-point fields and stack. - Add VJP regression tests against the reference 5D+trapezoid implementation.
a6ffc67 to
c5c4cdc
Compare
Avoid materializing the large (x,y,z,theta,phi) contraction intermediate during far-field projection. Instead, fold trapezoid weights into a single weighted contraction, which reduces autograd tape size and peak memory.
The Cartesian far-field approximation calls this integral many times with single-angle (theta,phi) inputs; reuse precomputed trapezoid weights and skip
einsumpath optimization for the (1,1) angle case to avoid per-call planning overhead.Changes:
_trapz_weights_1dand_far_field_integralhelpers._far_field_integralin_far_fields_for_surfaceand reuse weights per call.Tested on the autograd metalens notebook and a couple of regression tests to make sure the results still match. Would be nice to get this in for 2.10 I guess.
Benchmarks
Note some of these are in log scale 😅
Greptile Overview
Greptile Summary
Refactored far-field projection to reduce memory usage in autograd backward pass by avoiding materialization of large intermediate tensors. Instead of computing the full (x,y,z,theta,phi) contraction and then applying trapezoidal integration, the implementation now folds trapezoid weights directly into a single weighted contraction via the new
_far_field_integralhelper.Key changes:
_trapz_weights_1dto precompute trapezoidal integration weights_far_field_integralthat combines phase multiplication, current contraction, and weighted integration in one einsum operationadd_atoperationsphase_0.shape[1] == 1 and phase_0.shape[2] == 1According to the PR benchmarks, this achieves significant memory reduction (shown on log scale) while maintaining numerical correctness.
Confidence Score: 5/5
Important Files Changed
File Analysis
_far_field_integralhelper to verify autograd backward pass correctnessSequence Diagram
sequenceDiagram participant Client participant FieldProjector participant _far_fields_for_surface participant _far_field_integral participant _trapz_weights_1d participant autograd Client->>FieldProjector: project_fields(monitor) FieldProjector->>FieldProjector: _project_fields_cartesian/kspace Note over FieldProjector: Precompute surface currents<br/>and weights once per monitor FieldProjector->>_trapz_weights_1d: compute weights for x,y,z _trapz_weights_1d-->>FieldProjector: weights[x,y,z] loop for each observation point (x,y,z) or (ux,uy) Note over FieldProjector: Convert to spherical coords loop for each frequency FieldProjector->>_far_fields_for_surface: compute fields at (theta,phi,f) Note over _far_fields_for_surface: Compute phase factors _far_fields_for_surface->>_far_fields_for_surface: phase_0, phase_1, phase_2 loop for each field component (E1,E2,H1,H2) _far_fields_for_surface->>_far_field_integral: currents, phases, weights Note over _far_field_integral: Single weighted einsum<br/>(avoids large intermediate) _far_field_integral->>autograd: einsum with folded weights autograd-->>_far_field_integral: integrated result _far_field_integral-->>_far_fields_for_surface: jm_i end Note over _far_fields_for_surface: Transform to spherical<br/>field components _far_fields_for_surface-->>FieldProjector: [Er,Etheta,Ephi,Hr,Htheta,Hphi] end Note over FieldProjector: Accumulate fields per point end Note over FieldProjector: Stack all point fields<br/>and reshape to grid FieldProjector-->>Client: FieldProjectionData