Skip to content

Commit

Permalink
add @functions and have LearnAPI.functions() return accessors
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Dec 16, 2024
1 parent 8e8123a commit 6279b25
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 22 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ julia> ridge = Ridge(lambda=0.1)
Inspect available functionality:

```
julia> LearnAPI.functions(ridge)
(:(LearnAPI.fit), :(LearnAPI.learner), :(LearnAPI.strip), :(LearnAPI.obs),
:(LearnAPI.features), :(LearnAPI.target), :(LearnAPI.predict), :(LearnAPI.coefficients))
julia> @functions ridge
(fit, LearnAPI.learner, LearnAPI.strip, obs, LearnAPI.features, LearnAPI.target, predict, LearnAPI.coefficients
```

Train:
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ y = <some training target>
Xnew = <some test or production features>

# List LearnaAPI functions implemented for `forest`:
LearnAPI.functions(forest)
@functions forest

# Train:
model = fit(forest, X, y)
Expand Down
3 changes: 2 additions & 1 deletion docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,9 @@ minimal (but useless) implementation, see the implementation of `SmallLearner`
## Utilities

```@docs
@functions
LearnAPI.clone
LearnAPI.@trait
@trait
```

---
Expand Down
2 changes: 1 addition & 1 deletion src/LearnAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ include("accessor_functions.jl")
include("traits.jl")
include("clone.jl")

export @trait
export @trait, @functions
export fit, update, update_observations, update_features
export predict, transform, inverse_transform, obs

Expand Down
31 changes: 16 additions & 15 deletions src/accessor_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,23 +312,23 @@ function training_labels end

# :extras intentionally excluded:
const ACCESSOR_FUNCTIONS_WITHOUT_EXTRAS = (
learner,
coefficients,
intercept,
tree,
trees,
feature_names,
feature_importances,
training_labels,
training_losses,
training_predictions,
training_scores,
components,
:(LearnAPI.learner),
:(LearnAPI.coefficients),
:(LearnAPI.intercept),
:(LearnAPI.tree),
:(LearnAPI.trees),
:(LearnAPI.feature_names),
:(LearnAPI.feature_importances),
:(LearnAPI.training_labels),
:(LearnAPI.training_losses),
:(LearnAPI.training_predictions),
:(LearnAPI.training_scores),
:(LearnAPI.components),
)

const ACCESSOR_FUNCTIONS_WITHOUT_EXTRAS_LIST = join(
map(ACCESSOR_FUNCTIONS_WITHOUT_EXTRAS) do f
"[`LearnAPI.$f`](@ref)"
"[`$f`](@ref)"
end,
", ",
" and ",
Expand All @@ -354,11 +354,12 @@ $(DOC_IMPLEMENTED_METHODS(":(LearnAPI.training_labels)")).
"""
function extras end

const ACCESSOR_FUNCTIONS = (extras, ACCESSOR_FUNCTIONS_WITHOUT_EXTRAS...)
const ACCESSOR_FUNCTIONS =
(:(LearnAPI.extras), ACCESSOR_FUNCTIONS_WITHOUT_EXTRAS...)

const ACCESSOR_FUNCTIONS_LIST = join(
map(ACCESSOR_FUNCTIONS) do f
"[`LearnAPI.$f`](@ref)"
"[`$f`](@ref)"
end,
", ",
" and ",
Expand Down
35 changes: 34 additions & 1 deletion src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,18 @@ with `learner`, or an associated model (object returned by `fit(learner, ...)`,
first argument. Learner traits (methods for which `learner` is the *only* argument)
are excluded.
To return actual functions, instead of symbols, use [`@functions`](@ref)` learner`
instead.
The returned tuple may include expressions like `:(DecisionTree.print_tree)`, which
reference functions not owned by LearnAPI.jl.
The understanding is that `learner` is a LearnAPI-compliant object whenever the return
value is non-empty.
Do `LearnAPI.functions()` to list all possible elements of the return value owned by
LearnAPI.jl.
# Extended help
# New implementations
Expand Down Expand Up @@ -100,6 +106,7 @@ learner-specific ones. The LearnAPI.jl accessor functions are: $ACCESSOR_FUNCTIO
(`LearnAPI.strip` is always included).
"""
functions(::Any) = ()
functions() = (
:(LearnAPI.fit),
:(LearnAPI.learner),
Expand All @@ -114,8 +121,34 @@ functions() = (
:(LearnAPI.predict),
:(LearnAPI.transform),
:(LearnAPI.inverse_transform),
ACCESSOR_FUNCTIONS...,
)
functions(::Any) = ()

"""
@functions learner
Return a tuple of functions that can be meaningfully applied with `learner`, or an
associated model, as the first argument. An "associated model" is an object returned by
`fit(learner, ...)`. Learner traits (methods for which `learner` is the *only* argument)
are excluded.
```
julia> @functions my_feature_selector
(fit, LearnAPI.learner, strip, obs, transform)
```
New learner implementations should overload [`LearnAPI.functions`](@ref).
See also [`LearnAPI.functions`](@ref).
"""
macro functions(learner)
quote
exs = LearnAPI.functions(learner)
eval.(exs)
end |> esc
end

"""
LearnAPI.kinds_of_proxy(learner)
Expand Down

0 comments on commit 6279b25

Please sign in to comment.