Skip to content

Commit

Permalink
refactor(prepare_data): use None as default for split_ratio
Browse files Browse the repository at this point in the history
Changed the default value of `split_ratio` parameter from `(0.8, 0.1,
0.1)` to `None`.

This change improves the readability of the function signature in both
code and documentation.

Fixes #35
  • Loading branch information
adosar committed Jan 10, 2025
1 parent a0c543e commit 53d63eb
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/aidsorb/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from . transforms import upsample_pcd


def prepare_data(source: str, split_ratio: Sequence=(0.8, 0.1, 0.1), seed: int = SEED):
def prepare_data(source: str, split_ratio: Sequence = None, seed: int = SEED):
r"""
Split point clouds into train, validation and test sets.
Expand All @@ -50,9 +50,9 @@ def prepare_data(source: str, split_ratio: Sequence=(0.8, 0.1, 0.1), seed: int =
----------
source : str
Absolute or relative path to the directory holding the point clouds.
split_ratio : sequence, default=(0.8, 0.1, 0.1)
split_ratio : sequence, default=None
Absolute sizes or fractions of splits of the form ``(train, val,
test)``.
test)``. If ``None``, it is set to ``(0.8, 0.1, 0.1)``.
seed : int, default=1
Controls randomness of the ``rng`` used for splitting.
Expand Down Expand Up @@ -84,6 +84,10 @@ def prepare_data(source: str, split_ratio: Sequence=(0.8, 0.1, 0.1), seed: int =
path = Path(source).parent
pcd_names = [name.removesuffix('.npy') for name in os.listdir(source)]

# Set default split ratio.
if split_ratio is None:
split_ratio = (0.8, 0.1, 0.1)

# Split the names of the point clouds.
train, val, test = random_split(pcd_names, split_ratio, generator=rng)

Expand Down

0 comments on commit 53d63eb

Please sign in to comment.