"""Regular FWI.
"""
import argparse
import numpy as np
import torch
import torch.optim as optim
from setup_seam import setup_seam
from smii.inversion.fwi import costjac
from smii.modeling.propagators.propagators import Scalar2D

def run_fwi(maxiter, lr, adam):
    """Run conventional FWI.
    """

    # Create initial model
    model_init = _create_desired_init_model()
    model = torch.Tensor(model_init)
    model.requires_grad = True

    # Load true model and metadata
    seam_model_path = 'tmp/SEAM_Vp_Elastic_N23900_chop_interp.bin'
    seam = setup_seam(seam_model_path)
    model_true = seam['model_true']
    dx = seam['dx']
    dt = seam['dt']
    sources = seam['sources']
    receivers_x = seam['receivers_x']

    # Load data
    seam_data_path = 'tmp/seam_data.npy'
    data = np.load(seam_data_path)
    num_sources = len(sources)
    num_receivers = len(receivers_x)
    nt = sources[0]['amplitude'].shape[1]
    assert data.shape[0] == num_sources
    assert data.shape[1] == num_receivers
    assert data.shape[2] == nt

    # Create dataset
    dataset = []
    for i, source in enumerate(sources):
        receiver = {}
        receiver['amplitude'] = data[i]
        receiver['locations'] = receivers_x
        dataset.append((source, receiver))

    propagator = Scalar2D
    if adam:
        optimizer = optim.Adam([model], lr=lr)
    else:
        optimizer = optim.LBFGS([model], lr=lr)
    data_cost = []
    model_cost = []
    models = []

    def closure():
        # Produce numpy array
        x = model.data.numpy()
        models.append(x.copy())

        # Compute cost and gradient on model
        cost, grad = costjac(x.ravel(), dataset, dx, dt, propagator,
                             model_true.shape)
        data_cost.append(cost)
        model_cost.append(np.linalg.norm(model_true - x))

        # Update model using gradient
        optimizer.zero_grad()
        model.backward(torch.Tensor(grad.reshape(model.shape)))
        return cost

    for it in range(maxiter):
        optimizer.step(closure)
        if it % 1 == 0:
            print(it, data_cost[-1], model_cost[-1])
            np.save('tmp/conventional_model_{}.npy'.format(it),
                    model.data.numpy())

    x = model.data.numpy()
    models.append(x.copy())

    return x, data_cost, model_cost, models


def _create_desired_init_model():
    """Create an initial model that starts at 1490 m/s at the top (water speed)
    and increases by 0.5m/s for each meter in depth.
    """
    nz = 64
    nx = 64
    dx = 100
    v0 = 1490
    dvdz = 0.5
    model_init = np.arange(v0, v0+nz*dx*dvdz, dx*dvdz).astype(np.float32)
    model_init = np.tile(model_init.reshape([-1, 1]), [1, nx])

    return model_init


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--niter', type=int, default=20,
                        help='number of iterations of optimizer')
    parser.add_argument('--lr', type=float, default=0.1,
                        help='learning rate of optimizer')
    parser.add_argument('--adam', action='store_true',
                        help='use Adam optimizer instead of LBFGS')
    opt = parser.parse_args()
    print(opt)
    final_model, _, _, _ = run_fwi(opt.niter, opt.lr, opt.adam)
    np.save('tmp/conventional_final_model.npy', final_model)
