Skip to content

Commit 45a80a4

Browse files
committed
Add StudyDataset
Signed-off-by: Guillaume Tauzin <4648633+gtauzin@users.noreply.github.com>
1 parent 44d2ec0 commit 45a80a4

File tree

7 files changed

+659
-0
lines changed

7 files changed

+659
-0
lines changed

kedro-datasets/RELEASE.md

+9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
- Added a parameter to enable/disable lazy saving for `PartitionedDataset`.
66
- Added `ibis-athena` and `ibis-databricks` extras for the backends added in Ibis 10.0.
7+
- Added the following new **experimental** datasets:
8+
9+
| Type | Description | Location |
10+
| ----------------------------- | --------------------------------------------------------------- | ------------------------------------- |
11+
| `optuna.StudyDataset` | A dataset for saving and loading Optuna studies. | `kedro_datasets_experimental.optuna` |
712

813
## Bug fixes and other changes
914

@@ -16,6 +21,10 @@
1621

1722
## Community contributions
1823

24+
Many thanks to the following Kedroids for contributing PRs to this release:
25+
26+
- [Guillaume Tauzin](https://github.com/gtauzin)
27+
1928
# Release 6.0.0
2029

2130
## Major features and improvements

kedro-datasets/docs/source/api/kedro_datasets_experimental.rst

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ kedro_datasets_experimental
1717
langchain.ChatOpenAIDataset
1818
langchain.OpenAIEmbeddingsDataset
1919
netcdf.NetCDFDataset
20+
optuna.StudyDataset
2021
prophet.ProphetModelDataset
2122
pytorch.PyTorchDataset
2223
rioxarray.GeoTIFFDataset
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Provide data loading and saving functionality for Optuna's study."""
2+
3+
from typing import Any
4+
5+
import lazy_loader as lazy
6+
7+
StudyDataset: Any
8+
9+
__getattr__, __dir__, __all__ = lazy.attach(
10+
__name__,
11+
submod_attrs={
12+
"study_dataset": ["StudyDataset"],
13+
},
14+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
"""``StudyDataset`` loads/saves data from/to an optuna Study."""
2+
3+
from __future__ import annotations
4+
5+
import fnmatch
6+
import logging
7+
import os
8+
from copy import deepcopy
9+
from pathlib import PurePosixPath
10+
from typing import Any
11+
12+
import optuna
13+
from kedro.io.core import (
14+
AbstractVersionedDataset,
15+
DatasetError,
16+
Version,
17+
)
18+
from sqlalchemy import URL
19+
from sqlalchemy.dialects import registry
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class StudyDataset(AbstractVersionedDataset[optuna.Study, optuna.Study]):
25+
"""``StudyDataset`` loads/saves data from/to an optuna Study.
26+
27+
Example usage for the
28+
`YAML API <https://docs.kedro.org/en/stable/data/data_catalog_yaml_examples.html>`_:
29+
30+
.. code-block:: yaml
31+
32+
review_prediction_study:
33+
type: optuna.StudyDataset
34+
backend: sqlite
35+
database: data/05_model_input/review_prediction_study.db
36+
load_args:
37+
sampler:
38+
class: TPESampler
39+
n_startup_trials: 10
40+
n_ei_candidates: 5
41+
pruner:
42+
class: NopPruner
43+
versioned: true
44+
45+
price_prediction_study:
46+
type: optuna.StudyDataset
47+
backend: postgresql
48+
database: optuna_db
49+
credentials: dev_optuna_postgresql
50+
51+
Example usage for the
52+
`Python API <https://docs.kedro.org/en/stable/data/\
53+
advanced_data_catalog_usage.html>`_:
54+
55+
.. code-block:: pycon
56+
57+
>>> from kedro_datasets.optuna import StudyDataset
58+
>>> from optuna.distributions import FloatDistribution
59+
>>> import optuna
60+
>>>
61+
>>> study = optuna.create_study()
62+
>>> trial = optuna.trial.create_trial(
63+
... params={"x": 2.0},
64+
... distributions={"x": FloatDistribution(0, 10)},
65+
... value=4.0,
66+
... )
67+
>>> study.add_trial(trial)
68+
>>>
69+
>>> dataset = StudyDataset(backend="sqlite", database="optuna.db")
70+
>>> dataset.save(study)
71+
>>> reloaded = dataset.load()
72+
>>> assert len(reloaded.trials) == 1
73+
>>> assert reloaded.trials[0].params["x"] == 2.0
74+
"""
75+
76+
DEFAULT_LOAD_ARGS: dict[str, Any] = {"sampler": None, "pruner": None}
77+
78+
def __init__( # noqa: PLR0913
79+
self,
80+
*,
81+
backend: str,
82+
database: str,
83+
study_name: str,
84+
load_args: dict[str, Any] | None = None,
85+
version: Version = None,
86+
credentials: dict[str, Any] | None = None,
87+
metadata: dict[str, Any] | None = None,
88+
) -> None:
89+
"""Creates a new instance of ``StudyDataset`` pointing to a concrete optuna
90+
Study on a specific relational database.
91+
92+
Args:
93+
backend: Name of the database backend. This name should correspond to a module
94+
in ``SQLAlchemy``.
95+
database: Name of the database.
96+
study_name: Name of the optuna Study.
97+
load_args: Optuna options for loading studies. Accepts a `sampler` and a
98+
`pruner`. If either are provided, a `class` matching any Optuna `sampler`,
99+
respecitively `pruner` class name should be provided, optionally with
100+
their argyments. Here you can find all available samplers and pruners
101+
and their arguments:
102+
- https://optuna.readthedocs.io/en/stable/reference/samplers/index.html
103+
- https://optuna.readthedocs.io/en/stable/reference/pruners.html
104+
All defaults are preserved.
105+
version: If specified, should be an instance of
106+
``kedro.io.core.Version``. If its ``load`` attribute is
107+
None, the latest version will be loaded. If its ``save``
108+
attribute is None, save version will be autogenerated.
109+
credentials: Credentials required to get access to the underlying RDB.
110+
They can include `username`, `password`, `host`, and `port`.
111+
metadata: Any arbitrary metadata.
112+
This is ignored by Kedro, but may be consumed by users or external plugins.
113+
"""
114+
self._backend = self._validate_backend(backend=backend)
115+
self._database = self._validate_database(backend=backend, database=database)
116+
self._study_name = self._validate_study_name(study_name=study_name)
117+
118+
credentials = self._validate_credentials(backend=backend, credentials=credentials)
119+
storage = URL.create(
120+
drivername=backend,
121+
database=database,
122+
**credentials,
123+
)
124+
125+
self._storage = str(storage)
126+
self.metadata = metadata
127+
128+
filepath = None
129+
if backend == "sqlite":
130+
filepath = PurePosixPath(os.path.realpath(database))
131+
132+
super().__init__(
133+
filepath=filepath,
134+
version=version,
135+
exists_function=self._study_name_exists,
136+
glob_function=self._study_name_glob,
137+
)
138+
139+
# Handle default load and save and fs arguments
140+
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
141+
142+
def _validate_backend(self, backend):
143+
valid_backends = list(registry.impls.keys()) + ["mssql", "mysql", "oracle", "postgresql", "sqlite"]
144+
if backend not in valid_backends:
145+
raise ValueError(
146+
f"Requested `backend` '{backend}' is not registered as an SQLAlchemy dialect."
147+
)
148+
return backend
149+
150+
def _validate_database(self, backend, database):
151+
if not isinstance(database, str):
152+
raise ValueError(f"`database` '{database}' is not a string.")
153+
154+
if backend == "sqlite":
155+
if database == ":memory:":
156+
return database
157+
158+
# Check if the directory exists
159+
database_dir = os.path.dirname(database)
160+
if len(database_dir) and not os.path.isdir(database_dir):
161+
raise FileNotFoundError(
162+
f"The directory of the sqlite DB '{database_dir}' does not exist."
163+
)
164+
165+
# Check if the file has an extension
166+
_, extension = os.path.splitext(database)
167+
if not extension:
168+
raise ValueError(f"The sqlite file `database` '{database}' does not have an extension.")
169+
170+
return database
171+
172+
def _validate_study_name(self, study_name):
173+
if not isinstance(study_name, str):
174+
raise ValueError(f"`study_name` '{study_name}' is not a string.")
175+
return study_name
176+
177+
def _validate_credentials(self, backend, credentials):
178+
if backend == "sqlite" or credentials is None:
179+
return {}
180+
181+
if not set(credentials.keys()) <= {"username", "password", "host", "port"}:
182+
raise ValueError(
183+
"Incorrect `credentials`. Provided `credentials` should contain "
184+
"`'username'`, `'password'`, `'host'`, and/or `'port'`. It contains "
185+
f"{set(credentials.keys())}."
186+
)
187+
188+
return deepcopy(credentials)
189+
190+
def _get_versioned_path(self, version: str) -> PurePosixPath:
191+
study_name_posix = PurePosixPath(self._study_name)
192+
return study_name_posix / version / study_name_posix
193+
194+
def resolve_load_version(self) -> str | None:
195+
"""Compute the version the dataset should be loaded with."""
196+
if not self._version:
197+
return None
198+
if self._version.load:
199+
return self._version.load
200+
return self._fetch_latest_load_version()
201+
202+
def _get_load_path(self) -> PurePosixPath:
203+
# Path is not affected by versioning
204+
return self._filepath
205+
206+
def _get_load_study_name(self) -> str:
207+
if not self._version:
208+
# When versioning is disabled, load from original study name
209+
return self._study_name
210+
211+
load_version = self.resolve_load_version()
212+
return str(self._get_versioned_path(load_version))
213+
214+
def _get_save_path(self) -> PurePosixPath:
215+
# Path is not affected by versioning
216+
return self._filepath
217+
218+
def _get_save_study_name(self) -> str:
219+
if not self._version:
220+
# When versioning is disabled, return original study name
221+
return self._study_name
222+
223+
save_version = self.resolve_save_version()
224+
versioned_study_name = self._get_versioned_path(save_version)
225+
226+
if self._exists_function(str(versioned_study_name)):
227+
raise DatasetError(
228+
f"Study name '{versioned_study_name}' for {self!s} must not exist if "
229+
f"versioning is enabled."
230+
)
231+
232+
return str(versioned_study_name)
233+
234+
def _describe(self) -> dict[str, Any]:
235+
return {
236+
"backend": self._backend,
237+
"database": self._database,
238+
"study_name": self._study_name,
239+
"load_args": self._load_args,
240+
"version": self._version,
241+
}
242+
243+
def _get_sampler(self, sampler_config):
244+
if sampler_config is None:
245+
return None
246+
247+
if "class" not in sampler_config:
248+
raise ValueError(
249+
"Optuna `sampler` 'class' should be specified when trying to load study "
250+
f"named '{self._study_name}' with a `sampler`."
251+
)
252+
253+
sampler_class_name = sampler_config.pop("class")
254+
if sampler_class_name in ["QMCSampler", "CmaEsSampler", "GPSampler"]:
255+
sampler_config["independent_sampler"] = self._get_sampler(
256+
sampler_config.pop("independent_sampler")
257+
)
258+
259+
if sampler_class_name == "PartialFixedSampler":
260+
sampler_config["base_sampler"] = self._get_sampler(
261+
sampler_config.pop("base_sampler")
262+
)
263+
264+
sampler_class = getattr(optuna.samplers, sampler_class_name)
265+
266+
return sampler_class(**sampler_config)
267+
268+
def _get_pruner(self, pruner_config):
269+
if pruner_config is None:
270+
return None
271+
272+
if "class" not in pruner_config:
273+
raise ValueError(
274+
"Optuna `pruner` 'class' should be specified when trying to load study "
275+
f"named '{self._study_name}' with a `pruner`."
276+
)
277+
278+
pruner_class_name = pruner_config.pop("class")
279+
if pruner_class_name == "PatientPruner":
280+
pruner_config["wrapped_pruner"] = self._get_pruner(
281+
pruner_config.pop("wrapped_pruner")
282+
)
283+
284+
pruner_class = getattr(optuna.pruners, pruner_class_name)
285+
286+
return pruner_class(**pruner_config)
287+
288+
def load(self) -> optuna.Study:
289+
load_args = deepcopy(self._load_args)
290+
sampler_config = load_args.pop("sampler")
291+
sampler = self._get_sampler(sampler_config)
292+
293+
pruner_config = load_args.pop("pruner")
294+
pruner = self._get_pruner(pruner_config)
295+
296+
study = optuna.load_study(
297+
storage=self._storage,
298+
study_name=self._get_load_study_name(),
299+
sampler=sampler,
300+
pruner=pruner,
301+
)
302+
303+
return study
304+
305+
def save(self, study: optuna.Study) -> None:
306+
save_study_name = self._get_save_study_name()
307+
if self._backend == "sqlite":
308+
os.makedirs(os.path.dirname(self._filepath), exist_ok=True)
309+
310+
if not os.path.isfile(self._filepath):
311+
optuna.create_study(
312+
storage=self._storage,
313+
)
314+
315+
# To overwrite an existing study, we need to first delete it if it exists
316+
if self._study_name_exists(save_study_name):
317+
optuna.delete_study(
318+
storage=self._storage,
319+
study_name=save_study_name,
320+
)
321+
322+
optuna.copy_study(
323+
from_study_name=study.study_name,
324+
from_storage=study._storage,
325+
to_storage=self._storage,
326+
to_study_name=save_study_name,
327+
)
328+
329+
def _study_name_exists(self, study_name) -> bool:
330+
study_names = optuna.study.get_all_study_names(storage=self._storage)
331+
return study_name in study_names
332+
333+
def _study_name_glob(self, pattern):
334+
study_names = optuna.study.get_all_study_names(storage=self._storage)
335+
for study_name in study_names:
336+
if fnmatch.fnmatch(study_name, pattern):
337+
yield study_name
338+
339+
def _exists(self) -> bool:
340+
try:
341+
load_study_name = self._get_load_study_name()
342+
except DatasetError:
343+
return False
344+
345+
return self._study_name_exists(load_study_name)

kedro-datasets/kedro_datasets_experimental/tests/optuna/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)