Skip to content

Commit 18e9c78

Browse files
michaelsextonmerelchtankatiyar
committed
feat(datasets): Add CSVDataset to dask module (kedro-org#627)
* Add CSVDataset to dask module Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> * Add tests to dask.CSVDataset Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> * Fix formatting issues in example usage Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> * Fix error in example usage that is causing test to fail Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> * Remove arguments from example usage Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> * Fix issue with folder used as path for CSV file Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> * Change number of partitions to fix failing assertion Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> * Fix syntax issue Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> * Remove temp path Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> * Add default save args Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> * Add to documentation and release notes Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> * Fix lint Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com> * Try fix netcdfdataset doctest Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com> * Try fix netcdfdataset doctest pointing at file Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com> * Fix moto mock_aws import Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com> * Fix lint and test Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> * Mypy Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> * docs test Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> * docs test Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> * docs test Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> * Fix unit tests Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> * Remove extra comments Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> * Try fix test Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> * Release notes + test Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> * Suggestion from code review Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> --------- Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au> Signed-off-by: Merel Theisen <49397448+merelcht@users.noreply.github.com> Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com> Signed-off-by: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com> Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> Co-authored-by: Merel Theisen <49397448+merelcht@users.noreply.github.com> Co-authored-by: Merel Theisen <merel.theisen@quantumblack.com> Co-authored-by: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com> Co-authored-by: Ankita Katiyar <ankitakatiyar2401@gmail.com> Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>
1 parent 807fa43 commit 18e9c78

File tree

6 files changed

+296
-1
lines changed

6 files changed

+296
-1
lines changed

kedro-datasets/RELEASE.md

+8
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,19 @@
1111
| `langchain.ChatOpenAIDataset` | A dataset for loading a ChatOpenAI langchain model. | `kedro_datasets_experimental.langchain` |
1212
| `netcdf.NetCDFDataset` | A dataset for loading and saving "*.nc" files. | `kedro_datasets_experimental.netcdf` |
1313
* `netcdf.NetCDFDataset` moved from `kedro_datasets` to `kedro_datasets_experimental`.
14+
15+
* Added the following new core datasets:
16+
| Type | Description | Location |
17+
|-------------------------------------|-----------------------------------------------------------|-----------------------------------------|
18+
| `dask.CSVDataset` | A dataset for loading a CSV files using `dask` | `kedro_datasets.dask` |
19+
1420
* Extended preview feature to `yaml.YAMLDataset`.
1521

1622
## Community contributions
1723

1824
Many thanks to the following Kedroids for contributing PRs to this release:
1925
* [Lukas Innig](https://github.com/derluke)
26+
* [Michael Sexton](https://github.com/michaelsexton)
2027

2128

2229
# Release 3.0.1
@@ -58,6 +65,7 @@ Many thanks to the following Kedroids for contributing PRs to this release:
5865
* [Eduardo Romero Lopez](https://github.com/eromerobilbomatica)
5966
* [Jerome Asselin](https://github.com/jerome-asselin-buspatrol)
6067

68+
6169
# Release 2.1.0
6270
## Major features and improvements
6371

kedro-datasets/docs/source/api/kedro_datasets.rst

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ kedro_datasets
1313

1414
kedro_datasets.api.APIDataset
1515
kedro_datasets.biosequence.BioSequenceDataset
16+
kedro_datasets.dask.CSVDataset
1617
kedro_datasets.dask.ParquetDataset
1718
kedro_datasets.databricks.ManagedTableDataset
1819
kedro_datasets.email.EmailMessageDataset

kedro-datasets/kedro_datasets/dask/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901
88
ParquetDataset: Any
9+
CSVDataset: Any
910

1011
__getattr__, __dir__, __all__ = lazy.attach(
11-
__name__, submod_attrs={"parquet_dataset": ["ParquetDataset"]}
12+
__name__,
13+
submod_attrs={"parquet_dataset": ["ParquetDataset"], "csv_dataset": ["CSVDataset"]},
1214
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""``CSVDataset`` is a data set used to load and save data to CSV files using Dask
2+
dataframe"""
3+
from __future__ import annotations
4+
5+
from copy import deepcopy
6+
from typing import Any
7+
8+
import dask.dataframe as dd
9+
import fsspec
10+
from kedro.io.core import AbstractDataset, get_protocol_and_path
11+
12+
13+
class CSVDataset(AbstractDataset[dd.DataFrame, dd.DataFrame]):
14+
"""``CSVDataset`` loads and saves data to comma-separated value file(s). It uses Dask
15+
remote data services to handle the corresponding load and save operations:
16+
https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html
17+
18+
Example usage for the
19+
`YAML API <https://kedro.readthedocs.io/en/stable/data/\
20+
data_catalog_yaml_examples.html>`_:
21+
22+
.. code-block:: yaml
23+
24+
cars:
25+
type: dask.CSVDataset
26+
filepath: s3://bucket_name/path/to/folder
27+
save_args:
28+
compression: GZIP
29+
credentials:
30+
client_kwargs:
31+
aws_access_key_id: YOUR_KEY
32+
aws_secret_access_key: YOUR_SECRET
33+
34+
Example usage for the
35+
`Python API <https://kedro.readthedocs.io/en/stable/data/\
36+
advanced_data_catalog_usage.html>`_:
37+
38+
.. code-block:: pycon
39+
40+
>>> from kedro_datasets.dask import CSVDataset
41+
>>> import pandas as pd
42+
>>> import numpy as np
43+
>>> import dask.dataframe as dd
44+
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [[5, 6], [7, 8]]})
45+
>>> ddf = dd.from_pandas(data, npartitions=1)
46+
>>> dataset = CSVDataset(filepath="path/to/folder/*.csv")
47+
>>> dataset.save(ddf)
48+
>>> reloaded = dataset.load()
49+
>>> assert np.array_equal(ddf.compute(), reloaded.compute())
50+
"""
51+
52+
DEFAULT_LOAD_ARGS: dict[str, Any] = {}
53+
DEFAULT_SAVE_ARGS: dict[str, Any] = {"index": False}
54+
55+
def __init__( # noqa: PLR0913
56+
self,
57+
filepath: str,
58+
load_args: dict[str, Any] | None = None,
59+
save_args: dict[str, Any] | None = None,
60+
credentials: dict[str, Any] | None = None,
61+
fs_args: dict[str, Any] | None = None,
62+
metadata: dict[str, Any] | None = None,
63+
) -> None:
64+
"""Creates a new instance of ``CSVDataset`` pointing to concrete
65+
CSV files.
66+
67+
Args:
68+
filepath: Filepath in POSIX format to a CSV file
69+
CSV collection or the directory of a multipart CSV.
70+
load_args: Additional loading options `dask.dataframe.read_csv`:
71+
https://docs.dask.org/en/latest/generated/dask.dataframe.read_csv.html
72+
save_args: Additional saving options for `dask.dataframe.to_csv`:
73+
https://docs.dask.org/en/latest/generated/dask.dataframe.to_csv.html
74+
credentials: Credentials required to get access to the underlying filesystem.
75+
E.g. for ``GCSFileSystem`` it should look like `{"token": None}`.
76+
fs_args: Optional parameters to the backend file system driver:
77+
https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html#optional-parameters
78+
metadata: Any arbitrary metadata.
79+
This is ignored by Kedro, but may be consumed by users or external plugins.
80+
"""
81+
self._filepath = filepath
82+
self._fs_args = deepcopy(fs_args) or {}
83+
self._credentials = deepcopy(credentials) or {}
84+
85+
self.metadata = metadata
86+
87+
# Handle default load and save arguments
88+
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
89+
if load_args is not None:
90+
self._load_args.update(load_args)
91+
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
92+
if save_args is not None:
93+
self._save_args.update(save_args)
94+
95+
@property
96+
def fs_args(self) -> dict[str, Any]:
97+
"""Property of optional file system parameters.
98+
99+
Returns:
100+
A dictionary of backend file system parameters, including credentials.
101+
"""
102+
fs_args = deepcopy(self._fs_args)
103+
fs_args.update(self._credentials)
104+
return fs_args
105+
106+
def _describe(self) -> dict[str, Any]:
107+
return {
108+
"filepath": self._filepath,
109+
"load_args": self._load_args,
110+
"save_args": self._save_args,
111+
}
112+
113+
def _load(self) -> dd.DataFrame:
114+
return dd.read_csv(
115+
self._filepath, storage_options=self.fs_args, **self._load_args
116+
)
117+
118+
def _save(self, data: dd.DataFrame) -> None:
119+
data.to_csv(self._filepath, storage_options=self.fs_args, **self._save_args)
120+
121+
def _exists(self) -> bool:
122+
protocol = get_protocol_and_path(self._filepath)[0]
123+
file_system = fsspec.filesystem(protocol=protocol, **self.fs_args)
124+
files = file_system.glob(self._filepath)
125+
return bool(files)

kedro-datasets/kedro_datasets_experimental/netcdf/netcdf_dataset.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class NetCDFDataset(AbstractDataset):
6262
... )
6363
>>> dataset.save(ds)
6464
>>> reloaded = dataset.load()
65+
>>> assert ds.equals(reloaded)
6566
"""
6667

6768
DEFAULT_LOAD_ARGS: dict[str, Any] = {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import boto3
2+
import dask.dataframe as dd
3+
import numpy as np
4+
import pandas as pd
5+
import pytest
6+
from kedro.io.core import DatasetError
7+
from moto import mock_aws
8+
from s3fs import S3FileSystem
9+
10+
from kedro_datasets.dask import CSVDataset
11+
12+
FILE_NAME = "*.csv"
13+
BUCKET_NAME = "test_bucket"
14+
AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"}
15+
16+
# Pathlib cannot be used since it strips out the second slash from "s3://"
17+
S3_PATH = f"s3://{BUCKET_NAME}/{FILE_NAME}"
18+
19+
20+
@pytest.fixture
21+
def mocked_s3_bucket():
22+
"""Create a bucket for testing using moto."""
23+
with mock_aws():
24+
conn = boto3.client(
25+
"s3",
26+
aws_access_key_id="fake_access_key",
27+
aws_secret_access_key="fake_secret_key",
28+
)
29+
conn.create_bucket(Bucket=BUCKET_NAME)
30+
yield conn
31+
32+
33+
@pytest.fixture
34+
def dummy_dd_dataframe() -> dd.DataFrame:
35+
df = pd.DataFrame(
36+
{"Name": ["Alex", "Bob", "Clarke", "Dave"], "Age": [31, 12, 65, 29]}
37+
)
38+
return dd.from_pandas(df, npartitions=1)
39+
40+
41+
@pytest.fixture
42+
def mocked_s3_object(tmp_path, mocked_s3_bucket, dummy_dd_dataframe: dd.DataFrame):
43+
"""Creates test data and adds it to mocked S3 bucket."""
44+
pandas_df = dummy_dd_dataframe.compute()
45+
temporary_path = tmp_path / "test.csv"
46+
pandas_df.to_csv(str(temporary_path))
47+
48+
mocked_s3_bucket.put_object(
49+
Bucket=BUCKET_NAME, Key=FILE_NAME, Body=temporary_path.read_bytes()
50+
)
51+
return mocked_s3_bucket
52+
53+
54+
@pytest.fixture
55+
def s3_dataset(load_args, save_args):
56+
return CSVDataset(
57+
filepath=S3_PATH,
58+
credentials=AWS_CREDENTIALS,
59+
load_args=load_args,
60+
save_args=save_args,
61+
)
62+
63+
64+
@pytest.fixture()
65+
def s3fs_cleanup():
66+
# clear cache so we get a clean slate every time we instantiate a S3FileSystem
67+
yield
68+
S3FileSystem.cachable = False
69+
70+
71+
@pytest.mark.usefixtures("s3fs_cleanup")
72+
class TestCSVDataset:
73+
def test_incorrect_credentials_load(self):
74+
"""Test that incorrect credential keys won't instantiate dataset."""
75+
pattern = r"unexpected keyword argument"
76+
with pytest.raises(DatasetError, match=pattern):
77+
CSVDataset(
78+
filepath=S3_PATH,
79+
credentials={
80+
"client_kwargs": {"access_token": "TOKEN", "access_key": "KEY"}
81+
},
82+
).load().compute()
83+
84+
@pytest.mark.parametrize("bad_credentials", [{"key": None, "secret": None}])
85+
def test_empty_credentials_load(self, bad_credentials):
86+
csv_dataset = CSVDataset(filepath=S3_PATH, credentials=bad_credentials)
87+
pattern = r"Failed while loading data from data set CSVDataset\(.+\)"
88+
with pytest.raises(DatasetError, match=pattern):
89+
csv_dataset.load().compute()
90+
91+
@pytest.mark.xfail
92+
def test_pass_credentials(self, mocker):
93+
"""Test that AWS credentials are passed successfully into boto3
94+
client instantiation on creating S3 connection."""
95+
client_mock = mocker.patch("botocore.session.Session.create_client")
96+
s3_dataset = CSVDataset(filepath=S3_PATH, credentials=AWS_CREDENTIALS)
97+
pattern = r"Failed while loading data from data set CSVDataset\(.+\)"
98+
with pytest.raises(DatasetError, match=pattern):
99+
s3_dataset.load().compute()
100+
101+
assert client_mock.call_count == 1
102+
args, kwargs = client_mock.call_args_list[0]
103+
assert args == ("s3",)
104+
assert kwargs["aws_access_key_id"] == AWS_CREDENTIALS["key"]
105+
assert kwargs["aws_secret_access_key"] == AWS_CREDENTIALS["secret"]
106+
107+
def test_save_data(self, s3_dataset, mocked_s3_bucket):
108+
"""Test saving the data to S3."""
109+
pd_data = pd.DataFrame(
110+
{"col1": ["a", "b"], "col2": ["c", "d"], "col3": ["e", "f"]}
111+
)
112+
dd_data = dd.from_pandas(pd_data, npartitions=1)
113+
s3_dataset.save(dd_data)
114+
loaded_data = s3_dataset.load()
115+
np.array_equal(loaded_data.compute(), dd_data.compute())
116+
117+
def test_load_data(self, s3_dataset, dummy_dd_dataframe, mocked_s3_object):
118+
"""Test loading the data from S3."""
119+
loaded_data = s3_dataset.load()
120+
np.array_equal(loaded_data, dummy_dd_dataframe.compute())
121+
122+
def test_exists(self, s3_dataset, dummy_dd_dataframe, mocked_s3_bucket):
123+
"""Test `exists` method invocation for both existing and
124+
nonexistent data set."""
125+
assert not s3_dataset.exists()
126+
s3_dataset.save(dummy_dd_dataframe)
127+
assert s3_dataset.exists()
128+
129+
def test_save_load_locally(self, tmp_path, dummy_dd_dataframe):
130+
"""Test loading the data locally."""
131+
file_path = str(tmp_path / "some" / "dir" / FILE_NAME)
132+
dataset = CSVDataset(filepath=file_path)
133+
134+
assert not dataset.exists()
135+
dataset.save(dummy_dd_dataframe)
136+
assert dataset.exists()
137+
loaded_data = dataset.load()
138+
dummy_dd_dataframe.compute().equals(loaded_data.compute())
139+
140+
@pytest.mark.parametrize(
141+
"load_args", [{"k1": "v1", "index": "value"}], indirect=True
142+
)
143+
def test_load_extra_params(self, s3_dataset, load_args):
144+
"""Test overriding the default load arguments."""
145+
for key, value in load_args.items():
146+
assert s3_dataset._load_args[key] == value
147+
148+
@pytest.mark.parametrize(
149+
"save_args", [{"k1": "v1", "index": "value"}], indirect=True
150+
)
151+
def test_save_extra_params(self, s3_dataset, save_args):
152+
"""Test overriding the default save arguments."""
153+
154+
for key, value in save_args.items():
155+
assert s3_dataset._save_args[key] == value
156+
157+
for key, value in s3_dataset.DEFAULT_SAVE_ARGS.items():
158+
assert s3_dataset._save_args[key] != value

0 commit comments

Comments
 (0)