Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint #100

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions examples/example_adjoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import numpy as np
import sys
from time import time

sys.path.append("../flowpm/")
sys.path.append("../../../DifferentiableHOS/")
sys.path.append("../../DifferentiableHOS/")
import flowpm
import flowpm.scipy.interpolate as interpolate
from flowpm.tfpower import linear_matter_power
import tfpm
import jax_cosmo as jc

cosmology = flowpm.cosmology.Planck15()
cosmo = cosmology

from flowpm.utils import white_noise, c2r3d, r2c3d, cic_paint, cic_readout

fftk = flowpm.kernels.fftk
laplace_kernel = flowpm.kernels.laplace_kernel
gradient_kernel = flowpm.kernels.gradient_kernel

box_size = 100.0
nc = 8
nsteps = 4
stages = np.linspace(0.1, 1.0, nsteps, endpoint=True)
B = 1
pm_nc_factor = B
print("\n FOR %d steps\n" % nsteps)

klin = tf.constant(np.logspace(-4, 1, 512), dtype=tf.float32)
pk = linear_matter_power(cosmology, klin)
pk_fun = lambda x: tf.cast(
tf.reshape(
interpolate.interp_tf(
tf.reshape(tf.cast(x, tf.float32), [-1]), klin, pk),
x.shape,
),
tf.complex64,
)

# GenIC and data
ic = flowpm.linear_field([nc, nc, nc], [box_size, box_size, box_size],
pk_fun,
batch_size=1)

data = tf.random.uniform(ic.shape)


##############################################
### First, simple forward model and gradients with backpro
@tf.function
def pmsim(initial_conditions):
print("gen pm graph")
initial_state = flowpm.lpt_init(cosmology, initial_conditions, 0.1)
state = flowpm.nbody(
cosmology, initial_state, stages, [nc, nc, nc], pm_nc_factor=B)
return state


@tf.function
def lossfunc(x, x0):
field = flowpm.cic_paint(tf.zeros_like(ic), x)
l = tf.reduce_sum((field - x0)**2)
return l


@tf.function
def gradloss(x, x0):
with tf.GradientTape() as tape:
tape.watch(x)
loss = lossfunc(x, x0)
grad = tape.gradient(loss, x)
return grad


@tf.function
def backprop_grads(ic, data):
print("gen grad graph")
with tf.GradientTape() as tape:
tape.watch(ic)
state = pmsim(ic)
loss = lossfunc(state[0], data)
grad = tape.gradient(loss, ic)
return loss, grad


_ = pmsim(ic * np.random.uniform())
l1, backpropgrad = backprop_grads(ic, data)
print("loss 1 : ", l1)

##############################################
## first force calculation for jump starting


@tf.function
def gradicadj(ic, x0):
print("gen adjoint graph")
state = tf.stop_gradient(pmsim(ic))
loss = lossfunc(state[0], data)
print("Loss : ", loss)
adjx, adjv = 0.0 * gradloss(state[0], data), -1.0 * gradloss(state[0], data)
adj = tf.stop_gradient(
tfpm.adjoint(cosmo, state, adjx, adjv, stages[::-1].astype(np.float32),
[nc, nc, nc]))
state, adjx, adjv = adj[:3], adj[3:4], adj[4:5]
grad = tfpm.adjoint_lptinit(cosmo, ic, -adjx, -adjv, a0=0.1)
return loss, grad, state


start = time()
l2, adjgrad, state = gradicadj(ic, data)
print("time for making graph and first run adjoint : ", time() - start)
print("loss 2 ", l2)

# print(adjgrad/backpropgrad)
print(
"adjoint gradients and backprop are close : ",
np.allclose(adjgrad, backpropgrad, atol=1e-3),
)

##############################################
#####Time testing
niters = 10
start = time()
for i in range(niters):
pmsim(ic * np.random.uniform())
print("time for %d forward : " % niters, time() - start)

start = time()
for i in range(niters):
backprop_grads(ic * np.random.uniform(), data)
print("time for %d grads : " % niters, time() - start)

start = time()
for i in range(niters):
gradicadj(ic * np.random.uniform(), data)
print("time for %d grads adjoint : " % niters, time() - start)
200 changes: 178 additions & 22 deletions flowpm/tfpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,15 @@ def apply_pgd(x, delta_k, alpha, kl, ks, kvec=None, name="ApplyPGD"):
return f


def kick(cosmo, state, ai, ac, af, dtype=tf.float32, name="Kick", **kwargs):
def kick(cosmo,
state,
ai,
ac,
af,
return_factor=False,
dtype=tf.float32,
name="Kick",
**kwargs):
"""Kick the particles given the state
Parameters
----------
Expand All @@ -289,10 +297,21 @@ def kick(cosmo, state, ai, ac, af, dtype=tf.float32, name="Kick", **kwargs):
shape = state.shape
update = tf.scatter_nd(indices, update, shape)
state = tf.add(state, update)
return state
if return_factor:
return state, fac
else:
return state


def drift(cosmo, state, ai, ac, af, dtype=tf.float32, name="Drift", **kwargs):
def drift(cosmo,
state,
ai,
ac,
af,
return_factor=False,
dtype=tf.float32,
name="Drift",
**kwargs):
"""Drift the particles given the state
Parameters
----------
Expand All @@ -311,7 +330,40 @@ def drift(cosmo, state, ai, ac, af, dtype=tf.float32, name="Drift", **kwargs):
shape = state.shape
update = tf.scatter_nd(indices, update, shape)
state = tf.add(state, update)
return state
if return_factor:
return state, fac
else:
return state


def forcex(cosmo,
x,
nc,
pm_nc_factor=1,
dtype=tf.float32,
name="Forcex",
**kwargs):
with tf.name_scope(name):

shape = x.get_shape()
batch_size = shape[1]
ncf = [n * pm_nc_factor for n in nc]

rho = tf.zeros([batch_size] + ncf)
wts = tf.ones((batch_size, nc[0] * nc[1] * nc[2]))
nbar = nc[0] * nc[1] * nc[2] / (ncf[0] * ncf[1] * ncf[2])

rho = cic_paint(rho, tf.multiply(x[0], pm_nc_factor), wts)
rho = tf.multiply(rho,
1. / nbar) # I am not sure why this is not needed here
delta_k = r2c3d(rho, norm=ncf[0] * ncf[1] * ncf[2])
fac = tf.cast(1.5 * cosmo.Omega_m, dtype=dtype)
update = apply_longrange(
tf.multiply(x[0], pm_nc_factor), delta_k, split=0, factor=fac)

update = update / pm_nc_factor
update = tf.expand_dims(update, axis=0)
return update


def force(cosmo,
Expand All @@ -337,24 +389,7 @@ def force(cosmo,
with tf.name_scope(name):
state = tf.convert_to_tensor(state, name="state")

shape = state.get_shape()
batch_size = shape[1]
ncf = [n * pm_nc_factor for n in nc]

rho = tf.zeros([batch_size] + ncf)
wts = tf.ones((batch_size, nc[0] * nc[1] * nc[2]))
nbar = nc[0] * nc[1] * nc[2] / (ncf[0] * ncf[1] * ncf[2])

rho = cic_paint(rho, tf.multiply(state[0], pm_nc_factor), wts)
rho = tf.multiply(rho,
1. / nbar) # I am not sure why this is not needed here
delta_k = r2c3d(rho, norm=ncf[0] * ncf[1] * ncf[2])
fac = tf.cast(1.5 * cosmo.Omega_m, dtype=dtype)
update = apply_longrange(
tf.multiply(state[0], pm_nc_factor), delta_k, split=0, factor=fac)

update = tf.expand_dims(update, axis=0) / pm_nc_factor

update = forcex(cosmo, state[0:1], nc, pm_nc_factor)
indices = tf.constant([[2]])
shape = state.shape
update = tf.scatter_nd(indices, update, shape)
Expand Down Expand Up @@ -501,3 +536,124 @@ def nbody(cosmo,
return intermediate_states
else:
return state


def gradforcev(cosmo, x, adjx, nc):
"""
Internal function to combine backprop gradient of force with adjoint gradient
Parameters:
-----------
cosmo: cosmology
Cosmological parameter object
x: tensor (1, batch_size, npart, 3)
Current position of the particles
adjx: tensor (1, batch_size, npart, 3)
Current adjoint gradient with respect to position
Returns
nc: int, or list of ints
Number of cells
-------
d_adjv: tensor (1, batch_size, npart, 3)
Update to adjoint grdient of velocity
"""
if isinstance(nc, int):
nc = [nc, nc, nc]
with tf.GradientTape() as tape:
tape.watch(x)
f = forcex(cosmo, x, nc)
d_adjv = tape.gradient(f, x, output_gradients=adjx)
return d_adjv


def adjoint_lptinit(cosmo, ic, adjx, adjv, a0, order=2):
"""
Integrate the adjoint equation for evolution backwards starting from final state
Parameters:
-----------
cosmo: cosmology
Cosmological parameter object
ic: tensor (batch_size, nc, nc, nc)
Initial density field
adjx: tensor (1, batch_size, npart, 3)
Adjoint gradient of objective with respect to initial position
adjv: tensor (1, batch_size, npart, 3)
Adjoint gradient of objective with respect to initial velocity
a0: float
scale factor at which initial lpt conditionds are generated
order: int, 1 or 2
order for perturbative expansion for LPT initial conditions
Returns
-------
grad: tensor (batch_size, nc, nc, nc)
Gradient with respect to the initial conditions
"""
with tf.GradientTape(persistent=True) as tape:
tape.watch(ic)
state = lpt_init(cosmo, ic, a0, order=order)
x, v, f = tf.split(state, 3, 0)
gradx = tape.gradient(x, ic, output_gradients=adjv)
gradv = tape.gradient(v, ic, output_gradients=adjx)
return gradx + gradv


def adjoint(cosmo, state, adjx, adjv, stages, nc, pm_nc_factor=1):
"""
Integrate the adjoint equation for evolution backwards starting from final state
Parameters:
-----------
cosmo: cosmology
Cosmological parameter object
state: tensor (3, batch_size, npart, 3)
Final state after nbody simulation
adjx: tensor (1, batch_size, npart, 3)
Adjoint gradient of objective with respect to position
adjv: tensor (1, batch_size, npart, 3)
Adjoint gradient of objective with respect to velocity
stages: array
Array of scale factors in the reverse direction 1->0
nc: int, or list of ints
Number of cells
pm_nc_factor: int
Upsampling factor for computing
Returns
-------
state_and_adjoint: tensor (5, batch_size, npart, 3), or list of states
Integrated state to initial conditions concatenated with adjx and ajdv
"""
if pm_nc_factor != 1:
raise NotImplementedError
ai = stages[0]
if isinstance(nc, int):
nc = [nc, nc, nc]
x, p, f = ai, ai, ai
intermediate_states = []
for i in range(len(stages) - 1):
a0 = stages[i]
a1 = stages[i + 1]
ah = (a0 * a1)**0.5

# Kick
state, facv = kick(cosmo, state, p, f, ah, return_factor=True)
p = ah
# Update adjoint
d_adjv = tf.stop_gradient(
gradforcev(cosmo, state[0:1], adjx, nc) * facv * -1.0)
adjv = tf.stop_gradient(adjv + d_adjv)
# Drift step
state, facx = drift(cosmo, state, x, p, a1, return_factor=True)
x = a1
# Update adjoint
d_adjx = tf.stop_gradient(adjv * facx * -1.0)
adjx = tf.stop_gradient(adjx + d_adjx)
# Force
state = tf.stop_gradient(force(cosmo, state, nc))
f = a1
# Kick step
state, facv = kick(cosmo, state, p, f, a1, return_factor=True)
p = a1
# Update adjoint
d_adjv = tf.stop_gradient(
gradforcev(cosmo, state[0:1], adjx, nc) * facv * -1.0)
adjv = tf.stop_gradient(adjv + d_adjv)

return tf.concat([state, adjx, adjv], 0)