Skip to content

Commit 4d13af3

Browse files
authored
Merge pull request #2 from slimgroup/dev
Add NF prior, MAP
2 parents 87eb160 + b6a2735 commit 4d13af3

12 files changed

+1186
-366
lines changed

Manifest.toml

+42-162
Large diffs are not rendered by default.

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
2222
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
2323
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2424
Seis4CCS = "808f26ee-2243-4b28-9f59-2530de529097"
25+
SetIntersectionProjection = "335f7d24-6316-57dd-9c3a-df470f2b739e"
2526
SlimPlotting = "f6d04670-764e-495b-a720-91c3c9a588ff"
2627
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2728
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

scripts/MAP.jl

+212
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# author: Ziyi Yin, ziyi.yin@gatech.edu
2+
## solving min ||F(G^{-1}(z))-y||_2^2 + λ^2 ||z||_2^2
3+
4+
using DrWatson
5+
@quickactivate "FNO4CO2"
6+
7+
using FNO4CO2
8+
using PyPlot
9+
using Flux, Random
10+
using MAT, Statistics, LinearAlgebra
11+
using ProgressMeter, JLD2
12+
using LineSearches
13+
using InvertibleNetworks
14+
15+
Random.seed!(3)
16+
17+
# load the FNO network
18+
JLD2.@load "../data/3D_FNO/batch_size=1_dt=0.02_ep=200_epochs=200_learning_rate=0.0001_modes=4_nt=51_ntrain=1000_nvalid=100_s=1_width=20.jld2";
19+
NN = deepcopy(NN_save);
20+
Flux.testmode!(NN, true);
21+
22+
# load the NF network
23+
JLD2.@load "../data/NFtrain/K=6_L=6_e=50_gab_l2=true_lr=0.001_lr_step=10_max_recursion=1_n_hidden=32_nepochs=500_noiseLev=0.02_λ=0.1.jld2";
24+
G = NetworkMultiScaleHINT(1, n_hidden, L, K;
25+
split_scales=true, max_recursion=max_recursion, p2=0, k2=1, activation=SigmoidLayer(low=0.5f0,high=1.0f0), logdet=false);
26+
P_curr = get_params(G);
27+
for j=1:length(P_curr)
28+
P_curr[j].data = Params[j].data;
29+
end
30+
31+
# forward to set up splitting, take the reverse for Asim formulation
32+
G(zeros(Float32,n[1],n[2],1,1));
33+
G1 = InvertNetRev(G);
34+
35+
# Define raw data directory
36+
mkpath(datadir("training-data"))
37+
perm_path = datadir("training-data", "perm_gridspacing15.0.mat")
38+
conc_path = datadir("training-data", "conc_gridspacing15.0.mat")
39+
40+
# Download the dataset into the data directory if it does not exist
41+
if ~isfile(perm_path)
42+
run(`wget https://www.dropbox.com/s/eqre95eqggqkdq2/'
43+
'perm_gridspacing15.0.mat -q -O $perm_path`)
44+
end
45+
if ~isfile(conc_path)
46+
run(`wget https://www.dropbox.com/s/b5zkp6cw60bd4lt/'
47+
'conc_gridspacing15.0.mat -q -O $conc_path`)
48+
end
49+
50+
perm = matread(perm_path)["perm"];
51+
conc = matread(conc_path)["conc"];
52+
53+
# physics grid
54+
grid = gen_grid(n, d, nt, dt)
55+
56+
# take a test sample
57+
x_true = perm[:,:,ntrain+nvalid+1]; # take a test sample
58+
y_true = conc[:,:,:,ntrain+nvalid+1];
59+
60+
# observation vintages
61+
nv = nt
62+
survey_indices = Int.(round.(range(1, stop=nt, length=nv)))
63+
yobs = permutedims(y_true[survey_indices,:,:,1:1],[2,3,1,4]); # ground truth CO2 concentration at these vintages
64+
65+
## add noise
66+
noise_ = randn(Float32, size(yobs))
67+
snr = 5f0
68+
noise_ = noise_/norm(noise_) * norm(yobs) * 10f0^(-snr/20f0)
69+
σ = Float32.(norm(noise_)/sqrt(length(noise_)))
70+
yobs = yobs + noise_
71+
72+
# initial z
73+
x_init = 20f0 * ones(Float32, n);
74+
x_init[:,25:36] .= 120f0;
75+
z = vec(G(reshape(x_init, n[1], n[2], 1, 1)));
76+
@time y_init = relu01(NN(perm_to_tensor(G1(z)[:,:,1,1], grid, AN)));
77+
78+
## weighting
79+
λ = 1f0;
80+
81+
function S(x)
82+
return relu01(NN(perm_to_tensor(x, grid, AN)))[:,:,survey_indices,1];
83+
end
84+
85+
# function value
86+
function f(z)
87+
println("evaluate f")
88+
loss = .5f0/σ^2f0 * norm(S(G1(z)[:,:,1,1])-yobs)^2f0
89+
return loss
90+
end
91+
92+
# set up plots
93+
niterations = 50
94+
95+
hisloss = zeros(Float32, niterations+1)
96+
hismisfit = zeros(Float32, niterations+1)
97+
hisprior = zeros(Float32, niterations+1)
98+
ls = BackTracking(c_1=1f-4,iterations=10,maxstep=Inf32,order=3,ρ_hi=5f-1,ρ_lo=1f-1)
99+
α = 1f1;
100+
### backtracking line search
101+
prog = Progress(niterations)
102+
for j=1:niterations
103+
104+
p = Flux.params(z)
105+
106+
@time grads = gradient(p) do
107+
global misfit = f(z)
108+
global prior = λ^2f0 * norm(z)^2f0/length(z)
109+
global loss = misfit + prior
110+
@show misfit, prior, loss
111+
println("evaluate g")
112+
return loss
113+
end
114+
if j == 1
115+
hisloss[1] = loss
116+
hismisfit[1] = misfit
117+
hisprior[1] = prior
118+
end
119+
g = grads.grads[z]
120+
gnorm = -g/norm(g, Inf)
121+
122+
println("iteration no: ",j,"; function value: ",loss)
123+
124+
# linesearch
125+
function ϕ(α)
126+
z1 = z .+ α .* gnorm
127+
global misfit = f(z1)
128+
global prior = λ^2f0 * norm(z1)^2f0/length(z1)
129+
global loss = misfit + prior
130+
@show misfit, prior, loss, α
131+
return loss
132+
end
133+
try
134+
global step, fval = ls(ϕ, α, loss, dot(g, gnorm))
135+
catch e
136+
println("linesearch failed at iteration: ",j)
137+
global niterations = j
138+
hisloss[j+1] = loss
139+
hismisfit[j+1] = misfit
140+
hisprior[j+1] = prior
141+
break
142+
end
143+
global α = 1.2f0 * step
144+
145+
hisloss[j+1] = loss
146+
hismisfit[j+1] = misfit
147+
hisprior[j+1] = prior
148+
149+
# Update model and bound projection
150+
global z .+= step .* gnorm
151+
152+
ProgressMeter.next!(prog; showvalues = [(:loss, fval), (:misfit, misfit), (:prior, prior), (:iter, j), (:steplength, step)])
153+
154+
end
155+
156+
y_predict = S(G1(z)[:,:,1,1]);
157+
158+
## compute true and plot
159+
SNR = -2f1 * log10(norm(x_true-G1(z)[:,:,1,1])/norm(x_true))
160+
fig = figure(figsize=(20,12));
161+
subplot(2,2,1);
162+
imshow(G1(z)[:,:,1,1]',vmin=20,vmax=120);title("inversion by NN, $(niterations) iter");colorbar();
163+
subplot(2,2,2);
164+
imshow(x_true',vmin=20,vmax=120);title("GT permeability");colorbar();
165+
subplot(2,2,3);
166+
imshow(x_init',vmin=20,vmax=120);title("initial permeability");colorbar();
167+
subplot(2,2,4);
168+
imshow(5*abs.(x_true'-G1(z)[:,:,1,1]'),vmin=20,vmax=120);title("5X error, SNR=$SNR");colorbar();
169+
suptitle("MAP (NF prior), snr=$snr")
170+
tight_layout()
171+
172+
sim_name = "FNOinversion"
173+
exp_name = "2phaseflow-NFprior"
174+
175+
save_dict = @strdict exp_name
176+
plot_path = plotsdir(sim_name, savename(save_dict; digits=6))
177+
178+
fig_name = @strdict nv niterations λ α snr
179+
safesave(joinpath(plot_path, savename(fig_name; digits=6)*"_3Dfno_inv.png"), fig);
180+
181+
## loss
182+
fig = figure(figsize=(20,12));
183+
subplot(3,1,1);
184+
plot(hisloss[1:niterations+1]);title("loss");
185+
subplot(3,1,2);
186+
plot(hismisfit[1:niterations+1]);title("misfit");
187+
subplot(3,1,3);
188+
plot(hisprior[1:niterations+1]);title("prior");
189+
suptitle("MAP (NF prior), snr=$snr")
190+
tight_layout()
191+
192+
safesave(joinpath(plot_path, savename(fig_name; digits=6)*"_3Dfno_loss.png"), fig);
193+
194+
## data fitting
195+
fig = figure(figsize=(20,12));
196+
for i = 1:5
197+
subplot(4,5,i);
198+
imshow(y_init[:,:,10*i+1]', vmin=0, vmax=1);
199+
title("initial prediction at snapshot $(10*i+1)")
200+
subplot(4,5,i+5);
201+
imshow(yobs[:,:,10*i+1]', vmin=0, vmax=1);
202+
title("true at snapshot $(10*i+1)")
203+
subplot(4,5,i+10);
204+
imshow(y_predict[:,:,10*i+1]', vmin=0, vmax=1);
205+
title("predict at snapshot $(10*i+1)")
206+
subplot(4,5,i+15);
207+
imshow(5*abs.(yobs[:,:,10*i+1]'-y_predict[:,:,10*i+1]'), vmin=0, vmax=1);
208+
title("5X diff at snapshot $(10*i+1)")
209+
end
210+
suptitle("MAP (NF prior), snr=$snr")
211+
tight_layout()
212+
safesave(joinpath(plot_path, savename(fig_name; digits=6)*"_3Dfno_fit.png"), fig);

0 commit comments

Comments
 (0)