Skip to content

Commit

Permalink
refactor storageconfig
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Dec 17, 2024
1 parent 9b141a3 commit 4ba24ad
Show file tree
Hide file tree
Showing 18 changed files with 261 additions and 376 deletions.
15 changes: 2 additions & 13 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -586,26 +586,15 @@ class IOConfig:
"""Replaces values if provided, returning a new IOConfig."""
...

class NativeStorageConfig:
"""Storage configuration for the Rust-native I/O layer."""
class StorageConfig:
"""Configuration for interacting with a particular storage backend."""

# Whether or not to use a multithreaded tokio runtime for processing I/O
multithreaded_io: bool
io_config: IOConfig

def __init__(self, multithreaded_io: bool, io_config: IOConfig): ...

class StorageConfig:
"""Configuration for interacting with a particular storage backend, using a particular I/O layer implementation."""

@staticmethod
def native(config: NativeStorageConfig) -> StorageConfig:
"""Create from a native storage config."""
...

@property
def config(self) -> NativeStorageConfig: ...

class ScanTask:
"""A batch of scan tasks for reading data from an external source."""

Expand Down
2 changes: 1 addition & 1 deletion daft/delta_lake/delta_lake_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
# Thus, if we don't detect any credentials being available, we attempt to detect it from the environment using our Daft credentials chain.
#
# See: https://github.com/delta-io/delta-rs/issues/2117
deltalake_sdk_io_config = storage_config.config.io_config
deltalake_sdk_io_config = storage_config.io_config
scheme = urlparse(table_uri).scheme
if scheme == "s3" or scheme == "s3a":
# Try to get region from boto3
Expand Down
2 changes: 1 addition & 1 deletion daft/hudi/hudi_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class HudiScanOperator(ScanOperator):
def __init__(self, table_uri: str, storage_config: StorageConfig) -> None:
super().__init__()
resolved_path, self._resolved_fs = _resolve_paths_and_filesystem(table_uri, storage_config.config.io_config)
resolved_path, self._resolved_fs = _resolve_paths_and_filesystem(table_uri, storage_config.io_config)
self._table = HudiTable(table_uri, self._resolved_fs, resolved_path[0])
self._storage_config = storage_config
self._schema = Schema.from_pyarrow_schema(self._table.schema)
Expand Down
3 changes: 1 addition & 2 deletions daft/io/_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
CsvSourceConfig,
FileFormatConfig,
IOConfig,
NativeStorageConfig,
StorageConfig,
)
from daft.dataframe import DataFrame
Expand Down Expand Up @@ -87,7 +86,7 @@ def read_csv(
chunk_size=_chunk_size,
)
file_format_config = FileFormatConfig.from_csv_config(csv_config)
storage_config = StorageConfig.native(NativeStorageConfig(True, io_config))
storage_config = StorageConfig(True, io_config)

builder = get_tabular_files_scan(
path=path,
Expand Down
6 changes: 3 additions & 3 deletions daft/io/_deltalake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from daft import context
from daft.api_annotations import PublicAPI
from daft.daft import IOConfig, NativeStorageConfig, ScanOperatorHandle, StorageConfig
from daft.daft import IOConfig, ScanOperatorHandle, StorageConfig
from daft.dataframe import DataFrame
from daft.dependencies import unity_catalog
from daft.io.catalog import DataCatalogTable
Expand Down Expand Up @@ -60,7 +60,7 @@ def read_deltalake(
)

io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config))
storage_config = StorageConfig(multithreaded_io, io_config)

if isinstance(table, str):
table_uri = table
Expand All @@ -72,7 +72,7 @@ def read_deltalake(
# Override the storage_config with the one provided by Unity catalog
table_io_config = table.io_config
if table_io_config is not None:
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, table_io_config))
storage_config = StorageConfig(multithreaded_io, table_io_config)
else:
raise ValueError(
f"table argument must be a table URI string, DataCatalogTable or UnityCatalogTable instance, but got: {type(table)}, {table}"
Expand Down
4 changes: 2 additions & 2 deletions daft/io/_hudi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from daft import context
from daft.api_annotations import PublicAPI
from daft.daft import IOConfig, NativeStorageConfig, ScanOperatorHandle, StorageConfig
from daft.daft import IOConfig, ScanOperatorHandle, StorageConfig
from daft.dataframe import DataFrame
from daft.logical.builder import LogicalPlanBuilder

Expand Down Expand Up @@ -33,7 +33,7 @@ def read_hudi(
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config

multithreaded_io = context.get_context().get_or_create_runner().name != "ray"
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config))
storage_config = StorageConfig(multithreaded_io, io_config)

hudi_operator = HudiScanOperator(table_uri, storage_config=storage_config)

Expand Down
4 changes: 2 additions & 2 deletions daft/io/_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from daft import context
from daft.api_annotations import PublicAPI
from daft.daft import IOConfig, NativeStorageConfig, ScanOperatorHandle, StorageConfig
from daft.daft import IOConfig, ScanOperatorHandle, StorageConfig
from daft.dataframe import DataFrame
from daft.logical.builder import LogicalPlanBuilder

Expand Down Expand Up @@ -123,7 +123,7 @@ def read_iceberg(
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config

multithreaded_io = context.get_context().get_or_create_runner().name != "ray"
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config))
storage_config = StorageConfig(multithreaded_io, io_config)

iceberg_operator = IcebergScanOperator(pyiceberg_table, snapshot_id=snapshot_id, storage_config=storage_config)

Expand Down
3 changes: 1 addition & 2 deletions daft/io/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
FileFormatConfig,
IOConfig,
JsonSourceConfig,
NativeStorageConfig,
StorageConfig,
)
from daft.dataframe import DataFrame
Expand Down Expand Up @@ -64,7 +63,7 @@ def read_json(

json_config = JsonSourceConfig(_buffer_size, _chunk_size)
file_format_config = FileFormatConfig.from_json_config(json_config)
storage_config = StorageConfig.native(NativeStorageConfig(True, io_config))
storage_config = StorageConfig(True, io_config)

builder = get_tabular_files_scan(
path=path,
Expand Down
3 changes: 1 addition & 2 deletions daft/io/_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from daft.daft import (
FileFormatConfig,
IOConfig,
NativeStorageConfig,
ParquetSourceConfig,
StorageConfig,
)
Expand Down Expand Up @@ -84,7 +83,7 @@ def read_parquet(
file_format_config = FileFormatConfig.from_parquet_config(
ParquetSourceConfig(coerce_int96_timestamp_unit=pytimeunit, row_groups=row_groups, chunk_size=_chunk_size)
)
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config))
storage_config = StorageConfig(multithreaded_io, io_config)

builder = get_tabular_files_scan(
path=path,
Expand Down
7 changes: 3 additions & 4 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
JsonConvertOptions,
JsonParseOptions,
JsonReadOptions,
NativeStorageConfig,
StorageConfig,
)
from daft.dependencies import pa, pacsv, pads, pq
Expand Down Expand Up @@ -90,7 +89,7 @@ def read_json(
Returns:
MicroPartition: Parsed MicroPartition from JSON
"""
config = storage_config.config if storage_config is not None else NativeStorageConfig(True, IOConfig())
config = storage_config if storage_config is not None else StorageConfig(True, IOConfig())
assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string inputs to read_json"
json_convert_options = JsonConvertOptions(
limit=read_options.num_rows,
Expand Down Expand Up @@ -126,7 +125,7 @@ def read_parquet(
Returns:
MicroPartition: Parsed MicroPartition from Parquet
"""
config = storage_config.config if storage_config is not None else NativeStorageConfig(True, IOConfig())
config = storage_config if storage_config is not None else StorageConfig(True, IOConfig())
assert isinstance(
file, (str, pathlib.Path)
), "Native downloader only works on string or Path inputs to read_parquet"
Expand Down Expand Up @@ -211,7 +210,7 @@ def read_csv(
Returns:
MicroPartition: Parsed MicroPartition from CSV
"""
config = storage_config.config if storage_config is not None else NativeStorageConfig(True, IOConfig())
config = storage_config if storage_config is not None else StorageConfig(True, IOConfig())
assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string or Path inputs to read_csv"
has_header = csv_options.header_index is not None
csv_convert_options = CsvConvertOptions(
Expand Down
18 changes: 9 additions & 9 deletions src/daft-local-execution/src/sources/scan_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use daft_io::IOStatsRef;
use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions};
use daft_micropartition::MicroPartition;
use daft_parquet::read::{read_parquet_bulk_async, ParquetSchemaInferenceOptions};
use daft_scan::{storage_config::StorageConfig, ChunkSpec, ScanTask};
use daft_scan::{ChunkSpec, ScanTask};
use futures::{Stream, StreamExt, TryStreamExt};
use snafu::ResultExt;
use tracing::instrument;
Expand Down Expand Up @@ -241,14 +241,14 @@ async fn stream_scan_task(
}
let source = scan_task.sources.first().unwrap();
let url = source.get_path();
let (io_config, multi_threaded_io) = match scan_task.storage_config.as_ref() {
StorageConfig::Native(native_storage_config) => (
native_storage_config.io_config.as_ref(),
native_storage_config.multithreaded_io,
),
};
let io_config = Arc::new(io_config.cloned().unwrap_or_default());
let io_client = daft_io::get_io_client(multi_threaded_io, io_config)?;
let io_config = Arc::new(
scan_task
.storage_config
.io_config
.clone()
.unwrap_or_default(),
);
let io_client = daft_io::get_io_client(scan_task.storage_config.multithreaded_io, io_config)?;
let table_stream = match scan_task.file_format_config.as_ref() {
FileFormatConfig::Parquet(ParquetSourceConfig {
coerce_int96_timestamp_unit,
Expand Down
Loading

0 comments on commit 4ba24ad

Please sign in to comment.