Skip to content

Commit d68d824

Browse files
Preserve order of samples/classes/labels for plot_pca_2d_projection reiinakano#108
1 parent bdf1116 commit d68d824

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

scikitplot/decomposition.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
import numpy as np
1414

1515

16-
def plot_pca_component_variance(clf, title='PCA Component Explained Variances',
17-
target_explained_variance=0.75, ax=None,
18-
figsize=None, title_fontsize="large",
19-
text_fontsize="medium"):
16+
def plot_pca_component_variance(
17+
clf, title='PCA Component Explained Variances',
18+
target_explained_variance=0.75, ax=None,
19+
figsize=None, title_fontsize="large",
20+
text_fontsize="medium"
21+
):
2022
"""Plots PCA components' explained variance ratios. (new in v0.2.2)
2123
2224
Args:
@@ -95,11 +97,13 @@ def plot_pca_component_variance(clf, title='PCA Component Explained Variances',
9597
return ax
9698

9799

98-
def plot_pca_2d_projection(clf, X, y, title='PCA 2-D Projection',
99-
biplot=False, feature_labels=None,
100-
ax=None, figsize=None, cmap='Spectral',
101-
title_fontsize="large", text_fontsize="medium",
102-
dimensions=[0, 1], label_dots=False, ):
100+
def plot_pca_2d_projection(
101+
clf, X, y, title='PCA 2-D Projection',
102+
biplot=False, feature_labels=None,
103+
ax=None, figsize=None, cmap='Spectral',
104+
title_fontsize="large", text_fontsize="medium",
105+
dimensions=[0, 1], label_dots=False,
106+
):
103107
"""Plots the 2-dimensional projection of PCA on a given dataset.
104108
105109
Args:
@@ -165,6 +169,7 @@ def plot_pca_2d_projection(clf, X, y, title='PCA 2-D Projection',
165169
fig, ax = plt.subplots(1, 1, figsize=figsize)
166170

167171
ax.set_title(title, fontsize=title_fontsize)
172+
168173
# Get unique classes from y, preserving order of class occurence in y
169174
_, class_indexes = np.unique(np.array(y), return_index=True)
170175
classes = np.array(y)[np.sort(class_indexes)]

0 commit comments

Comments
 (0)