Skip to content

Commit 07c7109

Browse files
committed
Show quality improvement for different values of k
1 parent ef173a4 commit 07c7109

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

kmeans.py

+18-13
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@ def inner(*moreargs):
99
return func(*args, *moreargs)
1010
return inner
1111

12-
Point = Tuple[float, ...]
13-
Centroid = Point
14-
1512
def mean(data: Iterable[float]) -> float:
1613
'Accurate arithmetic mean'
1714
data = list(data)
@@ -21,8 +18,11 @@ def transpose(matrix: Iterable[Iterable]) -> Iterable[tuple]:
2118
'Swap rows with columns for a 2-D array'
2219
return zip(*matrix)
2320

21+
Point = Tuple[float, ...]
22+
Centroid = Point
23+
2424
def dist(p: Point, q: Point, sqrt=sqrt, fsum=fsum, zip=zip) -> float:
25-
'Euclidean distance'
25+
'Multi-dimensional euclidean distance'
2626
return sqrt(fsum((x1 - x2) ** 2.0 for x1, x2 in zip(p, q)))
2727

2828
def assign_data(centroids: Sequence[Centroid], data: Iterable[Point]) -> Dict[Centroid, Sequence[Point]]:
@@ -46,10 +46,16 @@ def k_means(data: Iterable[Point], k:int=2, iterations:int=10) -> List[Point]:
4646
centroids = compute_centroids(labeled.values())
4747
return centroids
4848

49+
def quality(labeled: Dict[Centroid, Sequence[Point]]) -> float:
50+
'Mean value of squared distances from data to its assigned centroid'
51+
return mean(dist(c, p) ** 2 for c, pts in labeled.items() for p in pts)
52+
53+
4954
if __name__ == '__main__':
5055

5156
from pprint import pprint
5257

58+
print('Simple example with six 3-D points clustered into two groups')
5359
points = [
5460
(10, 41, 23),
5561
(22, 30, 29),
@@ -62,12 +68,10 @@ def k_means(data: Iterable[Point], k:int=2, iterations:int=10) -> List[Point]:
6268
centroids = k_means(points, k=2)
6369
pprint(assign_data(centroids, points))
6470

65-
if __name__ == '__main__':
66-
# https://www.datascience.com/blog/introduction-to-k-means-clustering-algorithm-learn-data-science-tutorials
67-
from pprint import pprint
71+
print('\nExample with a richer dataset.')
72+
print('See: https://www.datascience.com/blog/introduction-to-k-means-clustering-algorithm-learn-data-science-tutorials')
6873

6974
data = [
70-
7175
(10, 30),
7276
(12, 50),
7377
(14, 70),
@@ -89,8 +93,9 @@ def k_means(data: Iterable[Point], k:int=2, iterations:int=10) -> List[Point]:
8993
(90, 160),
9094
]
9195

92-
# 5583 1338 1202 668 611 409 463
93-
centroids = k_means(data, k=4, iterations=20)
94-
d = assign_data(centroids, data)
95-
pprint(d)
96-
96+
print('k quality')
97+
print('- -------')
98+
for k in range(1, 8):
99+
centroids = k_means(data, k, iterations=20)
100+
d = assign_data(centroids, data)
101+
print(f'{k} {quality(d) :8,.1f}')

0 commit comments

Comments
 (0)