Skip to content

Commit 1d90b74

Browse files
committed
add crosswell comparison to crosswell reflection
1 parent 640ed55 commit 1d90b74

File tree

2 files changed

+597
-0
lines changed

2 files changed

+597
-0
lines changed

couple_inversion.jl

+298
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
using DrWatson
2+
@quickactivate "FNO"
3+
import Pkg; Pkg.instantiate()
4+
5+
using PyPlot
6+
using BSON
7+
using Flux, Random, FFTW, Zygote, NNlib
8+
using MAT, Statistics, LinearAlgebra
9+
using CUDA
10+
using ProgressMeter, JLD2
11+
using LineSearches
12+
13+
using JUDI, JUDI4Flux
14+
using JOLI
15+
using Printf
16+
17+
CUDA.culiteral_pow(::typeof(^), a::Complex{Float32}, b::Val{2}) = real(conj(a)*a)
18+
CUDA.sqrt(a::Complex) = cu(sqrt(a))
19+
Base.broadcasted(::typeof(sqrt), a::Base.Broadcast.Broadcasted) = Base.broadcast(sqrt, Base.materialize(a))
20+
21+
include("utils.jl");
22+
include("fno3dstruct.jl");
23+
24+
Random.seed!(3)
25+
26+
ntrain = 1000
27+
ntest = 100
28+
29+
BSON.@load "data/TrainedNet/2phasenet_200.bson" NN w batch_size Loss modes width learning_rate epochs gamma step_size;
30+
31+
n = (64,64)
32+
# dx, dy in m
33+
d = (1f0/64, 1f0/64) # in the training phase
34+
35+
nt = 51
36+
#dt = 20f0 # dt in day
37+
dt = 1f0/nt
38+
39+
perm = matread("data/data/perm.mat")["perm"];
40+
conc = matread("data/data/conc.mat")["conc"];
41+
42+
s = 4
43+
44+
x_train_ = convert(Array{Float32},perm[1:s:end,1:s:end,1:ntrain]);
45+
x_test_ = convert(Array{Float32},perm[1:s:end,1:s:end,end-ntest+1:end]);
46+
47+
nv = 11
48+
survey_indices = Int.(round.(range(1, stop=nt, length=nv)))
49+
50+
y_train_ = convert(Array{Float32},conc[survey_indices,1:s:end,1:s:end,1:ntrain]);
51+
y_test_ = convert(Array{Float32},conc[survey_indices,1:s:end,1:s:end,end-ntest+1:end]);
52+
53+
y_train_ = permutedims(y_train_,[2,3,1,4]);
54+
y_test = permutedims(y_test_,[2,3,1,4]);
55+
56+
x_normalizer = UnitGaussianNormalizer(x_train_);
57+
x_train_ = encode(x_normalizer,x_train_);
58+
x_test_ = encode(x_normalizer,x_test_);
59+
60+
y_normalizer = UnitGaussianNormalizer(y_train_);
61+
y_train = encode(y_normalizer,y_train_);
62+
63+
x = reshape(collect(range(d[1],stop=n[1]*d[1],length=n[1])), :, 1)
64+
z = reshape(collect(range(d[2],stop=n[2]*d[2],length=n[2])), 1, :)
65+
66+
grid = zeros(Float32,n[1],n[2],2)
67+
grid[:,:,1] = repeat(x',n[2])'
68+
grid[:,:,2] = repeat(z,n[1])
69+
70+
x_train = zeros(Float32,n[1],n[2],nv,4,ntrain)
71+
x_test = zeros(Float32,n[1],n[2],nv,4,ntest)
72+
73+
for i = 1:nv
74+
x_train[:,:,i,1,:] = deepcopy(x_train_)
75+
x_test[:,:,i,1,:] = deepcopy(x_test_)
76+
for j = 1:ntrain
77+
x_train[:,:,i,2,j] = grid[:,:,1]
78+
x_train[:,:,i,3,j] = grid[:,:,2]
79+
x_train[:,:,i,4,j] .= survey_indices[i]*dt
80+
end
81+
82+
for k = 1:ntest
83+
x_test[:,:,i,2,k] = grid[:,:,1]
84+
x_test[:,:,i,3,k] = grid[:,:,2]
85+
x_test[:,:,i,4,k] .= survey_indices[i]*dt
86+
end
87+
end
88+
89+
# value, x, y, t
90+
Flux.testmode!(NN, true);
91+
Flux.testmode!(NN.conv1.bn0);
92+
Flux.testmode!(NN.conv1.bn1);
93+
Flux.testmode!(NN.conv1.bn2);
94+
Flux.testmode!(NN.conv1.bn3);
95+
96+
nx, ny = n
97+
dx, dy = d
98+
x_test_1 = deepcopy(perm[1:s:end,1:s:end,1001]);
99+
y_test_1 = deepcopy(conc[:,1:s:end,1:s:end,1001]);
100+
101+
################ Forward -- generate data
102+
103+
sw = y_test_1[survey_indices,:,:,1]
104+
105+
##### Rock physics
106+
107+
function Patchy(sw::AbstractArray{Float32}, vp::AbstractArray{Float32}, vs::AbstractArray{Float32}, rho::AbstractArray{Float32}, phi::AbstractArray{Float32}; bulk_min = 36.6f9, bulk_fl1 = 2.735f9, bulk_fl2 = 0.125f9, ρw = 501.9f0, ρo = 1053.0f0)
108+
109+
bulk_sat1 = rho .* (vp.^2f0 - 4f0/3f0 .* vs.^2f0)
110+
shear_sat1 = rho .* (vs.^2f0)
111+
112+
patch_temp = bulk_sat1 ./(bulk_min .- bulk_sat1) -
113+
bulk_fl1 ./ phi ./ (bulk_min .- bulk_fl1) +
114+
bulk_fl2 ./ phi ./ (bulk_min .- bulk_fl2)
115+
116+
bulk_sat2 = bulk_min./(1f0./patch_temp .+ 1f0)
117+
118+
bulk_new = 1f0./( (1f0.-sw)./(bulk_sat1+4f0/3f0*shear_sat1)
119+
+ sw./(bulk_sat2+4f0/3f0*shear_sat1) ) - 4f0/3f0*shear_sat1
120+
121+
rho_new = rho + phi .* sw * (ρw - ρo)
122+
123+
Vp_new = sqrt.((bulk_new+4f0/3f0*shear_sat1)./rho_new)
124+
Vs_new = sqrt.((shear_sat1)./rho_new)
125+
126+
return Vp_new, Vs_new, rho_new
127+
end
128+
129+
n = (size(sw,3), size(sw,2))
130+
131+
vp = 3500 * ones(Float32,n)
132+
vs = vp ./ sqrt(3f0)
133+
phi = 0.25f0 * ones(Float32,n)
134+
135+
rho = 2200 * ones(Float32,n)
136+
137+
vp_stack = [(Patchy(sw[i,:,:]',vp,vs,rho,phi))[1] for i = 1:nv]
138+
139+
##### Wave equation
140+
141+
d = (15f0, 15f0)
142+
o = (0f0, 0f0)
143+
144+
extentx = (n[1]-1)*d[1]
145+
extentz = (n[2]-1)*d[2]
146+
147+
nsrc = 8
148+
nrec = n[2]
149+
150+
model = [Model(n, d, o, (1000f0 ./ vp_stack[i]).^2f0; nb = 75) for i = 1:nv]
151+
152+
timeS = timeR = 750f0
153+
dtS = dtR = 1f0
154+
ntS = Int(floor(timeS/dtS))+1
155+
ntR = Int(floor(timeR/dtR))+1
156+
157+
xsrc = convertToCell(range(5*d[1],stop=5*d[1],length=nsrc))
158+
ysrc = convertToCell(range(0f0,stop=0f0,length=nsrc))
159+
zsrc = convertToCell(range(15*d[2],stop=(n[2]-15)*d[2],length=nsrc))
160+
161+
xrec = range((n[1]-5)*d[1],stop=(n[1]-5)*d[1], length=nrec)
162+
yrec = 0f0
163+
zrec = range(d[2],stop=(n[2]-1)*d[2],length=nrec)
164+
165+
srcGeometry = Geometry(xsrc, ysrc, zsrc; dt=dtS, t=timeS)
166+
recGeometry = Geometry(xrec, yrec, zrec; dt=dtR, t=timeR, nsrc=nsrc)
167+
168+
f0 = 0.02f0 # kHz
169+
wavelet = ricker_wavelet(timeS, dtS, f0)
170+
q = judiVector(srcGeometry, wavelet)
171+
172+
ntComp = get_computational_nt(srcGeometry, recGeometry, model[end])
173+
info = Info(prod(n), nsrc, ntComp)
174+
175+
opt = Options(return_array=true)
176+
Pr = judiProjection(info, recGeometry)
177+
Ps = judiProjection(info, srcGeometry)
178+
179+
F = [Pr*judiModeling(info, model[i]; options=opt)*Ps' for i = 1:nv]
180+
181+
d_obs = [F[i]*q for i = 1:nv]
182+
183+
JLD2.@save "data/data/timelapse_$(nsrc)nsrc_$(nv)nv.jld2" d_obs
184+
185+
G = Forward(F[1],q)
186+
187+
x_perm = 20*ones(Float32,n[1],n[2],1)
188+
189+
grad_iterations = 20
190+
191+
function perm_to_tensor(x_perm,survey_indices,grid,dt)
192+
# input nx*ny, output nx*ny*nt*4*1
193+
nv = length(survey_indices)
194+
nx, ny = size(x_perm)
195+
x1 = reshape(x_perm,nx,ny,1,1,1)
196+
x2 = cat([x1 for i = 1:nv]...,dims=3)
197+
grid_1 = cat([reshape(grid[:,:,1],nx,ny,1,1,1) for i = 1:nv]...,dims=3)
198+
grid_2 = cat([reshape(grid[:,:,2],nx,ny,1,1,1) for i = 1:nv]...,dims=3)
199+
grid_t = cat([survey_indices[i]*dt*ones(Float32,nx,ny,1,1,1) for i = 1:nv]...,dims=3)
200+
x_out = cat(x2,grid_1,grid_2,grid_t,dims=4)
201+
return x_out
202+
end
203+
204+
function f(x_inv)
205+
println("evaluate f")
206+
@time begin
207+
sw = decode(y_normalizer,NN(perm_to_tensor(x_inv,survey_indices,grid,dt)))
208+
vp_stack = [(Patchy(sw[:,:,i,1]',vp,vs,rho,phi))[1] for i = 1:nv]
209+
m_stack = [(1000f0 ./ vp_stack[i]).^2f0 for i = 1:nv]
210+
d_predict = [G(m_stack[i]) for i = 1:nv]
211+
loss = 0.5f0 * norm(d_predict-d_obs)^2f0
212+
end
213+
return loss
214+
end
215+
216+
function g!(gvec, x_inv)
217+
println("evaluate g")
218+
p = params(x_inv)
219+
@time grads = gradient(p) do
220+
sw = decode(y_normalizer,NN(perm_to_tensor(x_inv,survey_indices,grid,dt)))
221+
vp_stack = [(Patchy(sw[:,:,i,1]',vp,vs,rho,phi))[1] for i = 1:nv]
222+
m_stack = [(1000f0 ./ vp_stack[i]).^2f0 for i = 1:nv]
223+
d_predict = [G(m_stack[i]) for i = 1:nv]
224+
loss = 0.5f0 * norm(d_predict-d_obs)^2f0
225+
return loss
226+
end
227+
copyto!(gvec, grads.grads[x_inv])
228+
end
229+
230+
function fg!(gvec, x_inv)
231+
println("evaluate f and g")
232+
p = params(x_inv)
233+
@time grads = gradient(p) do
234+
sw = decode(y_normalizer,NN(perm_to_tensor(x_inv,survey_indices,grid,dt)))
235+
vp_stack = [(Patchy(sw[:,:,i,1]',vp,vs,rho,phi))[1] for i = 1:nv]
236+
m_stack = [(1000f0 ./ vp_stack[i]).^2f0 for i = 1:nv]
237+
d_predict = [G(m_stack[i]) for i = 1:nv]
238+
global loss = 0.5f0 * norm(d_predict-d_obs)^2f0
239+
return loss
240+
end
241+
copyto!(gvec, grads.grads[x_inv])
242+
return loss
243+
end
244+
245+
x = zeros(Float32, nx, ny)
246+
x_init = decode(x_normalizer,reshape(x,nx,ny,1))[:,:,1]
247+
248+
ls = BackTracking(c_1=1f-4,iterations=10,maxstep=Inf32,order=3,ρ_hi=5f-1,ρ_lo=1f-1)
249+
Grad_Loss = zeros(Float32, grad_iterations+1)
250+
251+
T = Float32
252+
253+
Grad_Loss[1] = f(x)
254+
println("Initial function value: ", Grad_Loss[1])
255+
256+
figure();
257+
for j=1:grad_iterations
258+
259+
gvec = similar(x)::AbstractArray{T}
260+
fval = fg!(gvec, x)::T
261+
p = -gvec/norm(gvec, Inf)
262+
263+
# linesearch
264+
function ϕ(α)::T
265+
try
266+
fval = f(x .+ α.*p)
267+
catch e
268+
@assert typeof(e) == DomainError
269+
fval = T(Inf)
270+
end
271+
@show α, fval
272+
return fval
273+
end
274+
275+
α, fval = ls(ϕ, 1f0, fval, dot(gvec, p))
276+
277+
println("Coupled inversion iteration no: ",j,"; function value: ",fval)
278+
Grad_Loss[j+1] = fval
279+
280+
global x_inv = decode(x_normalizer,reshape(x,nx,ny,1))[:,:,1]
281+
imshow(x_inv,vmin=20,vmax=120);title("inversion by NN, $j iter");
282+
283+
# Update model and bound projection
284+
@. x = x + α*p::AbstractArray{T}
285+
end
286+
287+
figure(figsize=(20,12));
288+
subplot(1,3,1)
289+
imshow(x_init,vmin=20,vmax=120);title("initial permeability");
290+
subplot(1,3,2);
291+
imshow(x_inv,vmin=20,vmax=120);title("inversion by coupled NN, $grad_iterations iter");
292+
subplot(1,3,3);
293+
imshow(x_test_1,vmin=20,vmax=120);title("GT permeability");
294+
suptitle("$nv vintages, $nsrc each, crosswell")
295+
figure();
296+
plot(Grad_Loss)
297+
298+
JLD2.@save "result/coupleinv$(grad_iterations)_$(nv)nv_$(nsrc)nsrc.jld2" x_inv Grad_Loss

0 commit comments

Comments
 (0)