|
| 1 | +# author: Ziyi Yin, ziyi.yin@gatech.edu |
| 2 | +## solving min ||F(G^{-1}(z))-y||_2^2 + λ^2 ||z||_2^2 |
| 3 | + |
| 4 | +using DrWatson |
| 5 | +@quickactivate "FNO4CO2" |
| 6 | + |
| 7 | +using FNO4CO2 |
| 8 | +using PyPlot |
| 9 | +using Flux, Random |
| 10 | +using MAT, Statistics, LinearAlgebra |
| 11 | +using ProgressMeter, JLD2 |
| 12 | +using LineSearches |
| 13 | +using InvertibleNetworks |
| 14 | + |
| 15 | +Random.seed!(3) |
| 16 | + |
| 17 | +# load the FNO network |
| 18 | +JLD2.@load "../data/3D_FNO/batch_size=1_dt=0.02_ep=200_epochs=200_learning_rate=0.0001_modes=4_nt=51_ntrain=1000_nvalid=100_s=1_width=20.jld2"; |
| 19 | +NN = deepcopy(NN_save); |
| 20 | +Flux.testmode!(NN, true); |
| 21 | + |
| 22 | +# load the NF network |
| 23 | +JLD2.@load "../data/NFtrain/K=6_L=6_e=50_gab_l2=true_lr=0.001_lr_step=10_max_recursion=1_n_hidden=32_nepochs=500_noiseLev=0.02_λ=0.1.jld2"; |
| 24 | +G = NetworkMultiScaleHINT(1, n_hidden, L, K; |
| 25 | + split_scales=true, max_recursion=max_recursion, p2=0, k2=1, activation=SigmoidLayer(low=0.5f0,high=1.0f0), logdet=false); |
| 26 | +P_curr = get_params(G); |
| 27 | +for j=1:length(P_curr) |
| 28 | + P_curr[j].data = Params[j].data; |
| 29 | +end |
| 30 | + |
| 31 | +# forward to set up splitting, take the reverse for Asim formulation |
| 32 | +G(zeros(Float32,n[1],n[2],1,1)); |
| 33 | +G1 = InvertNetRev(G); |
| 34 | + |
| 35 | +# Define raw data directory |
| 36 | +mkpath(datadir("training-data")) |
| 37 | +perm_path = datadir("training-data", "perm_gridspacing15.0.mat") |
| 38 | +conc_path = datadir("training-data", "conc_gridspacing15.0.mat") |
| 39 | + |
| 40 | +# Download the dataset into the data directory if it does not exist |
| 41 | +if ~isfile(perm_path) |
| 42 | + run(`wget https://www.dropbox.com/s/eqre95eqggqkdq2/' |
| 43 | + 'perm_gridspacing15.0.mat -q -O $perm_path`) |
| 44 | +end |
| 45 | +if ~isfile(conc_path) |
| 46 | + run(`wget https://www.dropbox.com/s/b5zkp6cw60bd4lt/' |
| 47 | + 'conc_gridspacing15.0.mat -q -O $conc_path`) |
| 48 | +end |
| 49 | + |
| 50 | +perm = matread(perm_path)["perm"]; |
| 51 | +conc = matread(conc_path)["conc"]; |
| 52 | + |
| 53 | +# physics grid |
| 54 | +grid = gen_grid(n, d, nt, dt) |
| 55 | + |
| 56 | +# take a test sample |
| 57 | +x_true = perm[:,:,ntrain+nvalid+1]; # take a test sample |
| 58 | +y_true = conc[:,:,:,ntrain+nvalid+1]; |
| 59 | + |
| 60 | +# observation vintages |
| 61 | +nv = nt |
| 62 | +survey_indices = Int.(round.(range(1, stop=nt, length=nv))) |
| 63 | +yobs = permutedims(y_true[survey_indices,:,:,1:1],[2,3,1,4]); # ground truth CO2 concentration at these vintages |
| 64 | + |
| 65 | +## add noise |
| 66 | +noise_ = randn(Float32, size(yobs)) |
| 67 | +snr = 5f0 |
| 68 | +noise_ = noise_/norm(noise_) * norm(yobs) * 10f0^(-snr/20f0) |
| 69 | +σ = Float32.(norm(noise_)/sqrt(length(noise_))) |
| 70 | +yobs = yobs + noise_ |
| 71 | + |
| 72 | +# initial z |
| 73 | +x_init = 20f0 * ones(Float32, n); |
| 74 | +x_init[:,25:36] .= 120f0; |
| 75 | +z = vec(G(reshape(x_init, n[1], n[2], 1, 1))); |
| 76 | +@time y_init = relu01(NN(perm_to_tensor(G1(z)[:,:,1,1], grid, AN))); |
| 77 | + |
| 78 | +## weighting |
| 79 | +λ = 1f0; |
| 80 | + |
| 81 | +function S(x) |
| 82 | + return relu01(NN(perm_to_tensor(x, grid, AN)))[:,:,survey_indices,1]; |
| 83 | +end |
| 84 | + |
| 85 | +# function value |
| 86 | +function f(z) |
| 87 | + println("evaluate f") |
| 88 | + loss = .5f0/σ^2f0 * norm(S(G1(z)[:,:,1,1])-yobs)^2f0 |
| 89 | + return loss |
| 90 | +end |
| 91 | + |
| 92 | +# set up plots |
| 93 | +niterations = 50 |
| 94 | + |
| 95 | +hisloss = zeros(Float32, niterations+1) |
| 96 | +hismisfit = zeros(Float32, niterations+1) |
| 97 | +hisprior = zeros(Float32, niterations+1) |
| 98 | +ls = BackTracking(c_1=1f-4,iterations=10,maxstep=Inf32,order=3,ρ_hi=5f-1,ρ_lo=1f-1) |
| 99 | +α = 1f1; |
| 100 | +### backtracking line search |
| 101 | +prog = Progress(niterations) |
| 102 | +for j=1:niterations |
| 103 | + |
| 104 | + p = Flux.params(z) |
| 105 | + |
| 106 | + @time grads = gradient(p) do |
| 107 | + global misfit = f(z) |
| 108 | + global prior = λ^2f0 * norm(z)^2f0/length(z) |
| 109 | + global loss = misfit + prior |
| 110 | + @show misfit, prior, loss |
| 111 | + println("evaluate g") |
| 112 | + return loss |
| 113 | + end |
| 114 | + if j == 1 |
| 115 | + hisloss[1] = loss |
| 116 | + hismisfit[1] = misfit |
| 117 | + hisprior[1] = prior |
| 118 | + end |
| 119 | + g = grads.grads[z] |
| 120 | + gnorm = -g/norm(g, Inf) |
| 121 | + |
| 122 | + println("iteration no: ",j,"; function value: ",loss) |
| 123 | + |
| 124 | + # linesearch |
| 125 | + function ϕ(α) |
| 126 | + z1 = z .+ α .* gnorm |
| 127 | + global misfit = f(z1) |
| 128 | + global prior = λ^2f0 * norm(z1)^2f0/length(z1) |
| 129 | + global loss = misfit + prior |
| 130 | + @show misfit, prior, loss, α |
| 131 | + return loss |
| 132 | + end |
| 133 | + try |
| 134 | + global step, fval = ls(ϕ, α, loss, dot(g, gnorm)) |
| 135 | + catch e |
| 136 | + println("linesearch failed at iteration: ",j) |
| 137 | + global niterations = j |
| 138 | + hisloss[j+1] = loss |
| 139 | + hismisfit[j+1] = misfit |
| 140 | + hisprior[j+1] = prior |
| 141 | + break |
| 142 | + end |
| 143 | + global α = 1.2f0 * step |
| 144 | + |
| 145 | + hisloss[j+1] = loss |
| 146 | + hismisfit[j+1] = misfit |
| 147 | + hisprior[j+1] = prior |
| 148 | + |
| 149 | + # Update model and bound projection |
| 150 | + global z .+= step .* gnorm |
| 151 | + |
| 152 | + ProgressMeter.next!(prog; showvalues = [(:loss, fval), (:misfit, misfit), (:prior, prior), (:iter, j), (:steplength, step)]) |
| 153 | + |
| 154 | +end |
| 155 | + |
| 156 | +y_predict = S(G1(z)[:,:,1,1]); |
| 157 | + |
| 158 | +## compute true and plot |
| 159 | +SNR = -2f1 * log10(norm(x_true-G1(z)[:,:,1,1])/norm(x_true)) |
| 160 | +fig = figure(figsize=(20,12)); |
| 161 | +subplot(2,2,1); |
| 162 | +imshow(G1(z)[:,:,1,1]',vmin=20,vmax=120);title("inversion by NN, $(niterations) iter");colorbar(); |
| 163 | +subplot(2,2,2); |
| 164 | +imshow(x_true',vmin=20,vmax=120);title("GT permeability");colorbar(); |
| 165 | +subplot(2,2,3); |
| 166 | +imshow(x_init',vmin=20,vmax=120);title("initial permeability");colorbar(); |
| 167 | +subplot(2,2,4); |
| 168 | +imshow(5*abs.(x_true'-G1(z)[:,:,1,1]'),vmin=20,vmax=120);title("5X error, SNR=$SNR");colorbar(); |
| 169 | +suptitle("MAP (NF prior), snr=$snr") |
| 170 | +tight_layout() |
| 171 | + |
| 172 | +sim_name = "FNOinversion" |
| 173 | +exp_name = "2phaseflow-NFprior" |
| 174 | + |
| 175 | +save_dict = @strdict exp_name |
| 176 | +plot_path = plotsdir(sim_name, savename(save_dict; digits=6)) |
| 177 | + |
| 178 | +fig_name = @strdict nv niterations λ α snr |
| 179 | +safesave(joinpath(plot_path, savename(fig_name; digits=6)*"_3Dfno_inv.png"), fig); |
| 180 | + |
| 181 | +## loss |
| 182 | +fig = figure(figsize=(20,12)); |
| 183 | +subplot(3,1,1); |
| 184 | +plot(hisloss[1:niterations+1]);title("loss"); |
| 185 | +subplot(3,1,2); |
| 186 | +plot(hismisfit[1:niterations+1]);title("misfit"); |
| 187 | +subplot(3,1,3); |
| 188 | +plot(hisprior[1:niterations+1]);title("prior"); |
| 189 | +suptitle("MAP (NF prior), snr=$snr") |
| 190 | +tight_layout() |
| 191 | + |
| 192 | +safesave(joinpath(plot_path, savename(fig_name; digits=6)*"_3Dfno_loss.png"), fig); |
| 193 | + |
| 194 | +## data fitting |
| 195 | +fig = figure(figsize=(20,12)); |
| 196 | +for i = 1:5 |
| 197 | + subplot(4,5,i); |
| 198 | + imshow(y_init[:,:,10*i+1]', vmin=0, vmax=1); |
| 199 | + title("initial prediction at snapshot $(10*i+1)") |
| 200 | + subplot(4,5,i+5); |
| 201 | + imshow(yobs[:,:,10*i+1]', vmin=0, vmax=1); |
| 202 | + title("true at snapshot $(10*i+1)") |
| 203 | + subplot(4,5,i+10); |
| 204 | + imshow(y_predict[:,:,10*i+1]', vmin=0, vmax=1); |
| 205 | + title("predict at snapshot $(10*i+1)") |
| 206 | + subplot(4,5,i+15); |
| 207 | + imshow(5*abs.(yobs[:,:,10*i+1]'-y_predict[:,:,10*i+1]'), vmin=0, vmax=1); |
| 208 | + title("5X diff at snapshot $(10*i+1)") |
| 209 | +end |
| 210 | +suptitle("MAP (NF prior), snr=$snr") |
| 211 | +tight_layout() |
| 212 | +safesave(joinpath(plot_path, savename(fig_name; digits=6)*"_3Dfno_fit.png"), fig); |
0 commit comments