-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmv_normal_gamma.jl
52 lines (39 loc) · 1.46 KB
/
mv_normal_gamma.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
module mv_normal_gamma
using LinearAlgebra
using Distributions
using SpecialFunctions
export MvNormalGamma, pdf, logpdf, params, dimensions
mutable struct MvNormalGamma <: ContinuousMultivariateDistribution
D ::Integer
μ ::Vector
Λ ::Matrix
α ::Real
β ::Real
function MvNormalGamma(mean::Vector, precision_matrix::Matrix, shape::Float64, rate::Float64)
if shape <= 0.0; error("Shape parameter must be positive."); end
if rate <= 0.0; error("Rate parameter must be positive."); end
dimensions = length(mean)
if size(precision_matrix, 1) != dimensions
error("Number of rows of precision matrix does not match mean vector length.")
end
if size(precision_matrix, 2) != dimensions
error("Number of columns of precision matrix does not match mean vector length.")
end
return new(dimensions, mean, precision_matrix, shape, rate)
end
end
function dims(p::MvNormalGamma)
return p.D
end
function params(p::MvNormalGamma)
return p.μ, p.Λ, p.α, p.β
end
function pdf(p::MvNormalGamma, θ, τ)
μ, Λ, α, β = params(p)
return det(Λ)^(1/2) * (2π)^(-p.D/2)*β^α/gamma(α)*τ^(α+p.D/2-1)*exp( -τ/2*((θ-μ)'*Λ*(θ-μ) +2β) )
end
function logpdf(p::MvNormalGamma, θ, τ)
μ, Λ, α, β = params(p)
return 1/2*logdet(Λ) -p.D/2*log(2π) + α*log(β) - log(gamma(α)) +(α+p.D/2-1)*log(τ) -τ/2*((θ-μ)'*Λ*(θ-μ) +2β)
end
end