Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 116 additions & 1 deletion tests/test_components/test_field_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from __future__ import annotations

import autograd.numpy as anp
import numpy as np
import pydantic.v1 as pydantic
import pytest
from autograd import make_vjp

import tidy3d as td
from tidy3d.components.field_projection import FieldProjector
from tidy3d.components.field_projection import FieldProjector, _far_field_integral
from tidy3d.exceptions import DataError

MEDIUM = td.Medium(permittivity=3)
Expand Down Expand Up @@ -665,3 +667,116 @@ def test_2d_sim_with_proj_monitors_near():
def test_trapezoid(array, pts, axes, expected):
result = FieldProjector.trapezoid(array, pts, axes)
assert np.allclose(result, expected)


@pytest.mark.parametrize("idx_u, idx_v", [(0, 1), (0, 2), (1, 2)])
def test_far_field_integral_vjp_3d(idx_u, idx_v):
rng = np.random.default_rng(0)
n_x, n_y, n_z = 3, 4, 5
n_theta, n_phi = 6, 7

currents = rng.standard_normal((n_x, n_y, n_z)) + 1j * rng.standard_normal((n_x, n_y, n_z))
pts = [
np.cumsum(rng.random(n_x)),
np.cumsum(rng.random(n_y)),
np.cumsum(rng.random(n_z)),
]
phase_0 = rng.standard_normal((n_x, n_theta, n_phi)) + 1j * rng.standard_normal(
(n_x, n_theta, n_phi)
)
phase_1 = rng.standard_normal((n_y, n_theta, n_phi)) + 1j * rng.standard_normal(
(n_y, n_theta, n_phi)
)
phase_2 = rng.standard_normal((n_z, n_theta)) + 1j * rng.standard_normal((n_z, n_theta))

def reference(currents_in):
chunk = anp.einsum("xtp,ytp,zt,xyz->xyztp", phase_0, phase_1, phase_2, currents_in)
axes = tuple(sorted((idx_u, idx_v)))
pts_int = tuple(pts[axis] for axis in axes)
return FieldProjector.trapezoid(chunk, pts_int, axes)

def primitive(currents_in):
return _far_field_integral(
currents_in,
phase_0,
phase_1,
phase_2,
pts,
idx_u,
idx_v,
is_2d=False,
idx_integration_1d=None,
)

vjp_primitive, ans_primitive = make_vjp(primitive)(currents)
vjp_reference, ans_reference = make_vjp(reference)(currents)

np.testing.assert_allclose(
np.asarray(ans_primitive), np.asarray(ans_reference), rtol=1e-12, atol=1e-12
)

g = rng.standard_normal(np.asarray(ans_reference).shape) + 1j * rng.standard_normal(
np.asarray(ans_reference).shape
)
grad_primitive = np.asarray(vjp_primitive(g))
grad_reference = np.asarray(vjp_reference(g))
np.testing.assert_allclose(grad_primitive, grad_reference, rtol=1e-10, atol=1e-10)


@pytest.mark.parametrize("idx_integration_1d", [0, 1, 2])
def test_far_field_integral_vjp_2d(idx_integration_1d):
rng = np.random.default_rng(1)
n_theta, n_phi = 5, 6

n_x, n_y, n_z = 1, 1, 1
if idx_integration_1d == 0:
n_x = 4
elif idx_integration_1d == 1:
n_y = 4
else:
n_z = 4

currents = rng.standard_normal((n_x, n_y, n_z)) + 1j * rng.standard_normal((n_x, n_y, n_z))
pts = [
np.cumsum(rng.random(n_x)) if n_x > 1 else np.zeros((n_x,)),
np.cumsum(rng.random(n_y)) if n_y > 1 else np.zeros((n_y,)),
np.cumsum(rng.random(n_z)) if n_z > 1 else np.zeros((n_z,)),
]
phase_0 = rng.standard_normal((n_x, n_theta, n_phi)) + 1j * rng.standard_normal(
(n_x, n_theta, n_phi)
)
phase_1 = rng.standard_normal((n_y, n_theta, n_phi)) + 1j * rng.standard_normal(
(n_y, n_theta, n_phi)
)
phase_2 = rng.standard_normal((n_z, n_theta)) + 1j * rng.standard_normal((n_z, n_theta))

def reference(currents_in):
chunk = anp.einsum("xtp,ytp,zt,xyz->xyztp", phase_0, phase_1, phase_2, currents_in)
return FieldProjector.trapezoid(chunk, pts[idx_integration_1d], idx_integration_1d)

def primitive(currents_in):
return _far_field_integral(
currents_in,
phase_0,
phase_1,
phase_2,
pts,
0,
1,
is_2d=True,
idx_integration_1d=idx_integration_1d,
)

vjp_primitive, ans_primitive = make_vjp(primitive)(currents)
vjp_reference, ans_reference = make_vjp(reference)(currents)

np.testing.assert_allclose(
np.asarray(ans_primitive), np.asarray(ans_reference), rtol=1e-12, atol=1e-12
)

g = rng.standard_normal(np.asarray(ans_reference).shape) + 1j * rng.standard_normal(
np.asarray(ans_reference).shape
)
grad_primitive = np.asarray(vjp_primitive(g))
grad_reference = np.asarray(vjp_reference(g))
np.testing.assert_allclose(grad_primitive, grad_reference, rtol=1e-10, atol=1e-10)
Loading