Version:

Gradient Test TTI Media

Copy
import os
import typing
from pathlib import Path

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pyasdf
import random

from salvus.flow import api
import salvus.flow.simple_config as sc
import salvus.mesh
from salvus.mesh import mesh_block
from salvus.mesh.mesh_block import MeshBlockCollection
from gradient_helper_functions import (
    get_unit_cube_mesh,
    test_material_gradient,
)
from salvus.mesh import simple_mesh
from salvus.mesh.unstructured_mesh import UnstructuredMesh

import parameterization

Build mesh

Choose either a cartesian or spherical reference frame.
model_order = random.choice([1, 2])
reference_frame = os.environ.get("REFERENCE_FRAME", "spherical")
site_name = os.environ.get("SITE_NAME", "local_f64")
print(f"Running gradient test on site: {site_name}")

p1 = parameterization.ElasticTti3D(
    rho=2.6e3, vsh=4.0e3 * 1.2, vsv=4.0e3, vph=5.8e3 * 1.2, vpv=5.8e3, eta=1.0
)

# Set a random orientation vector.
if reference_frame == "cartesian":
    setattr(p1, "_axis", np.array([1.5, 0.75, 3.0]))
Running gradient test on site: local_f64
n_elem_per_dim = 2
if reference_frame == "spherical":
    print("\nSPHERICAL MESH\n")
    receiver_type = "classic"
    try:
        mesh = MeshBlockCollection(
            mesh_block._generators.spherical.central_sphere_3d(
                element_count_lat=3, r_outer=np.sqrt(3)
            )
        ).get_unstructured_mesh()
    except Exception as e:
        # Backwards compatibility with 0.12.X
        print(e)
        mesh = MeshBlockCollection(
            mesh_block._generators.spherical.central_sphere_3d(
                nelem_lat=3, r_outer=np.sqrt(3)
            )
        ).get_unstructured_mesh()

    mesh.attach_global_variable("reference_frame", "spherical")
    mesh.change_tensor_order(model_order, interpolation_mode="spherical")
    mesh.attach_field("fluid", np.zeros(mesh.nelem))

    for name, val in p1.parameters():
        if name != "AXIS":
            mesh.attach_field(
                name, np.ones_like(mesh.get_element_nodes()[:, :, 0]) * val
            )
        else:
            for i, a in enumerate(val):
                mesh.attach_global_variable(f"a_{i}", a)

    mesh.find_side_sets(mode="spherical_full")

else:
    print("\nCUBE MESH\n")
    receiver_type = "block"
    mesh = get_unit_cube_mesh(
        dim=3,
        model_order=model_order,
        par=p1,
        deform=False,
        n_elem_per_dim=n_elem_per_dim,
    )

SPHERICAL MESH

pval_s = 0.15
pval_r = 0.55

r_loc = [(pval_r, pval_r, pval_r)]
s_loc = [(1 - pval_s, 1 - pval_s, 1 - pval_s)]

if reference_frame == "spherical":
    print("\nSEISMOLOGICAL SOURCE\n")
    sources = [
        sc.source.seismology.VectorPoint3D(
            fr=1e21,
            ft=1e21,
            fp=1e21,
            latitude=-47.5,
            radius_of_sphere_in_m=np.sqrt(3),
            depth_in_m=1,
            longitude=-112.7,
            source_time_function=sc.stf.Delta(),
        )
        for x in s_loc
    ]

else:
    print("\nCARTESIAN SOURCE\n")
    sources = [
        sc.source.cartesian.VectorPoint3D(
            fx=1e21,
            fy=1e21,
            fz=1e21,
            x=x[0],
            y=x[1],
            z=x[2],
            source_time_function=sc.stf.Delta(),
        )
        for x in s_loc
    ]

if reference_frame == "spherical":
    print("\nSEISMOLOGICAL RECEIVER\n")
    recs = [
        sc.receiver.seismology.Point3D(
            latitude=33.35,
            longitude=45,
            depth_in_m=np.sqrt(3) - 0.9526,
            radius_of_sphere_in_m=np.sqrt(3),
            station_code=f"{i:03d}",
            fields=["displacement"],
        )
        for i, x in enumerate(r_loc)
    ]

else:
    print("\nCARTESIAN RECEIVER\n")
    recs = [
        sc.receiver.cartesian.Point3D(
            x=x[0],
            y=x[1],
            z=x[2],
            station_code=f"{i:03d}",
            fields=["displacement"],
        )
        for i, x in enumerate(r_loc)
    ]

SEISMOLOGICAL SOURCE


SEISMOLOGICAL RECEIVER

w = sc.simulation.Waveform(mesh=mesh, sources=sources, receivers=recs)

# Timing.
start_time, end_time, time_step = 0.0, 5e-4, 5e-6 / n_elem_per_dim
w.physics.wave_equation.end_time_in_seconds = end_time
w.physics.wave_equation.time_step_in_seconds = time_step
w.physics.wave_equation.start_time_in_seconds = start_time

# For gradient computation.
w.output.volume_data.format = "hdf5"
w.output.volume_data.filename = "output.h5"
w.output.volume_data.fields = ["adjoint-checkpoint"]
w.output.volume_data.sampling_interval_in_time_steps = 10000
w.validate()

mesh.write_h5("test.h5")

if reference_frame == "spherical":
    assert mesh.global_strings["reference_frame"] == "spherical"
api.run(
    ranks=2,
    input_file=w,
    get_all=True,
    site_name=site_name,
    overwrite=True,
    delete_remote_files=False,
    output_folder="fwd_output",
    verbosity=0,
)
<salvus.flow.executors.salvus_job.SalvusJob at 0x75cceee32110>
def compute_misfit(
    rec_file: Path, adj_src: typing.Union[Path, None] = None, rot_mat=None
):
    """
    A function to compute the energy misfit and corresponding adjoin source.
    :param rec_file: File containing the receivers from the forward run.
    :param adj_src: File which, if provided, will contain the adjoint sources.
    :returns: A measure of misft. Will write adjoint sources if adj_src is a valid file path.
    """

    misfit = 0.0
    if adj_src is not None and adj_src.is_file():
        adj_src.unlink()
    adj_out = h5py.File(adj_src, mode="w") if adj_src else None
    with pyasdf.ASDFDataSet(rec_file, mode="r") as fh:
        for rec in fh.waveforms:
            u = rec.displacement
            adj = np.empty((u[0].stats.npts, len(u)))

            if reference_frame == "cartesian":
                for _i, cmp in enumerate(u):
                    misfit += time_step * (cmp.data * cmp.data).sum()
                    adj[:, _i] = -time_step * cmp.data
            else:
                for _i, code in enumerate(["Z", "N", "E"]):
                    cmp = u.select(component=code)[0]
                    misfit += time_step * (cmp.data * cmp.data).sum()
                    adj[:, _i] = -time_step * cmp.data

            if adj_out:
                rot_mat = np.eye(3)
                adj = adj[:, :, np.newaxis]
                if receiver_type == "classic":
                    dset = adj_out.create_dataset(
                        name=u[0].stats.station,
                        data=(np.matmul(rot_mat, adj)).squeeze(),
                    )
                    dset.attrs["starttime"] = start_time * 1e9
                    dset.attrs["sampling_rate"] = 1 / time_step

                else:
                    ncmp = 3
                    stf = np.empty((1, ncmp, adj.shape[0]))
                    stf[0, :] = adj.T
                    group = adj_out.create_group("adjoint_sources")
                    group.create_dataset(name="stf", data=stf)
                    group.create_dataset(
                        name="coordinates", data=np.atleast_2d(r_loc)
                    )
                    group.attrs["start_time_in_seconds"] = start_time
                    group.attrs["sampling_rate_in_hertz"] = 1 / time_step
                    group.attrs["spatial_type"] = np.string_("vector")

    if adj_out:
        adj_out.close()

    return misfit * 0.5


rx, ry, rz = w.output.point_data.receiver[0].location
misfit_00 = compute_misfit(
    Path("fwd_output/receivers.h5"), Path("adjoint_sources.h5")
)
if reference_frame == "spherical":
    print("\nZNE ADJOINT SOURCE\n")
    adj_srcs = [
        sc.source.seismology.VectorPoint3DZNE(
            latitude=33.35,
            longitude=45,
            depth_in_m=np.sqrt(3) - 0.9526,
            radius_of_sphere_in_m=np.sqrt(3),
            fz=1,
            fn=1,
            fe=1,
            source_time_function=sc.stf.Custom(
                filename="adjoint_sources.h5", dataset_name="000"
            ),
        )
    ]

else:
    print("\nCARTESIAN ADJOINT SOURCE\n")
    adj_srcs = [
        sc.source.cartesian.VectorPoint3D(
            x=r_loc[0][0],
            y=r_loc[0][1],
            z=r_loc[0][2],
            fx=1,
            fy=1,
            fz=1,
            source_time_function=sc.stf.Custom(
                filename="adjoint_sources.h5", dataset_name="000"
            ),
        )
    ]

ZNE ADJOINT SOURCE

w_adjoint = sc.simulation.Waveform(mesh=mesh)

if receiver_type == "classic":
    w_adjoint.adjoint.point_source = adj_srcs
else:
    w_adjoint.adjoint.point_source_block = {
        "filename": "adjoint_sources.h5",
        "groups": ["adjoint_sources"],
    }

w_adjoint.adjoint.gradient.output_filename = "gradient.h5"
w_adjoint.adjoint.forward_meta_json_filename = "fwd_output/meta.json"
w_adjoint.adjoint.gradient.parameterization = "tti"

w_adjoint.validate()
job = api.run(
    ranks=2,
    get_all=True,
    overwrite=True,
    site_name=site_name,
    delete_remote_files=False,
    input_file=w_adjoint,
    output_folder="adj_output",
    verbosity=0,
)

# print(job.stdout)
g = None
gradient = UnstructuredMesh.from_h5("adj_output/gradient.h5")
# Update model
h = np.logspace(-11, -2, 10)
pars = [x for x, y in p1.parameters()]

nx = 1
ny = 1
f, a = plt.subplots(ny, nx, figsize=(5, 5))

# Perturb parameters one at a time.
print(
    "{:-^80}".format(
        "Results ({}, order={})".format(reference_frame, model_order)
    )
)
for _i, p in enumerate(pars):
    if p in {"AXIS", "TTISymmetryAxis"}:
        continue

    # first try to get away with just three runs
    error = test_material_gradient(
        all_h=[h[0], h[3], h[-1]],
        model=mesh,
        gradient=gradient,
        parameter=p,
        simulation=w,
        misfit_function=compute_misfit,
        site_name=site_name,
        m0=misfit_00,
        ranks=2,
        quiet=True,
    )

    try:
        assert np.min(error) < 1e-5
        assert error[0] > np.min(error)
        assert error[-1] > np.min(error)
        used_h = [h[0], h[3], h[-1]]
    except:
        # run again with more step lengths
        error = test_material_gradient(
            all_h=h,
            model=mesh,
            gradient=gradient,
            parameter=p,
            simulation=w,
            misfit_function=compute_misfit,
            site_name=site_name,
            m0=misfit_00,
            ranks=2,
            quiet=True,
        )
        used_h = h
        assert np.min(error) < 1e-5
        assert error[0] > np.min(error)
        assert error[-1] > np.min(error)

    finally:
        # Generate a plot
        a.set_title(f"Gradient Test TTI({reference_frame})")
        a.loglog(used_h, error, label=p)
        a.set_xlabel("$h$")
        a.set_ylabel("Relative Error")

        print(
            "{0:<20}{1:<30}{2:<10}{3:<30}{4:<30}".format(
                p, np.min(error), np.argmin(error), error[0], error[-1]
            )
        )

print("{:-^80}\n".format(""))
a.legend()
--------------------------Results (spherical, order=1)--------------------------
RHO                 5.889787696027617e-08         1         1.635883756670877e-05         0.014802470346028435          
VPH                 1.4516113676786748e-07        1         3.6689503789444004e-06        0.10946294717764225           
VPV                 1.4884298322110887e-07        1         1.0050892429100884e-05        0.09986034125621442           
VSH                 7.903175933516226e-08         1         2.9766573858619308e-05        0.01308495053214165           
VSV                 6.836167491694316e-08         1         1.6500990511313224e-05        0.07252946165574045           
ETA                 3.67449215153235e-07          1         0.00010593024004455591        0.00793039432590833           
--------------------------------------------------------------------------------

<matplotlib.legend.Legend at 0x75ccee2f5e10>
PAGE CONTENTS