Skip to content

Commit 84bd54c

Browse files
committed
test with sklearn default
1 parent d3b4f50 commit 84bd54c

File tree

3 files changed

+113
-51
lines changed

3 files changed

+113
-51
lines changed

sobolev_alignment/krr_approx.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
from falkon import Falkon
3131
from falkon.kernels import GaussianKernel, LaplacianKernel, MaternKernel
3232
from falkon.options import FalkonOptions
33+
34+
FALKON_IMPORTED = True
3335
except ImportError:
36+
FALKON_IMPORTED = False
3437
print("FALKON NOT INSTALLED, OR NOT IMPORTED. USING FALKON WOULD RESULT IN BETTER PERFORMANCE.", flush=True)
3538
from sklearn.gaussian_process.kernels import Matern, PairwiseKernel
3639
from sklearn.kernel_ridge import KernelRidge
@@ -58,10 +61,10 @@ class KRRApprox:
5861
}
5962

6063
falkon_kernel = {
61-
"rbf": GaussianKernel,
62-
"gaussian": GaussianKernel,
63-
"laplacian": LaplacianKernel,
64-
"matern": MaternKernel,
64+
"rbf": GaussianKernel if FALKON_IMPORTED else None,
65+
"gaussian": GaussianKernel if FALKON_IMPORTED else None,
66+
"laplacian": LaplacianKernel if FALKON_IMPORTED else None,
67+
"matern": MaternKernel if FALKON_IMPORTED else None,
6568
}
6669

6770
default_kernel_params = {
@@ -244,7 +247,7 @@ def _setup_falkon_clf(self):
244247
penalty=self.penalization,
245248
M=self.M,
246249
maxiter=self.maxiter,
247-
options=FalkonOptions(**self.falkon_options),
250+
options=FalkonOptions(**self.falkon_options) if FALKON_IMPORTED else None,
248251
)
249252
return True
250253

tests/test_krr_approx.py

+77-44
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
M = 500
1414

1515

16+
@pytest.fixture(scope="module")
17+
def falkon_import():
18+
try:
19+
return True
20+
except ImportError:
21+
return False
22+
23+
1624
@pytest.fixture(scope="module")
1725
def input():
1826
return torch.normal(0, 1, size=(n_samples, n_genes))
@@ -81,41 +89,60 @@ def fit_sklearn_matern_ridge(self, sklearn_matern_KRR, input, embedding):
8189
###
8290

8391
@pytest.fixture(scope="class")
84-
def falkon_rbf_KRR(self):
85-
return KRRApprox(
86-
method="falkon", kernel="rbf", kernel_params={"sigma": np.sqrt(2 * n_genes)}, penalization=penalization, M=M
87-
)
92+
def falkon_rbf_KRR(self, falkon_import):
93+
if falkon_import:
94+
return KRRApprox(
95+
method="falkon",
96+
kernel="rbf",
97+
kernel_params={"sigma": np.sqrt(2 * n_genes)},
98+
penalization=penalization,
99+
M=M,
100+
)
101+
else:
102+
return None
88103

89104
@pytest.fixture(scope="class")
90-
def falkon_matern_KRR(self):
91-
return KRRApprox(
92-
method="falkon",
93-
kernel="matern",
94-
kernel_params={"sigma": np.sqrt(2 * n_genes), "nu": 1.5},
95-
penalization=penalization,
96-
M=M,
97-
)
105+
def falkon_matern_KRR(self, falkon_import):
106+
if falkon_import:
107+
return KRRApprox(
108+
method="falkon",
109+
kernel="matern",
110+
kernel_params={"sigma": np.sqrt(2 * n_genes), "nu": 1.5},
111+
penalization=penalization,
112+
M=M,
113+
)
114+
else:
115+
return None
98116

99117
@pytest.fixture(scope="class")
100-
def falkon_laplacian_KRR(self):
101-
return KRRApprox(
102-
method="falkon",
103-
kernel="laplacian",
104-
kernel_params={"sigma": np.sqrt(2 * n_genes)},
105-
penalization=penalization,
106-
M=M,
107-
)
118+
def falkon_laplacian_KRR(self, falkon_import):
119+
if falkon_import:
120+
return KRRApprox(
121+
method="falkon",
122+
kernel="laplacian",
123+
kernel_params={"sigma": np.sqrt(2 * n_genes)},
124+
penalization=penalization,
125+
M=M,
126+
)
127+
else:
128+
return None
108129

109130
@pytest.fixture(scope="class")
110131
def fit_falkon_rbf_ridge(self, falkon_rbf_KRR, input, embedding):
132+
if falkon_rbf_KRR is None:
133+
return None
111134
return falkon_rbf_KRR.fit(input, embedding)
112135

113136
@pytest.fixture(scope="class")
114137
def fit_falkon_laplacian_ridge(self, falkon_laplacian_KRR, input, embedding):
138+
if falkon_laplacian_KRR is None:
139+
return None
115140
return falkon_laplacian_KRR.fit(input, embedding)
116141

117142
@pytest.fixture(scope="class")
118143
def fit_falkon_matern_ridge(self, falkon_matern_KRR, input, embedding):
144+
if falkon_matern_KRR is None:
145+
return None
119146
return falkon_matern_KRR.fit(input, embedding)
120147

121148
###
@@ -131,9 +158,10 @@ def test_all_sklearn_kernels(self):
131158
KRRApprox(kernel=kernel, method="sklearn")
132159
return True
133160

134-
def test_all_falkon_kernels(self):
135-
for kernel in KRRApprox.falkon_kernel:
136-
KRRApprox(kernel=kernel, method="falkon")
161+
def test_all_falkon_kernels(self, falkon_import):
162+
if falkon_import:
163+
for kernel in KRRApprox.falkon_kernel:
164+
KRRApprox(kernel=kernel, method="falkon")
137165
return True
138166

139167
###
@@ -160,32 +188,37 @@ def test_laplacian_sklearn_fit(self, fit_sklearn_laplacian_ridge, valid_input, v
160188
###
161189

162190
def test_rbf_falkon_fit(self, fit_falkon_rbf_ridge, valid_input, valid_embedding):
163-
pred = fit_falkon_rbf_ridge.transform(valid_input)
164-
pearson_corr = scipy.stats.pearsonr(pred.flatten(), valid_embedding.detach().numpy().flatten())
165-
assert pearson_corr[0] > pearson_threshold
191+
if fit_falkon_rbf_ridge is not None:
192+
pred = fit_falkon_rbf_ridge.transform(valid_input)
193+
pearson_corr = scipy.stats.pearsonr(pred.flatten(), valid_embedding.detach().numpy().flatten())
194+
assert pearson_corr[0] > pearson_threshold
166195

167196
def test_matern_falkon_fit(self, fit_falkon_matern_ridge, valid_input, valid_embedding):
168-
pred = fit_falkon_matern_ridge.transform(valid_input)
169-
pearson_corr = scipy.stats.pearsonr(pred.flatten(), valid_embedding.detach().numpy().flatten())
170-
assert pearson_corr[0] > pearson_threshold
197+
if fit_falkon_matern_ridge is not None:
198+
pred = fit_falkon_matern_ridge.transform(valid_input)
199+
pearson_corr = scipy.stats.pearsonr(pred.flatten(), valid_embedding.detach().numpy().flatten())
200+
assert pearson_corr[0] > pearson_threshold
171201

172202
def test_laplacian_falkon_fit(self, fit_falkon_laplacian_ridge, valid_input, valid_embedding):
173-
pred = fit_falkon_laplacian_ridge.transform(valid_input)
174-
pearson_corr = scipy.stats.pearsonr(pred.flatten(), valid_embedding.detach().numpy().flatten())
175-
assert pearson_corr[0] > pearson_threshold
203+
if fit_falkon_laplacian_ridge is not None:
204+
pred = fit_falkon_laplacian_ridge.transform(valid_input)
205+
pearson_corr = scipy.stats.pearsonr(pred.flatten(), valid_embedding.detach().numpy().flatten())
206+
assert pearson_corr[0] > pearson_threshold
176207

177208
def test_ridge_coef_sklearn_fit(self, fit_sklearn_laplacian_ridge, input, valid_input):
178-
pred_reconstruct = fit_sklearn_laplacian_ridge.kernel_(
179-
valid_input, input[fit_sklearn_laplacian_ridge.ridge_samples_idx_, :]
180-
)
181-
pred_reconstruct = pred_reconstruct.dot(fit_sklearn_laplacian_ridge.sample_weights_)
182-
np.testing.assert_array_almost_equal(
183-
pred_reconstruct, fit_sklearn_laplacian_ridge.transform(valid_input), decimal=3
184-
)
209+
if fit_sklearn_laplacian_ridge is not None:
210+
pred_reconstruct = fit_sklearn_laplacian_ridge.kernel_(
211+
valid_input, input[fit_sklearn_laplacian_ridge.ridge_samples_idx_, :]
212+
)
213+
pred_reconstruct = pred_reconstruct.dot(fit_sklearn_laplacian_ridge.sample_weights_)
214+
np.testing.assert_array_almost_equal(
215+
pred_reconstruct, fit_sklearn_laplacian_ridge.transform(valid_input), decimal=3
216+
)
185217

186218
def test_ridge_coef_falkon_fit(self, fit_falkon_laplacian_ridge, input, valid_input):
187-
pred_reconstruct = fit_falkon_laplacian_ridge.kernel_(valid_input, fit_falkon_laplacian_ridge.anchors())
188-
pred_reconstruct = pred_reconstruct.matmul(fit_falkon_laplacian_ridge.sample_weights_)
189-
np.testing.assert_array_almost_equal(
190-
pred_reconstruct, fit_falkon_laplacian_ridge.transform(valid_input), decimal=3
191-
)
219+
if fit_falkon_laplacian_ridge is not None:
220+
pred_reconstruct = fit_falkon_laplacian_ridge.kernel_(valid_input, fit_falkon_laplacian_ridge.anchors())
221+
pred_reconstruct = pred_reconstruct.matmul(fit_falkon_laplacian_ridge.sample_weights_)
222+
np.testing.assert_array_almost_equal(
223+
pred_reconstruct, fit_falkon_laplacian_ridge.transform(valid_input), decimal=3
224+
)

tests/test_sobolev_alignment.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
frac_save_artificial = 0.1
1414

1515

16+
@pytest.fixture(scope="module")
17+
def falkon_import():
18+
try:
19+
return True
20+
except ImportError:
21+
return False
22+
23+
1624
@pytest.fixture(scope="module")
1725
def source_data():
1826
poisson_coef = np.random.randint(1, 25, size=n_genes)
@@ -87,20 +95,38 @@ def target_scvi_params():
8795

8896
class TestSobolevAlignment:
8997
@pytest.fixture(scope="class")
90-
def sobolev_alignment_raw(self, source_scvi_params, target_scvi_params):
98+
def sobolev_alignment_raw(self, falkon_import, source_scvi_params, target_scvi_params):
99+
if falkon_import:
100+
source_krr_params = {"method": "falkon"}
101+
target_krr_params = {"method": "falkon"}
102+
else:
103+
source_krr_params = {"method": "sklearn"}
104+
target_krr_params = {"method": "sklearn"}
105+
91106
return SobolevAlignment(
92107
source_scvi_params=source_scvi_params,
93108
target_scvi_params=target_scvi_params,
109+
source_krr_params=source_krr_params,
110+
target_krr_params=target_krr_params,
94111
source_batch_name=None,
95112
target_batch_name=None,
96113
no_posterior_collapse=False,
97114
)
98115

99116
@pytest.fixture(scope="class")
100-
def sobolev_alignment_batch(self, source_scvi_params, target_scvi_params):
117+
def sobolev_alignment_batch(self, falkon_import, source_scvi_params, target_scvi_params):
118+
if falkon_import:
119+
source_krr_params = {"method": "falkon"}
120+
target_krr_params = {"method": "falkon"}
121+
else:
122+
source_krr_params = {"method": "sklearn"}
123+
target_krr_params = {"method": "sklearn"}
124+
101125
return SobolevAlignment(
102126
source_scvi_params=source_scvi_params,
103127
target_scvi_params=target_scvi_params,
128+
source_krr_params=source_krr_params,
129+
target_krr_params=target_krr_params,
104130
source_batch_name="batch",
105131
target_batch_name="batch",
106132
n_artificial_samples=n_artificial_samples,

0 commit comments

Comments
 (0)