-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmetrics.py
166 lines (109 loc) · 4.15 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from multiprocessing.pool import Pool
from typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union
import torch
import numpy as np
from torch import einsum
from torch import Tensor
from functools import partial
from scipy.ndimage import distance_transform_edt as distance
from scipy.spatial.distance import directed_hausdorff
EPS = 1e-7
# Assert utils
def uniq(a: Tensor) -> Set:
return set(torch.unique(a.cpu()).numpy())
def sset(a: Tensor, sub: Iterable) -> bool:
return uniq(a).issubset(sub)
def eq(a: Tensor, b) -> bool:
return torch.eq(a, b).all()
def simplex(t: Tensor, axis=1) -> bool:
_sum = t.sum(axis).type(torch.float32)
_ones = torch.ones_like(_sum, dtype=torch.float32)
return torch.allclose(_sum, _ones)
def one_hot(t: Tensor, axis=1) -> bool:
return simplex(t, axis) and sset(t, [0, 1])
# # Metrics and shitz
def meta_dice(sum_str: str, label: Tensor, pred: Tensor, smooth: float = 1e-8) -> float:
assert label.shape == pred.shape
assert one_hot(label)
assert one_hot(pred)
inter_size: Tensor = einsum(sum_str, [intersection(label, pred)]).type(torch.float32)
sum_sizes: Tensor = (einsum(sum_str, [label]) + einsum(sum_str, [pred])).type(torch.float32)
dices: Tensor = (2 * inter_size + smooth) / (sum_sizes + smooth)
return dices
dice_coef = partial(meta_dice, "bcwh->bc")
dice_batch = partial(meta_dice, "bcwh->c") # used for 3d dice
def intersection(a: Tensor, b: Tensor) -> Tensor:
assert a.shape == b.shape
assert sset(a, [0, 1])
assert sset(b, [0, 1])
return a & b
def union(a: Tensor, b: Tensor) -> Tensor:
assert a.shape == b.shape
assert sset(a, [0, 1])
assert sset(b, [0, 1])
return a | b
def haussdorf(preds: Tensor, target: Tensor) -> Tensor:
assert preds.shape == target.shape
assert one_hot(preds)
assert one_hot(target)
B, C, _, _ = preds.shape
res = torch.zeros((B, C), dtype=torch.float32, device=preds.device)
n_pred = preds.cpu().numpy()
n_target = target.cpu().numpy()
for b in range(B):
if C == 2:
res[b, :] = numpy_haussdorf(n_pred[b, 0], n_target[b, 0])
continue
for c in range(C):
res[b, c] = numpy_haussdorf(n_pred[b, c], n_target[b, c])
return res
def numpy_haussdorf(pred: np.ndarray, target: np.ndarray) -> float:
assert len(pred.shape) == 2
assert pred.shape == target.shape
return max(directed_hausdorff(pred, target)[0], directed_hausdorff(target, pred)[0])
# switch between representations
def probs2class(probs: Tensor) -> Tensor:
b, _, w, h = probs.shape # type: Tuple[int, int, int, int]
assert simplex(probs)
res = probs.argmax(dim=1)
assert res.shape == (b, w, h)
return res
def class2one_hot(seg: Tensor, C: int) -> Tensor:
if len(seg.shape) == 2: # Only w, h, used by the dataloader
seg = seg.unsqueeze(dim=0)
assert sset(seg, list(range(C)))
b, w, h = seg.shape # type: Tuple[int, int, int]
res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32)
assert res.shape == (b, C, w, h)
assert one_hot(res)
return res
def probs2one_hot(probs: Tensor) -> Tensor:
_, C, _, _ = probs.shape
assert simplex(probs)
res = class2one_hot(probs2class(probs), C)
assert res.shape == probs.shape
assert one_hot(res)
return res
def one_hot2dist(seg: np.ndarray) -> np.ndarray:
assert one_hot(torch.Tensor(seg), axis=0)
C: int = len(seg)
res = np.zeros_like(seg)
for c in range(C):
posmask = seg[c].astype(np.bool)
if posmask.any():
negmask = ~posmask
res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
return res
def nanmean(x):
"""Computes the arithmetic mean ignoring any NaNs."""
return torch.mean(x[x == x])
def onehot2indice(x):
"""Convert NCHW to NHW."""
return torch.argmax(x, dim=1)
def _fast_hist(true, pred, num_classes):
mask = (true >= 0) & (true < num_classes)
hist = torch.bincount(
num_classes * true[mask] + pred[mask],
minlength=num_classes ** 2,
).reshape(num_classes, num_classes).float()
return hist