From 64c94282abb4e678f3cac532a59ce91cab0ebd1e Mon Sep 17 00:00:00 2001 From: Matt Dancho Date: Tue, 5 Nov 2024 15:55:43 -0500 Subject: [PATCH] TSCV: add plotting method --- .../crossvalidation/time_series_cv.py | 175 +++++++++++++++--- 1 file changed, 151 insertions(+), 24 deletions(-) diff --git a/src/pytimetk/crossvalidation/time_series_cv.py b/src/pytimetk/crossvalidation/time_series_cv.py index 46ac0d84..dde5b563 100644 --- a/src/pytimetk/crossvalidation/time_series_cv.py +++ b/src/pytimetk/crossvalidation/time_series_cv.py @@ -1,6 +1,9 @@ import pandas as pd import numpy as np +import plotly.graph_objects as go +from plotly.subplots import make_subplots + from timebasedcv import TimeBasedSplit from timebasedcv.splitstate import SplitState from timebasedcv.utils._types import ModeType @@ -20,7 +23,7 @@ class TimeSeriesCV(TimeBasedSplit): """`TimeSeriesCV` is a subclass of `TimeBasedSplit` with default mode set to 'backward' - and an optional `slice_limit` to return the first `n` slices of time series cross-validation sets. + and an optional `split_limit` to return the first `n` slices of time series cross-validation sets. Parameters ---------- @@ -41,8 +44,8 @@ class TimeSeriesCV(TimeBasedSplit): The type of window to use, either "rolling" or "expanding". mode: ModeType, optional The mode to use for cross-validation. Default is 'backward'. - slice_limit: int, optional - The maximum number of slices to return. If not provided, all slices are returned. + split_limit: int, optional + The maximum number of splits to return. If not provided, all splits are returned. Raises: ---------- @@ -59,7 +62,7 @@ class TimeSeriesCV(TimeBasedSplit): Examples: --------- - ```python + ``` {python} import pandas as pd import numpy as np from pytimetk import TimeSeriesCV @@ -97,29 +100,35 @@ class TimeSeriesCV(TimeBasedSplit): forecast_horizon=5, gap=1, stride=0, - slice_limit=3 # Limiting to 3 slices + split_limit=3 # Limiting to 3 splits ) X, y = df.loc[:, ["a", "b"]], df["y"] - # If `time_series` is not provided, it will use the index of `X` or `y` if available - for X_train, X_forecast, y_train, y_forecast in tscv.split(X, y): - - # Get the start and end dates for the training and forecast periods - train_start_date = min(X_train.index) - train_end_date = max(X_train.index) - forecast_start_date = min(X_forecast.index) - forecast_end_date = max(X_forecast.index) - - print(f"Train: {X_train.shape}, Forecast: {X_forecast.shape}") - print(f"Train Period: {train_start_date} to {train_end_date}") - print(f"Forecast Period: {forecast_start_date} to {forecast_end_date}\n") + # Creates a split generator + splits = tscv.split(X, y) + + for X_train, X_forecast, y_train, y_forecast in splits: + print(X_train) + print(X_forecast) + ``` + + ``` {python} + # Also, you can use `glimpse()` to print summary information about the splits + + tscv.glimpse(y) + ``` + + ``` {python} + # You can also plot the splits by calling `plot()` on the `TimeSeriesCV` instance with the `y` Pandas series + + tscv.plot(y) ``` """ - def __init__(self, *args, mode: ModeType = "backward", slice_limit: int = None, **kwargs): + def __init__(self, *args, mode: ModeType = "backward", split_limit: int = None, **kwargs): super().__init__(*args, mode=mode, **kwargs) - self.slice_limit = slice_limit + self.split_limit = split_limit def split( self, @@ -129,7 +138,7 @@ def split( end_dt: NullableDatetime = None, return_splitstate: bool = False, ) -> Generator[Union[Tuple[TL, ...], Tuple[Tuple[TL, ...], SplitState]], None, None]: - """Returns a generator of split arrays with an optional `slice_limit`. + """Returns a generator of split arrays with an optional `split_limit`. Arguments: *arrays: @@ -145,8 +154,8 @@ def split( Whether to return the `SplitState` instance for each split. Yields: - A generator of tuples of arrays containing the training and forecast data. If `slice_limit` is set, - yields only up to `slice_limit` splits. + A generator of tuples of arrays containing the training and forecast data. If `split_limit` is set, + yields only up to `split_limit` splits. """ # If time_series is not provided, attempt to extract it from the index of the first array if time_series is None: @@ -163,14 +172,132 @@ def split( *arrays, time_series=time_series, start_dt=start_dt, end_dt=end_dt, return_splitstate=return_splitstate ) - if self.slice_limit is not None: + if self.split_limit is not None: for i, split in enumerate(split_generator): - if i >= self.slice_limit: + if i >= self.split_limit: break yield split else: yield from split_generator + def glimpse(self, *arrays: TL, time_series: SeriesLike[DateTimeLike] = None): + """Prints summary information about the splits, focusing on the first two arrays. + + Arguments: + *arrays: + The arrays to split. Only the first two will be used for summary information. + time_series: + The time series used for splitting. If not provided, the index of the first array is used. Default is None. + """ + + # Use only the first array for splitting and summary + X = arrays[0] + + if time_series is None: + if isinstance(X, (pd.DataFrame, pd.Series)): + time_series = X.index + else: + raise ValueError("time_series must be provided if the first array does not have a time-based index.") + + # If the time_series is an index, convert it to a Series for easier handling + if isinstance(time_series, pd.Index): + time_series = pd.Series(time_series, index=time_series) + + # Iterate through the splits and print summary information + for split_number, (X_train, X_forecast) in enumerate(self.split(X, time_series=time_series), start=1): + # Get the start and end dates for the training and forecast periods + train_start_date = time_series[X_train.index[0]] + train_end_date = time_series[X_train.index[-1]] + forecast_start_date = time_series[X_forecast.index[0]] + forecast_end_date = time_series[X_forecast.index[-1]] + + # Print summary information + print(f"Split Number: {split_number}") + print(f"Train Shape: {X_train.shape}, Forecast Shape: {X_forecast.shape}") + print(f"Train Period: {train_start_date} to {train_end_date}") + print(f"Forecast Period: {forecast_start_date} to {forecast_end_date}\n") + + + def plot(self, y: pd.Series, time_series: pd.Series = None): + """Plots the cross-validation sets using Plotly with each fold in a separate subplot. + + Arguments: + y: Pandas.Series + The Pandas series of target values to plot. + time_series: Optional[pd.Series] + The time series used for the x-axis. If not provided, the index of `y` will be used. + """ + # Use the index of y if time_series is not provided + if time_series is None: + if isinstance(y, pd.Series): + time_series = y.index + else: + raise ValueError("time_series must be provided if y does not have a time-based index.") + + # Ensure time_series is a Pandas Index + if not isinstance(time_series, pd.Index): + raise ValueError("time_series must be a Pandas Index or convertible to one.") + + # Determine the number of folds + splits = list(self.split(y, time_series=time_series, return_splitstate=True)) + num_folds = len(splits) + + # Create subplots + fig = make_subplots( + rows=num_folds, cols=1, # One column, multiple rows + shared_xaxes=True, # Share the x-axis across all subplots + subplot_titles=[f"Fold {i+1}" for i in range(num_folds)] + ) + + # Enumerate through the splits and add traces to each subplot + for fold, (train_forecast, split_state) in enumerate(splits, start=1): + train, forecast = train_forecast + + ts = split_state.train_start + te = split_state.train_end + fs = split_state.forecast_start + fe = split_state.forecast_end + + # Add train set trace to the current subplot + fig.add_trace( + go.Scatter( + x=time_series[(time_series >= ts) & (time_series < te)], + y=train + fold, + name=f"Train Fold {fold}", + mode="markers", + marker={"color": "rgb(57, 105, 172)"} + ), + row=fold, col=1 + ) + + # Add forecast set trace to the current subplot + fig.add_trace( + go.Scatter( + x=time_series[(time_series >= fs) & (time_series < fe)], + y=forecast + fold, + name=f"Forecast Fold {fold}", + mode="markers", + marker={"color": "indianred"} + ), + row=fold, col=1 + ) + + # Update layout + fig.update_layout( + title={ + "text": "Time-Based Cross Validation", + "y": 0.95, "x": 0.5, + "xanchor": "center", + "yanchor": "top" + }, + showlegend=True, + height=300 * num_folds, # Adjust height based on the number of folds + xaxis_title="Time", + yaxis_title="Fold" + ) + + return fig + # class TimeSeriesCV: