Copy
import os
import typing
from pathlib import Path

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pyasdf
import random

import salvus_flow.api
import salvus_flow.simple_config as sc
import salvus_mesh
from gradient_helper_functions import (
    get_unit_cube_mesh,
    test_material_gradient,
)
from salvus_flow._experimental import parameterization
from salvus_mesh import simple_mesh
from salvus_mesh.unstructured_mesh import UnstructuredMesh

Build mesh

Choose either a cartesian or spherical reference frame.

model_order = random.choice([1, 2])
reference_frame = os.environ.get("REFERENCE_FRAME", "spherical")

p1 = parameterization.ElasticTti3D(
    rho=2.6e3, vsh=4.0e3 * 1.2, vsv=4.0e3, vph=5.8e3 * 1.2, vpv=5.8e3, eta=1.0
)

# Set a random orientation vector.
if reference_frame == "cartesian":
    setattr(p1, "_axis", np.random.random_sample(size=3) * 2 * np.pi)
n_elem_per_dim = 2
if reference_frame == "spherical":

    print("\nSPHERICAL MESH\n")
    receiver_type = "classic"
    mesh = salvus_mesh.Skeleton(
        salvus_mesh.StructuredGrid3D.central_sphere(
            nelem_lat=3, r_outer=np.sqrt(3)
        )
    ).get_unstructured_mesh()

    mesh.change_tensor_order(model_order, spherical=True)
    mesh.attach_field("fluid", np.zeros(mesh.nelem))

    for name, val in p1.parameters():
        if name != "AXIS":
            mesh.attach_field(
                name, np.ones_like(mesh.get_element_nodes()[:, :, 0]) * val
            )
        else:
            for i, a in enumerate(val):
                mesh.attach_global_variable(f"a_{i}", a)

    mesh.find_side_sets(mode="spherical_full")

else:

    print("\nCUBE MESH\n")
    receiver_type = "block"
    mesh = get_unit_cube_mesh(
        dim=3,
        model_order=model_order,
        par=p1,
        deform=False,
        n_elem_per_dim=n_elem_per_dim,
    )
SPHERICAL MESH

pval_s = 0.15
pval_r = 0.55

r_loc = [(pval_r, pval_r, pval_r)]
s_loc = [(1 - pval_s, 1 - pval_s, 1 - pval_s)]

if reference_frame == "spherical":

    print("\nSEISMOLOGICAL SOURCE\n")
    sources = [
        sc.source.seismology.VectorPoint3D(
            fr=1e21,
            ft=1e21,
            fp=1e21,
            latitude=np.random.uniform(0, np.pi),
            radius_of_sphere_in_m=np.sqrt(3),
            depth_in_m=1,
            longitude=np.random.uniform(0, 2 * np.pi),
            source_time_function=sc.stf.Delta(),
        )
        for x in s_loc
    ]

else:

    print("\nCARTESIAN SOURCE\n")
    sources = [
        sc.source.cartesian.VectorPoint3D(
            fx=1e21,
            fy=1e21,
            fz=1e21,
            x=x[0],
            y=x[1],
            z=x[2],
            source_time_function=sc.stf.Delta(),
        )
        for x in s_loc
    ]

if reference_frame == "spherical":

    print("\nSEISMOLOGICAL RECEIVER\n")
    recs = [
        sc.receiver.seismology.Point3D(
            latitude=33.35,
            longitude=45,
            depth_in_m=np.sqrt(3) - 0.9526,
            radius_of_sphere_in_m=np.sqrt(3),
            station_code=f"{i:03d}",
            fields=["displacement"],
        )
        for i, x in enumerate(r_loc)
    ]

else:

    print("\nCARTESIAN RECEIVER\n")
    recs = [
        sc.receiver.cartesian.Point3D(
            x=x[0],
            y=x[1],
            z=x[2],
            station_code=f"{i:03d}",
            fields=["displacement"],
        )
        for i, x in enumerate(r_loc)
    ]
SEISMOLOGICAL SOURCE


SEISMOLOGICAL RECEIVER

w = sc.simulation.Waveform(mesh=mesh, sources=sources, receivers=recs)

# Timing.
start_time, end_time, time_step = 0.0, 5e-4, 5e-6 / n_elem_per_dim
w.physics.wave_equation.end_time_in_seconds = end_time
w.physics.wave_equation.time_step_in_seconds = time_step
w.physics.wave_equation.start_time_in_seconds = start_time

# For gradient computation.
w.output.volume_data.format = "hdf5"
w.output.volume_data.filename = "output.h5"
w.output.volume_data.fields = ["adjoint-checkpoint"]
w.output.volume_data.sampling_interval_in_time_steps = 10000
w.validate()

mesh.write_h5("test.h5")
salvus_flow.api.run(
    ranks=2,
    input_file=w,
    get_all=True,
    site_name="local_f64",
    overwrite=True,
    delete_remote_files=False,
    output_folder="fwd_output",
    verbosity=2,
)
Uploading 1 files...
🚀  Submitted job_1910221515131961_4acc42a95b[[email protected]local_f64
Job `job_1910221515131961_4acc42a95b` running on `local_f64` with 2 rank(s).
Site information:
  * Salvus version: 0.10.5
  * Floating point size: 64

-> Current Task: Time loop complete 🍿  Downloading 7 files...


 🎁  Downloaded 2.0 MB to 'fwd_output'.
* Downloaded 2.0 MB of results to `fwd_output`.
* Total run time: 1.19 seconds.
* Pure simulation time: 0.59 seconds.
<salvus_flow.sites.salvus_job.SalvusJob at 0x7f8976a5a390>
def compute_misfit(
    rec_file: Path, adj_src: typing.Union[Path, None] = None, rot_mat=None
):
    """
    A function to compute the energy misfit and corresponding adjoin source.
    :param rec_file: File containing the receivers from the forward run.
    :param adj_src: File which, if provided, will contain the adjoint sources.
    :returns: A measure of misft. Will write adjoint sources if adj_src is a valid file path.
    """

    misfit = 0.0
    adj_out = h5py.File(adj_src, mode="w") if adj_src else None
    with pyasdf.ASDFDataSet(rec_file, mode="r") as fh:

        for rec in fh.waveforms:
            u = rec.displacement
            adj = np.empty((u[0].stats.npts, len(u)))

            if reference_frame == "cartesian":
                for _i, cmp in enumerate(u):
                    misfit += time_step * (cmp.data * cmp.data).sum()
                    adj[:, _i] = -time_step * cmp.data
            else:
                for _i, code in enumerate(["Z", "N", "E"]):
                    cmp = u.select(component=code)[0]
                    misfit += time_step * (cmp.data * cmp.data).sum()
                    adj[:, _i] = -time_step * cmp.data

            if adj_out:

                rot_mat = np.eye(3)
                adj = adj[:, :, np.newaxis]
                if receiver_type == "classic":
                    dset = adj_out.create_dataset(
                        name=u[0].stats.station,
                        data=(np.matmul(rot_mat, adj)).squeeze(),
                    )
                    dset.attrs["starttime"] = start_time * 1e9
                    dset.attrs["sampling_rate"] = 1 / time_step

                else:
                    ncmp = 3
                    stf = np.empty((1, ncmp, adj.shape[0]))
                    stf[0, :] = adj.T
                    group = adj_out.create_group("adjoint_sources")
                    group.create_dataset(name="stf", data=stf)
                    group.create_dataset(
                        name="coordinates", data=np.atleast_2d(r_loc)
                    )
                    group.attrs["start_time_in_seconds"] = start_time
                    group.attrs["sampling_rate_in_hertz"] = 1 / time_step
                    group.attrs["spatial-type"] = np.string_("vector")

    if adj_out:
        adj_out.close()

    return misfit * 0.5


rx, ry, rz = w.output.point_data.receiver[0].location
misfit_00 = compute_misfit(
    Path("fwd_output/receivers.h5"), Path("adjoint_sources.h5")
)

rx, ry, rz = w.output.point_data.receiver[0].location
misfit_00 = compute_misfit(
    Path("fwd_output/receivers.h5"), Path("adjoint_sources.h5")
)
if reference_frame == "spherical":

    print("\nZNE ADJOINT SOURCE\n")
    adj_srcs = [
        sc.source.seismology.VectorPoint3DZNE(
            latitude=33.35,
            longitude=45,
            depth_in_m=np.sqrt(3) - 0.9526,
            radius_of_sphere_in_m=np.sqrt(3),
            fz=1,
            fn=1,
            fe=1,
            source_time_function=sc.stf.Custom(
                filename="adjoint_sources.h5", dataset_name="000"
            ),
        )
    ]

else:

    print("\nCARTESIAN ADJOINT SOURCE\n")
    adj_srcs = [
        sc.source.cartesian.VectorPoint3D(
            x=r_loc[0][0],
            y=r_loc[0][1],
            z=r_loc[0][2],
            fx=1,
            fy=1,
            fz=1,
            source_time_function=sc.stf.Custom(
                filename="adjoint_sources.h5", dataset_name="000"
            ),
        )
    ]
ZNE ADJOINT SOURCE

w_adjoint = sc.simulation.Waveform(mesh=mesh)

if receiver_type == "classic":
    w_adjoint.adjoint.point_source = adj_srcs
else:
    w_adjoint.adjoint.point_source_block = {
        "filename": "adjoint_sources.h5",
        "groups": ["adjoint_sources"],
    }

w_adjoint.adjoint.gradient.output_filename = "gradient.h5"
w_adjoint.adjoint.forward_meta_json_filename = "fwd_output/meta.json"
w_adjoint.adjoint.gradient.parameterization = "tti"

w_adjoint.validate()
job = salvus_flow.api.run(
    ranks=2,
    get_all=True,
    overwrite=True,
    site_name="local_f64",
    delete_remote_files=False,
    input_file=w_adjoint,
    output_folder="adj_output",
    verbosity=2,
)

# print(job.stdout)
Uploading 3 files...

🚀  Submitted job_1910221515452375_3038a9f80b[[email protected]local_f64
Job `job_1910221515452375_3038a9f80b` running on `local_f64` with 2 rank(s).
Site information:
  * Salvus version: 0.10.5
  * Floating point size: 64

 🍿  Downloading 6 files...


 🎁  Downloaded 87.0 KB to 'adj_output'.
* Downloaded 87.0 KB of results to `adj_output`.
* Total run time: 158.52 seconds.
* Pure simulation time: 156.89 seconds.
g = None
gradient = UnstructuredMesh.from_h5("adj_output/gradient.h5")
# Update model
h = np.logspace(-11, -2, 10)
pars = [x for x, y in p1.parameters()]

nx = 1
ny = 1
f, a = plt.subplots(ny, nx, figsize=(5, 5))

# Perturb parameters one at a time.
print(
    "{:-^80}".format(
        "Results ({}, order={})".format(reference_frame, model_order)
    )
)
for _i, p in enumerate(pars):

    if p in {"AXIS", "TTISymmetryAxis"}:
        continue

    error = test_material_gradient(
        all_h=h,
        model=mesh,
        gradient=gradient,
        parameter=p,
        simulation=w,
        misfit_function=compute_misfit,
        m0=misfit_00,
        ranks=2,
        quiet=True,
    )

    # Generate a plot
    a.set_title(f"Gradient Test TTI({reference_frame})")
    a.loglog(h, error, label=p)
    a.set_xlabel("$h$")
    a.set_ylabel("Relative Error")

    print("{0:<20}{1:<30}{2:<30}".format(p, np.min(error), np.max(error)))

    # Make it a real test.
    assert np.min(error) < 1e-5
    assert error[0] > np.min(error)
    assert error[-1] > np.min(error)


print("{:-^80}\n".format(""))
a.legend()
--------------------------Results (spherical, order=1)--------------------------
RHO                 1.324255154211428e-08         0.014802470346035062          
VPH                 3.015180844958499e-08         0.058775060329050895          
VPV                 3.648858190447956e-08         0.07937892924462987           
VSH                 4.638017386915694e-06         2.379334640358006             
VSV                 1.840476420601613e-07         0.16193821560590005           
ETA                 6.802918923181132e-08         0.029772111102424347          
--------------------------------------------------------------------------------

<matplotlib.legend.Legend at 0x7f8976392e80>
PAGE CONTENTS