import os
import typing
from pathlib import Path
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pyasdf
import random
from salvus.flow import api
import salvus.flow.simple_config as sc
import salvus.mesh
from salvus.mesh import mesh_block
from salvus.mesh.mesh_block import MeshBlockCollection
from gradient_helper_functions import (
get_unit_cube_mesh,
test_material_gradient,
)
from salvus.mesh import simple_mesh
from salvus.mesh.unstructured_mesh import UnstructuredMesh
import parameterization
model_order = random.choice([1, 2])
reference_frame = os.environ.get("REFERENCE_FRAME", "spherical")
site_name = os.environ.get("SITE_NAME", "local_f64")
print(f"Running gradient test on site: {site_name}")
p1 = parameterization.ElasticTti3D(
rho=2.6e3, vsh=4.0e3 * 1.2, vsv=4.0e3, vph=5.8e3 * 1.2, vpv=5.8e3, eta=1.0
)
# Set a random orientation vector.
if reference_frame == "cartesian":
setattr(p1, "_axis", np.array([1.5, 0.75, 3.0]))
Running gradient test on site: local_f64
n_elem_per_dim = 2
if reference_frame == "spherical":
print("\nSPHERICAL MESH\n")
receiver_type = "classic"
try:
mesh = MeshBlockCollection(
mesh_block._generators.spherical.central_sphere_3d(
element_count_lat=3, r_outer=np.sqrt(3)
)
).get_unstructured_mesh()
except Exception as e:
# Backwards compatibility with 0.12.X
print(e)
mesh = MeshBlockCollection(
mesh_block._generators.spherical.central_sphere_3d(
nelem_lat=3, r_outer=np.sqrt(3)
)
).get_unstructured_mesh()
mesh.attach_global_variable("reference_frame", "spherical")
mesh.change_tensor_order(model_order, interpolation_mode="spherical")
mesh.attach_field("fluid", np.zeros(mesh.nelem))
for name, val in p1.parameters():
if name != "AXIS":
mesh.attach_field(
name, np.ones_like(mesh.get_element_nodes()[:, :, 0]) * val
)
else:
for i, a in enumerate(val):
mesh.attach_global_variable(f"a_{i}", a)
mesh.find_side_sets(mode="spherical_full")
else:
print("\nCUBE MESH\n")
receiver_type = "block"
mesh = get_unit_cube_mesh(
dim=3,
model_order=model_order,
par=p1,
deform=False,
n_elem_per_dim=n_elem_per_dim,
)
SPHERICAL MESH
pval_s = 0.15
pval_r = 0.55
r_loc = [(pval_r, pval_r, pval_r)]
s_loc = [(1 - pval_s, 1 - pval_s, 1 - pval_s)]
if reference_frame == "spherical":
print("\nSEISMOLOGICAL SOURCE\n")
sources = [
sc.source.seismology.VectorPoint3D(
fr=1e21,
ft=1e21,
fp=1e21,
latitude=-47.5,
radius_of_sphere_in_m=np.sqrt(3),
depth_in_m=1,
longitude=-112.7,
source_time_function=sc.stf.Delta(),
)
for x in s_loc
]
else:
print("\nCARTESIAN SOURCE\n")
sources = [
sc.source.cartesian.VectorPoint3D(
fx=1e21,
fy=1e21,
fz=1e21,
x=x[0],
y=x[1],
z=x[2],
source_time_function=sc.stf.Delta(),
)
for x in s_loc
]
if reference_frame == "spherical":
print("\nSEISMOLOGICAL RECEIVER\n")
recs = [
sc.receiver.seismology.Point3D(
latitude=33.35,
longitude=45,
depth_in_m=np.sqrt(3) - 0.9526,
radius_of_sphere_in_m=np.sqrt(3),
station_code=f"{i:03d}",
fields=["displacement"],
)
for i, x in enumerate(r_loc)
]
else:
print("\nCARTESIAN RECEIVER\n")
recs = [
sc.receiver.cartesian.Point3D(
x=x[0],
y=x[1],
z=x[2],
station_code=f"{i:03d}",
fields=["displacement"],
)
for i, x in enumerate(r_loc)
]
SEISMOLOGICAL SOURCE SEISMOLOGICAL RECEIVER
w = sc.simulation.Waveform(mesh=mesh, sources=sources, receivers=recs)
# Timing.
start_time, end_time, time_step = 0.0, 5e-4, 5e-6 / n_elem_per_dim
w.physics.wave_equation.end_time_in_seconds = end_time
w.physics.wave_equation.time_step_in_seconds = time_step
w.physics.wave_equation.start_time_in_seconds = start_time
# For gradient computation.
w.output.volume_data.format = "hdf5"
w.output.volume_data.filename = "output.h5"
w.output.volume_data.fields = ["adjoint-checkpoint"]
w.output.volume_data.sampling_interval_in_time_steps = 10000
w.validate()
mesh.write_h5("test.h5")
if reference_frame == "spherical":
assert mesh.global_strings["reference_frame"] == "spherical"
api.run(
ranks=2,
input_file=w,
get_all=True,
site_name=site_name,
overwrite=True,
delete_remote_files=False,
output_folder="fwd_output",
verbosity=0,
)
<salvus.flow.executors.salvus_job.SalvusJob at 0x799c85bb7510>
def compute_misfit(
rec_file: Path, adj_src: typing.Union[Path, None] = None, rot_mat=None
):
"""
A function to compute the energy misfit and corresponding adjoin source.
:param rec_file: File containing the receivers from the forward run.
:param adj_src: File which, if provided, will contain the adjoint sources.
:returns: A measure of misft. Will write adjoint sources if adj_src is a valid file path.
"""
misfit = 0.0
if adj_src is not None and adj_src.is_file():
adj_src.unlink()
adj_out = h5py.File(adj_src, mode="w") if adj_src else None
with pyasdf.ASDFDataSet(rec_file, mode="r") as fh:
for rec in fh.waveforms:
u = rec.displacement
adj = np.empty((u[0].stats.npts, len(u)))
if reference_frame == "cartesian":
for _i, cmp in enumerate(u):
misfit += time_step * (cmp.data * cmp.data).sum()
adj[:, _i] = -time_step * cmp.data
else:
for _i, code in enumerate(["Z", "N", "E"]):
cmp = u.select(component=code)[0]
misfit += time_step * (cmp.data * cmp.data).sum()
adj[:, _i] = -time_step * cmp.data
if adj_out:
rot_mat = np.eye(3)
adj = adj[:, :, np.newaxis]
if receiver_type == "classic":
dset = adj_out.create_dataset(
name=u[0].stats.station,
data=(np.matmul(rot_mat, adj)).squeeze(),
)
dset.attrs["starttime"] = start_time * 1e9
dset.attrs["sampling_rate"] = 1 / time_step
else:
ncmp = 3
stf = np.empty((1, ncmp, adj.shape[0]))
stf[0, :] = adj.T
group = adj_out.create_group("adjoint_sources")
group.create_dataset(name="stf", data=stf)
group.create_dataset(
name="coordinates", data=np.atleast_2d(r_loc)
)
group.attrs["start_time_in_seconds"] = start_time
group.attrs["sampling_rate_in_hertz"] = 1 / time_step
group.attrs["spatial_type"] = np.string_("vector")
if adj_out:
adj_out.close()
return misfit * 0.5
rx, ry, rz = w.output.point_data.receiver[0].location
misfit_00 = compute_misfit(
Path("fwd_output/receivers.h5"), Path("adjoint_sources.h5")
)
if reference_frame == "spherical":
print("\nZNE ADJOINT SOURCE\n")
adj_srcs = [
sc.source.seismology.VectorPoint3DZNE(
latitude=33.35,
longitude=45,
depth_in_m=np.sqrt(3) - 0.9526,
radius_of_sphere_in_m=np.sqrt(3),
fz=1,
fn=1,
fe=1,
source_time_function=sc.stf.Custom(
filename="adjoint_sources.h5", dataset_name="000"
),
)
]
else:
print("\nCARTESIAN ADJOINT SOURCE\n")
adj_srcs = [
sc.source.cartesian.VectorPoint3D(
x=r_loc[0][0],
y=r_loc[0][1],
z=r_loc[0][2],
fx=1,
fy=1,
fz=1,
source_time_function=sc.stf.Custom(
filename="adjoint_sources.h5", dataset_name="000"
),
)
]
ZNE ADJOINT SOURCE
w_adjoint = sc.simulation.Waveform(mesh=mesh)
if receiver_type == "classic":
w_adjoint.adjoint.point_source = adj_srcs
else:
w_adjoint.adjoint.point_source_block = {
"filename": "adjoint_sources.h5",
"groups": ["adjoint_sources"],
}
w_adjoint.adjoint.gradient.output_filename = "gradient.h5"
w_adjoint.adjoint.forward_meta_json_filename = "fwd_output/meta.json"
w_adjoint.adjoint.gradient.parameterization = "tti"
w_adjoint.validate()
job = api.run(
ranks=2,
get_all=True,
overwrite=True,
site_name=site_name,
delete_remote_files=False,
input_file=w_adjoint,
output_folder="adj_output",
verbosity=0,
)
# print(job.stdout)
g = None
gradient = UnstructuredMesh.from_h5("adj_output/gradient.h5")
# Update model
h = np.logspace(-11, -2, 10)
pars = [x for x, y in p1.parameters()]
nx = 1
ny = 1
f, a = plt.subplots(ny, nx, figsize=(5, 5))
# Perturb parameters one at a time.
print(
"{:-^80}".format(
"Results ({}, order={})".format(reference_frame, model_order)
)
)
for _i, p in enumerate(pars):
if p in {"AXIS", "TTISymmetryAxis"}:
continue
# first try to get away with just three runs
error = test_material_gradient(
all_h=[h[0], h[3], h[-1]],
model=mesh,
gradient=gradient,
parameter=p,
simulation=w,
misfit_function=compute_misfit,
site_name=site_name,
m0=misfit_00,
ranks=2,
quiet=True,
)
try:
assert np.min(error) < 1e-5
assert error[0] > np.min(error)
assert error[-1] > np.min(error)
used_h = [h[0], h[3], h[-1]]
except:
# run again with more step lengths
error = test_material_gradient(
all_h=h,
model=mesh,
gradient=gradient,
parameter=p,
simulation=w,
misfit_function=compute_misfit,
site_name=site_name,
m0=misfit_00,
ranks=2,
quiet=True,
)
used_h = h
assert np.min(error) < 1e-5
assert error[0] > np.min(error)
assert error[-1] > np.min(error)
finally:
# Generate a plot
a.set_title(f"Gradient Test TTI({reference_frame})")
a.loglog(used_h, error, label=p)
a.set_xlabel("$h$")
a.set_ylabel("Relative Error")
print(
"{0:<20}{1:<30}{2:<10}{3:<30}{4:<30}".format(
p, np.min(error), np.argmin(error), error[0], error[-1]
)
)
print("{:-^80}\n".format(""))
a.legend()
--------------------------Results (spherical, order=1)-------------------------- RHO 3.840132793153242e-08 1 1.6358837567118707e-05 0.014802470346048527 VPH 1.5433767180329438e-07 1 1.2845484864329721e-05 0.10946294717766121 VPV 1.626920021787245e-07 1 1.005089242785446e-05 0.09986034125618534 VSH 1.1071537105065033e-07 1 6.145018335847246e-05 0.013084950532207264 VSV 6.836167641500901e-08 1 2.9593340023741472e-05 0.07252946165574206 ETA 4.7104371354806476e-07 1 2.335744631044868e-06 0.00793039432580769 --------------------------------------------------------------------------------
<matplotlib.legend.Legend at 0x799c6bc1d510>