|
| 1 | +using DrWatson |
| 2 | +@quickactivate "FNO" |
| 3 | +import Pkg; Pkg.instantiate() |
| 4 | + |
| 5 | +using PyPlot |
| 6 | +using BSON |
| 7 | +using Flux, Random, FFTW, Zygote, NNlib |
| 8 | +using MAT, Statistics, LinearAlgebra |
| 9 | +using CUDA |
| 10 | +using ProgressMeter, JLD2 |
| 11 | +using LineSearches |
| 12 | + |
| 13 | +using JUDI, JUDI4Flux |
| 14 | +using JOLI |
| 15 | +using Printf |
| 16 | + |
| 17 | +CUDA.culiteral_pow(::typeof(^), a::Complex{Float32}, b::Val{2}) = real(conj(a)*a) |
| 18 | +CUDA.sqrt(a::Complex) = cu(sqrt(a)) |
| 19 | +Base.broadcasted(::typeof(sqrt), a::Base.Broadcast.Broadcasted) = Base.broadcast(sqrt, Base.materialize(a)) |
| 20 | + |
| 21 | +include("utils.jl"); |
| 22 | +include("fno3dstruct.jl"); |
| 23 | + |
| 24 | +Random.seed!(3) |
| 25 | + |
| 26 | +ntrain = 1000 |
| 27 | +ntest = 100 |
| 28 | + |
| 29 | +BSON.@load "data/TrainedNet/2phasenet_200.bson" NN w batch_size Loss modes width learning_rate epochs gamma step_size; |
| 30 | + |
| 31 | +n = (64,64) |
| 32 | + # dx, dy in m |
| 33 | +d = (1f0/64, 1f0/64) # in the training phase |
| 34 | + |
| 35 | +nt = 51 |
| 36 | +#dt = 20f0 # dt in day |
| 37 | +dt = 1f0/nt |
| 38 | + |
| 39 | +perm = matread("data/data/perm.mat")["perm"]; |
| 40 | +conc = matread("data/data/conc.mat")["conc"]; |
| 41 | + |
| 42 | +s = 4 |
| 43 | + |
| 44 | +x_train_ = convert(Array{Float32},perm[1:s:end,1:s:end,1:ntrain]); |
| 45 | +x_test_ = convert(Array{Float32},perm[1:s:end,1:s:end,end-ntest+1:end]); |
| 46 | + |
| 47 | +nv = 11 |
| 48 | +survey_indices = Int.(round.(range(1, stop=nt, length=nv))) |
| 49 | + |
| 50 | +y_train_ = convert(Array{Float32},conc[survey_indices,1:s:end,1:s:end,1:ntrain]); |
| 51 | +y_test_ = convert(Array{Float32},conc[survey_indices,1:s:end,1:s:end,end-ntest+1:end]); |
| 52 | + |
| 53 | +y_train_ = permutedims(y_train_,[2,3,1,4]); |
| 54 | +y_test = permutedims(y_test_,[2,3,1,4]); |
| 55 | + |
| 56 | +x_normalizer = UnitGaussianNormalizer(x_train_); |
| 57 | +x_train_ = encode(x_normalizer,x_train_); |
| 58 | +x_test_ = encode(x_normalizer,x_test_); |
| 59 | + |
| 60 | +y_normalizer = UnitGaussianNormalizer(y_train_); |
| 61 | +y_train = encode(y_normalizer,y_train_); |
| 62 | + |
| 63 | +x = reshape(collect(range(d[1],stop=n[1]*d[1],length=n[1])), :, 1) |
| 64 | +z = reshape(collect(range(d[2],stop=n[2]*d[2],length=n[2])), 1, :) |
| 65 | + |
| 66 | +grid = zeros(Float32,n[1],n[2],2) |
| 67 | +grid[:,:,1] = repeat(x',n[2])' |
| 68 | +grid[:,:,2] = repeat(z,n[1]) |
| 69 | + |
| 70 | +x_train = zeros(Float32,n[1],n[2],nv,4,ntrain) |
| 71 | +x_test = zeros(Float32,n[1],n[2],nv,4,ntest) |
| 72 | + |
| 73 | +for i = 1:nv |
| 74 | + x_train[:,:,i,1,:] = deepcopy(x_train_) |
| 75 | + x_test[:,:,i,1,:] = deepcopy(x_test_) |
| 76 | + for j = 1:ntrain |
| 77 | + x_train[:,:,i,2,j] = grid[:,:,1] |
| 78 | + x_train[:,:,i,3,j] = grid[:,:,2] |
| 79 | + x_train[:,:,i,4,j] .= survey_indices[i]*dt |
| 80 | + end |
| 81 | + |
| 82 | + for k = 1:ntest |
| 83 | + x_test[:,:,i,2,k] = grid[:,:,1] |
| 84 | + x_test[:,:,i,3,k] = grid[:,:,2] |
| 85 | + x_test[:,:,i,4,k] .= survey_indices[i]*dt |
| 86 | + end |
| 87 | +end |
| 88 | + |
| 89 | +# value, x, y, t |
| 90 | +Flux.testmode!(NN, true); |
| 91 | +Flux.testmode!(NN.conv1.bn0); |
| 92 | +Flux.testmode!(NN.conv1.bn1); |
| 93 | +Flux.testmode!(NN.conv1.bn2); |
| 94 | +Flux.testmode!(NN.conv1.bn3); |
| 95 | + |
| 96 | +nx, ny = n |
| 97 | +dx, dy = d |
| 98 | +x_test_1 = deepcopy(perm[1:s:end,1:s:end,1001]); |
| 99 | +y_test_1 = deepcopy(conc[:,1:s:end,1:s:end,1001]); |
| 100 | + |
| 101 | +################ Forward -- generate data |
| 102 | + |
| 103 | +sw = y_test_1[survey_indices,:,:,1] |
| 104 | + |
| 105 | +##### Rock physics |
| 106 | + |
| 107 | +function Patchy(sw::AbstractArray{Float32}, vp::AbstractArray{Float32}, vs::AbstractArray{Float32}, rho::AbstractArray{Float32}, phi::AbstractArray{Float32}; bulk_min = 36.6f9, bulk_fl1 = 2.735f9, bulk_fl2 = 0.125f9, ρw = 501.9f0, ρo = 1053.0f0) |
| 108 | + |
| 109 | + bulk_sat1 = rho .* (vp.^2f0 - 4f0/3f0 .* vs.^2f0) |
| 110 | + shear_sat1 = rho .* (vs.^2f0) |
| 111 | + |
| 112 | + patch_temp = bulk_sat1 ./(bulk_min .- bulk_sat1) - |
| 113 | + bulk_fl1 ./ phi ./ (bulk_min .- bulk_fl1) + |
| 114 | + bulk_fl2 ./ phi ./ (bulk_min .- bulk_fl2) |
| 115 | + |
| 116 | + bulk_sat2 = bulk_min./(1f0./patch_temp .+ 1f0) |
| 117 | + |
| 118 | + bulk_new = 1f0./( (1f0.-sw)./(bulk_sat1+4f0/3f0*shear_sat1) |
| 119 | + + sw./(bulk_sat2+4f0/3f0*shear_sat1) ) - 4f0/3f0*shear_sat1 |
| 120 | + |
| 121 | + rho_new = rho + phi .* sw * (ρw - ρo) |
| 122 | + |
| 123 | + Vp_new = sqrt.((bulk_new+4f0/3f0*shear_sat1)./rho_new) |
| 124 | + Vs_new = sqrt.((shear_sat1)./rho_new) |
| 125 | + |
| 126 | + return Vp_new, Vs_new, rho_new |
| 127 | +end |
| 128 | + |
| 129 | +n = (size(sw,3), size(sw,2)) |
| 130 | + |
| 131 | +vp = 3500 * ones(Float32,n) |
| 132 | +vs = vp ./ sqrt(3f0) |
| 133 | +phi = 0.25f0 * ones(Float32,n) |
| 134 | + |
| 135 | +rho = 2200 * ones(Float32,n) |
| 136 | + |
| 137 | +vp_stack = [(Patchy(sw[i,:,:]',vp,vs,rho,phi))[1] for i = 1:nv] |
| 138 | + |
| 139 | +##### Wave equation |
| 140 | + |
| 141 | +d = (15f0, 15f0) |
| 142 | +o = (0f0, 0f0) |
| 143 | + |
| 144 | +extentx = (n[1]-1)*d[1] |
| 145 | +extentz = (n[2]-1)*d[2] |
| 146 | + |
| 147 | +nsrc = 8 |
| 148 | +nrec = n[2] |
| 149 | + |
| 150 | +model = [Model(n, d, o, (1000f0 ./ vp_stack[i]).^2f0; nb = 75) for i = 1:nv] |
| 151 | + |
| 152 | +timeS = timeR = 750f0 |
| 153 | +dtS = dtR = 1f0 |
| 154 | +ntS = Int(floor(timeS/dtS))+1 |
| 155 | +ntR = Int(floor(timeR/dtR))+1 |
| 156 | + |
| 157 | +xsrc = convertToCell(range(5*d[1],stop=5*d[1],length=nsrc)) |
| 158 | +ysrc = convertToCell(range(0f0,stop=0f0,length=nsrc)) |
| 159 | +zsrc = convertToCell(range(15*d[2],stop=(n[2]-15)*d[2],length=nsrc)) |
| 160 | + |
| 161 | +xrec = range((n[1]-5)*d[1],stop=(n[1]-5)*d[1], length=nrec) |
| 162 | +yrec = 0f0 |
| 163 | +zrec = range(d[2],stop=(n[2]-1)*d[2],length=nrec) |
| 164 | + |
| 165 | +srcGeometry = Geometry(xsrc, ysrc, zsrc; dt=dtS, t=timeS) |
| 166 | +recGeometry = Geometry(xrec, yrec, zrec; dt=dtR, t=timeR, nsrc=nsrc) |
| 167 | + |
| 168 | +f0 = 0.02f0 # kHz |
| 169 | +wavelet = ricker_wavelet(timeS, dtS, f0) |
| 170 | +q = judiVector(srcGeometry, wavelet) |
| 171 | + |
| 172 | +ntComp = get_computational_nt(srcGeometry, recGeometry, model[end]) |
| 173 | +info = Info(prod(n), nsrc, ntComp) |
| 174 | + |
| 175 | +opt = Options(return_array=true) |
| 176 | +Pr = judiProjection(info, recGeometry) |
| 177 | +Ps = judiProjection(info, srcGeometry) |
| 178 | + |
| 179 | +F = [Pr*judiModeling(info, model[i]; options=opt)*Ps' for i = 1:nv] |
| 180 | + |
| 181 | +d_obs = [F[i]*q for i = 1:nv] |
| 182 | + |
| 183 | +JLD2.@save "data/data/timelapse_$(nsrc)nsrc_$(nv)nv.jld2" d_obs |
| 184 | + |
| 185 | +G = Forward(F[1],q) |
| 186 | + |
| 187 | +x_perm = 20*ones(Float32,n[1],n[2],1) |
| 188 | + |
| 189 | +grad_iterations = 20 |
| 190 | + |
| 191 | +function perm_to_tensor(x_perm,survey_indices,grid,dt) |
| 192 | + # input nx*ny, output nx*ny*nt*4*1 |
| 193 | + nv = length(survey_indices) |
| 194 | + nx, ny = size(x_perm) |
| 195 | + x1 = reshape(x_perm,nx,ny,1,1,1) |
| 196 | + x2 = cat([x1 for i = 1:nv]...,dims=3) |
| 197 | + grid_1 = cat([reshape(grid[:,:,1],nx,ny,1,1,1) for i = 1:nv]...,dims=3) |
| 198 | + grid_2 = cat([reshape(grid[:,:,2],nx,ny,1,1,1) for i = 1:nv]...,dims=3) |
| 199 | + grid_t = cat([survey_indices[i]*dt*ones(Float32,nx,ny,1,1,1) for i = 1:nv]...,dims=3) |
| 200 | + x_out = cat(x2,grid_1,grid_2,grid_t,dims=4) |
| 201 | + return x_out |
| 202 | +end |
| 203 | + |
| 204 | +function f(x_inv) |
| 205 | + println("evaluate f") |
| 206 | + @time begin |
| 207 | + sw = decode(y_normalizer,NN(perm_to_tensor(x_inv,survey_indices,grid,dt))) |
| 208 | + vp_stack = [(Patchy(sw[:,:,i,1]',vp,vs,rho,phi))[1] for i = 1:nv] |
| 209 | + m_stack = [(1000f0 ./ vp_stack[i]).^2f0 for i = 1:nv] |
| 210 | + d_predict = [G(m_stack[i]) for i = 1:nv] |
| 211 | + loss = 0.5f0 * norm(d_predict-d_obs)^2f0 |
| 212 | + end |
| 213 | + return loss |
| 214 | +end |
| 215 | + |
| 216 | +function g!(gvec, x_inv) |
| 217 | + println("evaluate g") |
| 218 | + p = params(x_inv) |
| 219 | + @time grads = gradient(p) do |
| 220 | + sw = decode(y_normalizer,NN(perm_to_tensor(x_inv,survey_indices,grid,dt))) |
| 221 | + vp_stack = [(Patchy(sw[:,:,i,1]',vp,vs,rho,phi))[1] for i = 1:nv] |
| 222 | + m_stack = [(1000f0 ./ vp_stack[i]).^2f0 for i = 1:nv] |
| 223 | + d_predict = [G(m_stack[i]) for i = 1:nv] |
| 224 | + loss = 0.5f0 * norm(d_predict-d_obs)^2f0 |
| 225 | + return loss |
| 226 | + end |
| 227 | + copyto!(gvec, grads.grads[x_inv]) |
| 228 | +end |
| 229 | + |
| 230 | +function fg!(gvec, x_inv) |
| 231 | + println("evaluate f and g") |
| 232 | + p = params(x_inv) |
| 233 | + @time grads = gradient(p) do |
| 234 | + sw = decode(y_normalizer,NN(perm_to_tensor(x_inv,survey_indices,grid,dt))) |
| 235 | + vp_stack = [(Patchy(sw[:,:,i,1]',vp,vs,rho,phi))[1] for i = 1:nv] |
| 236 | + m_stack = [(1000f0 ./ vp_stack[i]).^2f0 for i = 1:nv] |
| 237 | + d_predict = [G(m_stack[i]) for i = 1:nv] |
| 238 | + global loss = 0.5f0 * norm(d_predict-d_obs)^2f0 |
| 239 | + return loss |
| 240 | + end |
| 241 | + copyto!(gvec, grads.grads[x_inv]) |
| 242 | + return loss |
| 243 | +end |
| 244 | + |
| 245 | +x = zeros(Float32, nx, ny) |
| 246 | +x_init = decode(x_normalizer,reshape(x,nx,ny,1))[:,:,1] |
| 247 | + |
| 248 | +ls = BackTracking(c_1=1f-4,iterations=10,maxstep=Inf32,order=3,ρ_hi=5f-1,ρ_lo=1f-1) |
| 249 | +Grad_Loss = zeros(Float32, grad_iterations+1) |
| 250 | + |
| 251 | +T = Float32 |
| 252 | + |
| 253 | +Grad_Loss[1] = f(x) |
| 254 | +println("Initial function value: ", Grad_Loss[1]) |
| 255 | + |
| 256 | +figure(); |
| 257 | +for j=1:grad_iterations |
| 258 | + |
| 259 | + gvec = similar(x)::AbstractArray{T} |
| 260 | + fval = fg!(gvec, x)::T |
| 261 | + p = -gvec/norm(gvec, Inf) |
| 262 | + |
| 263 | + # linesearch |
| 264 | + function ϕ(α)::T |
| 265 | + try |
| 266 | + fval = f(x .+ α.*p) |
| 267 | + catch e |
| 268 | + @assert typeof(e) == DomainError |
| 269 | + fval = T(Inf) |
| 270 | + end |
| 271 | + @show α, fval |
| 272 | + return fval |
| 273 | + end |
| 274 | + |
| 275 | + α, fval = ls(ϕ, 1f0, fval, dot(gvec, p)) |
| 276 | + |
| 277 | + println("Coupled inversion iteration no: ",j,"; function value: ",fval) |
| 278 | + Grad_Loss[j+1] = fval |
| 279 | + |
| 280 | + global x_inv = decode(x_normalizer,reshape(x,nx,ny,1))[:,:,1] |
| 281 | + imshow(x_inv,vmin=20,vmax=120);title("inversion by NN, $j iter"); |
| 282 | + |
| 283 | + # Update model and bound projection |
| 284 | + @. x = x + α*p::AbstractArray{T} |
| 285 | +end |
| 286 | + |
| 287 | +figure(figsize=(20,12)); |
| 288 | +subplot(1,3,1) |
| 289 | +imshow(x_init,vmin=20,vmax=120);title("initial permeability"); |
| 290 | +subplot(1,3,2); |
| 291 | +imshow(x_inv,vmin=20,vmax=120);title("inversion by coupled NN, $grad_iterations iter"); |
| 292 | +subplot(1,3,3); |
| 293 | +imshow(x_test_1,vmin=20,vmax=120);title("GT permeability"); |
| 294 | +suptitle("$nv vintages, $nsrc each, crosswell") |
| 295 | +figure(); |
| 296 | +plot(Grad_Loss) |
| 297 | + |
| 298 | +JLD2.@save "result/coupleinv$(grad_iterations)_$(nv)nv_$(nsrc)nsrc.jld2" x_inv Grad_Loss |
0 commit comments