-
Notifications
You must be signed in to change notification settings - Fork 25
Refactor time series prediction inference code #61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Introduced a new `TimeSeriesPredictor` class to handle prediction logic, separating it from the TabPFN-specific implementation. - Updated `TabPFNTimeSeriesPredictor` to inherit from `TimeSeriesPredictor`, allowing for a more flexible inference engine setup. - Removed the `tabpfn_worker.py` file and replaced it with a modular worker system, including `ParallelWorker`, `CPUParallelWorker`, and `GPUParallelWorker`. - Added new base inference engine classes for better abstraction and maintainability. - Updated tests to reflect changes in the predictor's structure and ensure proper initialization and functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @liam-sbhoo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly refactors the time series prediction inference code by separating the core model prediction logic from the parallel execution strategy. This architectural change enhances modularity and simplifies the integration of alternative prediction models in the future, making the codebase more flexible and easier to extend.
Highlights
- Dependency Updates: The
pyproject.toml
file has been updated to specifytabpfn
version2.1.0
and to includematplotlib>=3.10.3
as a new development dependency. - Core Code Restructuring: The core
TabPFNTimeSeriesPredictor
class has been refactored to inherit from a newTimeSeriesPredictor
base class. This change decouples the model prediction logic (now handled byInferenceEngine
) from the parallelization strategy (now handled byParallelWorker
), improving modularity and maintainability. - Introduction of New Abstractions: New abstract base classes,
InferenceEngine
andParallelWorker
, have been introduced to define clear interfaces for model prediction and parallel execution. Concrete implementations likeGPUParallelWorker
,CPUParallelWorker
, and variousTabPFNInferenceEngine
types have been added to support different operational modes. - Removal of Legacy Code: The monolithic
tabpfn_time_series/tabpfn_worker.py
file, which previously contained combined prediction and parallelization logic, has been removed. Its responsibilities are now distributed among the newly introducedInferenceEngine
andParallelWorker
classes. - Test Suite Adjustments: Existing tests in
tests/test_predictor.py
have been updated to align with the refactored class structure, specifically changing assertions frompredictor.tabpfn_worker
topredictor.worker
.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request does a great job of refactoring the time series prediction code by decoupling the model prediction logic into an InferenceEngine
and the parallelization into a ParallelWorker
. This significantly improves modularity and will make it easier to experiment with different models in the future. The overall structure is sound. I've identified a few issues, including a critical bug in the mock inference engine, a regression in handling constant-valued time series, and a loss of functionality for configuring worker parallelism. I've also included some suggestions to improve code quality and maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, this nicely decouples everything. Two points:
-
I'm still struggling to understand if the IfnerenceEngine could be dropped as well? Is the code for downloading models (should be done automatically) and renaming for client (could we use different settings) really needed?
-
In current setup, to make models even easier to replace, what do you think about the below change? With this I could use RandomforestTaBPFN or AutoTabPFN without implementing another inference engine.
# In tabpfn_time_series/worker/tabpfn_inference_engine.py
import numpy as np
from copy import deepcopy
from typing import Type
# Keep the base class for the template method pattern
from tabpfn_time_series.worker.base_inference_engine import InferenceEngine
# --- Helper Function (Shared Logic) ---
def process_tabpfn_pred_output(
pred_output: dict, output_selection: str, quantiles: list[float | str]
) -> dict[str, np.ndarray]:
"""Translates raw TabPFN output to the standardized dictionary format."""
result = {"target": pred_output[output_selection]}
result.update({q: q_pred for q, q_pred in zip(quantiles, pred_output["quantiles"])})
return result
# --- New Shared Base Class for TabPFN Models ---
class BaseTabPFNInferenceEngine(InferenceEngine):
"""
A base engine that handles the common fit/predict logic for any
sklearn-compatible TabPFN regressor.
"""
def __init__(self, model_class: Type, config: dict):
super().__init__(config)
self.model_class = model_class
def _predict(
self, train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, quantiles: list
) -> dict[str, np.ndarray]:
"""
This concrete implementation now serves both Local and Client modes.
"""
# 1. Instantiate the specific model class (Local or Client)
model = self.model_class(**self.config["tabpfn_internal"], random_state=0)
# 2. Fit and predict
model.fit(train_X, train_y)
pred_output = model.predict(test_X, output_type="main")
# 3. Process the output
return process_tabpfn_pred_output(
pred_output, self.config["tabpfn_output_selection"], quantiles
)
# --- Final, Lean Implementations ---
class LocalTabPFNInferenceEngine(BaseTabPFNInferenceEngine):
"""Handles local TabPFN: downloads the model and passes the correct class."""
def __init__(self, config: dict):
from tabpfn import TabPFNRegressor # Local import
self._download_model(config["tabpfn_internal"]["model_path"])
super().__init__(model_class=TabPFNRegressor, config=config)
@staticmethod
def _download_model(model_name: str):
# (Implementation for downloading model weights...)
from tabpfn.model.loading import resolve_model_path, download_model
model_path, _, model_name, which = resolve_model_path(model_name, "regressor")
if not model_path.exists():
download_model(model_path, which, "v2", model_name)
class TabPFNClientInferenceEngine(BaseTabPFNInferenceEngine):
"""Handles client TabPFN: initializes the client and passes the correct class."""
def __init__(self, config: dict):
from tabpfn_client import init, TabPFNRegressor # Local import
# Unique setup for the client
init()
config_copy = deepcopy(config)
config_copy["tabpfn_internal"]["model_path"] = self._parse_model_name(
config_copy["tabpfn_internal"]["model_path"]
)
super().__init__(model_class=TabPFNRegressor, config=config_copy)
def _parse_model_name(self, model_name: str) -> str:
# (Implementation for parsing client model name...)
from tabpfn_client import TabPFNRegressor
available = TabPFNRegressor.list_available_models()
for m in available:
if m in model_name:
return m
raise ValueError(f"Model {model_name} not found. Available: {available}")
tabpfn_time_series/predictor.py
Outdated
TabPFNMode.MOCK: CPUParallelWorker, | ||
} | ||
|
||
inference_engine = inference_engine_mapping[tabpfn_mode]() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a way to overwrite this? Currently we still couldn't use our own implementation of another model / inference engine
@noahho I think Anyway, I got your "wishes" up there, let me do a few more iteration while using it myself for the evaluation. |
@liam-sbhoo yep I saw there was agnostic code in there. I just wondered if that code could go somewhere else or if the class was needed. It's definitely a step forward :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the time series prediction inference code to decouple model prediction logic from parallelization concerns, facilitating experimentation with different models beyond TabPFN.
- Extracts model prediction logic into new
InferenceEngine
abstraction with specialized adapters - Moves parallelization logic into dedicated
ParallelWorker
classes - Introduces a more flexible
TimeSeriesPredictor
class with factory methods
Reviewed Changes
Copilot reviewed 10 out of 12 changed files in this pull request and generated 5 comments.
Show a summary per file
File | Description |
---|---|
tests/test_predictor.py | Updates tests to reflect new class structure and removes mock mode testing |
tabpfn_time_series/worker/tabpfn_model_adapter.py | Implements TabPFN-specific model adapter for handling TabPFN prediction logic |
tabpfn_time_series/worker/parallel.py | New parallel processing framework supporting both CPU and GPU execution |
tabpfn_time_series/worker/model_adapter.py | Base model adapter interface and point prediction adapter implementation |
tabpfn_time_series/tabpfn_worker.py | Removes old TabPFN worker implementation (295 lines deleted) |
tabpfn_time_series/predictor.py | Refactors predictor classes to use new adapter pattern |
tabpfn_time_series/defaults.py | Simplifies default configuration structure |
tabpfn_time_series/init.py | Updates exports to reflect renamed constants |
pyproject.toml | Updates TabPFN dependency version and adds matplotlib |
gift_eval/tabpfn_ts_wrapper.py | Updates to use renamed quantile configuration constant |
Comments suppressed due to low confidence (1)
pyproject.toml:22
- Based on my knowledge cutoff in January 2025, tabpfn version 2.1.0 may not exist. The latest known version was around 2.0.x. Please verify this version exists before deployment.
"tabpfn>=2.1.0",
"""Test that predict method calls the worker's predict method""" | ||
# Create predictor and call predict | ||
predictor = TabPFNTimeSeriesPredictor(tabpfn_mode=TabPFNMode.MOCK) | ||
predictor = TabPFNTimeSeriesPredictor(tabpfn_mode=TabPFNMode.LOCAL) | ||
|
||
with self.assertRaises(ValueError): | ||
_ = predictor.predict(self.train_tsdf, self.test_tsdf) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test creates a LOCAL mode predictor but then expects it to raise a ValueError when calling predict(). This appears to be testing the wrong behavior - if LOCAL mode is not supported, the ValueError should occur during predictor initialization, not during prediction.
Copilot uses AI. Check for mistakes.
|
||
def __new__( | ||
cls, | ||
tabpfn_mode: TabPFNMode = TabPFNMode.LOCAL, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value for tabpfn_mode has changed from TabPFNMode.CLIENT to TabPFNMode.LOCAL. This is a breaking change that could affect existing users who relied on the previous default behavior.
tabpfn_mode: TabPFNMode = TabPFNMode.LOCAL, | |
tabpfn_mode: TabPFNMode = TabPFNMode.CLIENT, |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
@noahho I think you'll like it - you can now easily use different models (being TabPFN-variant or not) to perform prediction 😄 See example at: tabpfn-time-series/tests/test_predictor.py Lines 97 to 129 in 42d8204
Thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, seems really well engineered PR!
Unrelated but I just thought of this (feel free to ignore): It might pay off to create an examples folder where you keep usage examples, e.g. for using a custom model - I at least found that an easy way to show and understand core functionalities. You can also check in the colab notebook there to make sure you don't change it by mistake - for the tabpfn repo this turned out to be super useful! Link becomes sth like: https://colab.research.google.com/github/PriorLabs/TabPFN/blob/main/examples/notebooks/TabPFN_Demo_Local.ipynb
""" | ||
Given a TimeSeriesDataFrame (multiple time series), perform prediction on each time series individually. | ||
""" | ||
# class TimeSeriesPredictor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this code here and commented out? Often this kind of code quickly becomes out of date and suually retrieving from history in git works fine if you want to save this state.
|
||
def __new__( | ||
cls, | ||
tabpfn_mode: TabPFNMode = TabPFNMode.LOCAL, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
@liam-sbhoo should we merge this? |
Development
This pull request introduces significant updates to the
tabpfn_time_series
package, focusing on refactoring, improving modularity, and enhancing functionality for time series prediction. The most important changes include introducing a new modular worker and model adapter architecture, and refactoring theTabPFNTimeSeriesPredictor
class for better flexibility and extensibility.Predictor Class Refactoring:
TimeSeriesPredictor
: Added a new generic class that allows the use of sklearn-compatible models to be the regressor. Added factory methods:from_point_prediction_regressor
(for sklearn-compatible models) andfrom_tabpfn_family
(for TabPFN regressor family)TabPFNTimeSeriesPredictor
: Refactored this to inherit fromTimeSeriesPredictor
. This class is kept mainly for backward compatibility.tests/test_predictor.py
for examples!Modular Architecture Enhancements:
BaseModelAdapter
andPointPredictionModelAdapter
: Added a new abstraction layer intabpfn_time_series/worker/model_adapter.py
to support flexible integration with different model types, including point prediction and probabilistic models.tabpfn_worker.py
file and replaced it with a modular worker system (ParallelWorker
,CPUParallelWorker
,GPUParallelWorker
) to streamline parallel execution and GPU utilization. [1] [2]These changes collectively enhance the maintainability, extensibility, and performance of the
tabpfn_time_series
package, making it more robust for diverse time series prediction use cases.