Skip to content

Conversation

liam-sbhoo
Copy link
Collaborator

@liam-sbhoo liam-sbhoo commented Jul 28, 2025

Development

This pull request introduces significant updates to the tabpfn_time_series package, focusing on refactoring, improving modularity, and enhancing functionality for time series prediction. The most important changes include introducing a new modular worker and model adapter architecture, and refactoring the TabPFNTimeSeriesPredictor class for better flexibility and extensibility.

Predictor Class Refactoring:

  • Introduced TimeSeriesPredictor: Added a new generic class that allows the use of sklearn-compatible models to be the regressor. Added factory methods: from_point_prediction_regressor (for sklearn-compatible models) and from_tabpfn_family (for TabPFN regressor family)
  • Refactored TabPFNTimeSeriesPredictor: Refactored this to inherit from TimeSeriesPredictor. This class is kept mainly for backward compatibility.
  • see tests/test_predictor.py for examples!

Modular Architecture Enhancements:

  • Introduced BaseModelAdapter and PointPredictionModelAdapter: Added a new abstraction layer in tabpfn_time_series/worker/model_adapter.py to support flexible integration with different model types, including point prediction and probabilistic models.
  • Refactored worker architecture: Removed the legacy tabpfn_worker.py file and replaced it with a modular worker system (ParallelWorker, CPUParallelWorker, GPUParallelWorker) to streamline parallel execution and GPU utilization. [1] [2]

These changes collectively enhance the maintainability, extensibility, and performance of the tabpfn_time_series package, making it more robust for diverse time series prediction use cases.

- Introduced a new `TimeSeriesPredictor` class to handle prediction logic, separating it from the TabPFN-specific implementation.
- Updated `TabPFNTimeSeriesPredictor` to inherit from `TimeSeriesPredictor`, allowing for a more flexible inference engine setup.
- Removed the `tabpfn_worker.py` file and replaced it with a modular worker system, including `ParallelWorker`, `CPUParallelWorker`, and `GPUParallelWorker`.
- Added new base inference engine classes for better abstraction and maintainability.
- Updated tests to reflect changes in the predictor's structure and ensure proper initialization and functionality.
@liam-sbhoo liam-sbhoo requested review from noahho and Copilot July 28, 2025 13:01
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @liam-sbhoo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the time series prediction inference code by separating the core model prediction logic from the parallel execution strategy. This architectural change enhances modularity and simplifies the integration of alternative prediction models in the future, making the codebase more flexible and easier to extend.

Highlights

  • Dependency Updates: The pyproject.toml file has been updated to specify tabpfn version 2.1.0 and to include matplotlib>=3.10.3 as a new development dependency.
  • Core Code Restructuring: The core TabPFNTimeSeriesPredictor class has been refactored to inherit from a new TimeSeriesPredictor base class. This change decouples the model prediction logic (now handled by InferenceEngine) from the parallelization strategy (now handled by ParallelWorker), improving modularity and maintainability.
  • Introduction of New Abstractions: New abstract base classes, InferenceEngine and ParallelWorker, have been introduced to define clear interfaces for model prediction and parallel execution. Concrete implementations like GPUParallelWorker, CPUParallelWorker, and various TabPFNInferenceEngine types have been added to support different operational modes.
  • Removal of Legacy Code: The monolithic tabpfn_time_series/tabpfn_worker.py file, which previously contained combined prediction and parallelization logic, has been removed. Its responsibilities are now distributed among the newly introduced InferenceEngine and ParallelWorker classes.
  • Test Suite Adjustments: Existing tests in tests/test_predictor.py have been updated to align with the refactored class structure, specifically changing assertions from predictor.tabpfn_worker to predictor.worker.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copilot

This comment was marked as outdated.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request does a great job of refactoring the time series prediction code by decoupling the model prediction logic into an InferenceEngine and the parallelization into a ParallelWorker. This significantly improves modularity and will make it easier to experiment with different models in the future. The overall structure is sound. I've identified a few issues, including a critical bug in the mock inference engine, a regression in handling constant-valued time series, and a loss of functionality for configuring worker parallelism. I've also included some suggestions to improve code quality and maintainability.

Copy link
Contributor

@noahho noahho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, this nicely decouples everything. Two points:

  1. I'm still struggling to understand if the IfnerenceEngine could be dropped as well? Is the code for downloading models (should be done automatically) and renaming for client (could we use different settings) really needed?

  2. In current setup, to make models even easier to replace, what do you think about the below change? With this I could use RandomforestTaBPFN or AutoTabPFN without implementing another inference engine.

# In tabpfn_time_series/worker/tabpfn_inference_engine.py

import numpy as np
from copy import deepcopy
from typing import Type

# Keep the base class for the template method pattern
from tabpfn_time_series.worker.base_inference_engine import InferenceEngine

# --- Helper Function (Shared Logic) ---

def process_tabpfn_pred_output(
    pred_output: dict, output_selection: str, quantiles: list[float | str]
) -> dict[str, np.ndarray]:
    """Translates raw TabPFN output to the standardized dictionary format."""
    result = {"target": pred_output[output_selection]}
    result.update({q: q_pred for q, q_pred in zip(quantiles, pred_output["quantiles"])})
    return result

# --- New Shared Base Class for TabPFN Models ---

class BaseTabPFNInferenceEngine(InferenceEngine):
    """
    A base engine that handles the common fit/predict logic for any
    sklearn-compatible TabPFN regressor.
    """
    def __init__(self, model_class: Type, config: dict):
        super().__init__(config)
        self.model_class = model_class

    def _predict(
        self, train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, quantiles: list
    ) -> dict[str, np.ndarray]:
        """
        This concrete implementation now serves both Local and Client modes.
        """
        # 1. Instantiate the specific model class (Local or Client)
        model = self.model_class(**self.config["tabpfn_internal"], random_state=0)

        # 2. Fit and predict
        model.fit(train_X, train_y)
        pred_output = model.predict(test_X, output_type="main")

        # 3. Process the output
        return process_tabpfn_pred_output(
            pred_output, self.config["tabpfn_output_selection"], quantiles
        )

# --- Final, Lean Implementations ---

class LocalTabPFNInferenceEngine(BaseTabPFNInferenceEngine):
    """Handles local TabPFN: downloads the model and passes the correct class."""
    def __init__(self, config: dict):
        from tabpfn import TabPFNRegressor # Local import
        self._download_model(config["tabpfn_internal"]["model_path"])
        super().__init__(model_class=TabPFNRegressor, config=config)

    @staticmethod
    def _download_model(model_name: str):
        # (Implementation for downloading model weights...)
        from tabpfn.model.loading import resolve_model_path, download_model
        model_path, _, model_name, which = resolve_model_path(model_name, "regressor")
        if not model_path.exists():
            download_model(model_path, which, "v2", model_name)


class TabPFNClientInferenceEngine(BaseTabPFNInferenceEngine):
    """Handles client TabPFN: initializes the client and passes the correct class."""
    def __init__(self, config: dict):
        from tabpfn_client import init, TabPFNRegressor # Local import
        
        # Unique setup for the client
        init()
        config_copy = deepcopy(config)
        config_copy["tabpfn_internal"]["model_path"] = self._parse_model_name(
            config_copy["tabpfn_internal"]["model_path"]
        )
        
        super().__init__(model_class=TabPFNRegressor, config=config_copy)

    def _parse_model_name(self, model_name: str) -> str:
        # (Implementation for parsing client model name...)
        from tabpfn_client import TabPFNRegressor
        available = TabPFNRegressor.list_available_models()
        for m in available:
            if m in model_name:
                return m
        raise ValueError(f"Model {model_name} not found. Available: {available}")

TabPFNMode.MOCK: CPUParallelWorker,
}

inference_engine = inference_engine_mapping[tabpfn_mode]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a way to overwrite this? Currently we still couldn't use our own implementation of another model / inference engine

@liam-sbhoo liam-sbhoo marked this pull request as draft July 28, 2025 22:09
@liam-sbhoo
Copy link
Collaborator Author

liam-sbhoo commented Jul 28, 2025

@noahho
Thanks for the review - it makes sense and I agree the code could be simpler.

I think InferenceEngine might be still relevant because, ideally that part of the code can be model-agnostic (I'm in the progress evaluating other TFM too, and thus the changes).

Anyway, I got your "wishes" up there, let me do a few more iteration while using it myself for the evaluation.

@noahho
Copy link
Contributor

noahho commented Jul 30, 2025

@liam-sbhoo yep I saw there was agnostic code in there. I just wondered if that code could go somewhere else or if the class was needed. It's definitely a step forward :-)

@liam-sbhoo liam-sbhoo requested a review from Copilot August 1, 2025 16:16
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the time series prediction inference code to decouple model prediction logic from parallelization concerns, facilitating experimentation with different models beyond TabPFN.

  • Extracts model prediction logic into new InferenceEngine abstraction with specialized adapters
  • Moves parallelization logic into dedicated ParallelWorker classes
  • Introduces a more flexible TimeSeriesPredictor class with factory methods

Reviewed Changes

Copilot reviewed 10 out of 12 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/test_predictor.py Updates tests to reflect new class structure and removes mock mode testing
tabpfn_time_series/worker/tabpfn_model_adapter.py Implements TabPFN-specific model adapter for handling TabPFN prediction logic
tabpfn_time_series/worker/parallel.py New parallel processing framework supporting both CPU and GPU execution
tabpfn_time_series/worker/model_adapter.py Base model adapter interface and point prediction adapter implementation
tabpfn_time_series/tabpfn_worker.py Removes old TabPFN worker implementation (295 lines deleted)
tabpfn_time_series/predictor.py Refactors predictor classes to use new adapter pattern
tabpfn_time_series/defaults.py Simplifies default configuration structure
tabpfn_time_series/init.py Updates exports to reflect renamed constants
pyproject.toml Updates TabPFN dependency version and adds matplotlib
gift_eval/tabpfn_ts_wrapper.py Updates to use renamed quantile configuration constant
Comments suppressed due to low confidence (1)

pyproject.toml:22

  • Based on my knowledge cutoff in January 2025, tabpfn version 2.1.0 may not exist. The latest known version was around 2.0.x. Please verify this version exists before deployment.
    "tabpfn>=2.1.0",

Comment on lines +89 to +96
"""Test that predict method calls the worker's predict method"""
# Create predictor and call predict
predictor = TabPFNTimeSeriesPredictor(tabpfn_mode=TabPFNMode.MOCK)
predictor = TabPFNTimeSeriesPredictor(tabpfn_mode=TabPFNMode.LOCAL)

with self.assertRaises(ValueError):
_ = predictor.predict(self.train_tsdf, self.test_tsdf)


Copy link
Preview

Copilot AI Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test creates a LOCAL mode predictor but then expects it to raise a ValueError when calling predict(). This appears to be testing the wrong behavior - if LOCAL mode is not supported, the ValueError should occur during predictor initialization, not during prediction.

Copilot uses AI. Check for mistakes.


def __new__(
cls,
tabpfn_mode: TabPFNMode = TabPFNMode.LOCAL,
Copy link
Preview

Copilot AI Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value for tabpfn_mode has changed from TabPFNMode.CLIENT to TabPFNMode.LOCAL. This is a breaking change that could affect existing users who relied on the previous default behavior.

Suggested change
tabpfn_mode: TabPFNMode = TabPFNMode.LOCAL,
tabpfn_mode: TabPFNMode = TabPFNMode.CLIENT,

Copilot uses AI. Check for mistakes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

@liam-sbhoo liam-sbhoo requested a review from noahho August 1, 2025 16:23
@liam-sbhoo
Copy link
Collaborator Author

liam-sbhoo commented Aug 1, 2025

@noahho
I think this implementation is in a better state now!

I think you'll like it - you can now easily use different models (being TabPFN-variant or not) to perform prediction 😄

See example at:

class TestTimeSeriesPredictor(unittest.TestCase):
def setUp(self):
self.train_tsdf, self.test_tsdf = create_test_data()
if os.getenv("GITHUB_ACTIONS"):
setup_github_actions_tabpfn_client()
def test_from_tabpfn_family(self):
from tabpfn_client import TabPFNRegressor as TabPFNClientRegressor
predictor = TimeSeriesPredictor.from_tabpfn_family(
tabpfn_class=TabPFNClientRegressor,
tabpfn_config={"n_estimators": 1},
tabpfn_output_selection="median",
)
result = predictor.predict(self.train_tsdf, self.test_tsdf)
assert result is not None
def test_from_point_prediction_regressor(self):
from sklearn.ensemble import RandomForestRegressor
predictor = TimeSeriesPredictor.from_point_prediction_regressor(
regressor_class=RandomForestRegressor,
regressor_config={"n_estimators": 1},
regressor_fit_config={
# "...": "...",
},
regressor_predict_config={
# "...": "...",
},
)
result = predictor.predict(self.train_tsdf, self.test_tsdf)
assert result is not None

Thank you.

@liam-sbhoo liam-sbhoo marked this pull request as ready for review August 1, 2025 16:28
Copy link
Contributor

@noahho noahho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, seems really well engineered PR!

Unrelated but I just thought of this (feel free to ignore): It might pay off to create an examples folder where you keep usage examples, e.g. for using a custom model - I at least found that an easy way to show and understand core functionalities. You can also check in the colab notebook there to make sure you don't change it by mistake - for the tabpfn repo this turned out to be super useful! Link becomes sth like: https://colab.research.google.com/github/PriorLabs/TabPFN/blob/main/examples/notebooks/TabPFN_Demo_Local.ipynb

"""
Given a TimeSeriesDataFrame (multiple time series), perform prediction on each time series individually.
"""
# class TimeSeriesPredictor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this code here and commented out? Often this kind of code quickly becomes out of date and suually retrieving from history in git works fine if you want to save this state.


def __new__(
cls,
tabpfn_mode: TabPFNMode = TabPFNMode.LOCAL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

@noahho
Copy link
Contributor

noahho commented Sep 4, 2025

@liam-sbhoo should we merge this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants