-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathobs_noise_variation.jl
57 lines (48 loc) · 2.02 KB
/
obs_noise_variation.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
include("experiment_common.jl")
using CSV
using DataFrames
SEDL.should_check_finite[] = false
with_alert("obs_noise_variation.jl") do
σ_deg = 20.0
group_name = "obs_noise_variation"
result_name = "$(σ_deg)°"
println("Starting experiment: $result_name...")
result_dir = joinpath("reports/$group_name/$result_name")
if ispath(result_dir)
error("\"$result_dir\" already exists")
end
mkpath(result_dir)
perf_best = []
perf_measure = []
for train_method in [:Handwritten, :FitHand, :FitTV, :FitTruth, :EM, :EM_NS, :SVI]
# you can find the available args inside `train_models.jl`.
local script_args = (;
# is_quick_test=true,
scenario=SEDL.RealCarScenario("alpha_truck"),
# n_train_ex=160,
gpu_id=Main.GPU_ID, # set this in the REPL before running the script
# use_fixed_variance=true,
# use_simple_obs_model=true,
σ_bearing=σ_deg * °,
exp_group=group_name,
validation_metric=(train_method == :SVI ? :RMSE : :log_obs),
train_method,
)
local perfs = train_multiple_times(script_args, 3).test_performance
local measure = map(SEDL.to_measurement, perfs)
local max_metrics = (; log_obs=maximum(perfs.log_obs))
local min_metrics = map(minimum, SEDL.dropnames(perfs, (:log_obs,)))
local best = merge(max_metrics, min_metrics)
push!(perf_measure, merge((; method=train_method), measure))
push!(perf_best, merge((; method=train_method), best))
CSV.write(joinpath(result_dir, "best.csv"), DataFrame(perf_best))
CSV.write(joinpath(result_dir, "average.csv"), DataFrame(perf_measure))
end
println("----------------------------------------------")
println("Best performance:")
display(DataFrame(perf_best))
println("----------------------------------------------")
println("Average performance:")
display(DataFrame(perf_measure))
@info("Results saved to $result_dir ...")
end