-
Notifications
You must be signed in to change notification settings - Fork 1.8k
SIMBA Improvements #8077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
SIMBA Improvements #8077
Conversation
|
||
def parse_value(value, annotation): | ||
annotation = _strip_optional(annotation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to allow support for Optional fields (i.e. where a field could be None, or str), which is the case for KIE and was throwing errors before.
temperature: float = 0.0, | ||
max_tokens: int = 1000, | ||
temperature: Optional[float] = None, | ||
max_tokens: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updating defaults, which were throwing errors for reasoning models. Now, instead of defaulting to a temp of 0.0 and max_tokens 1000 (and erroring out automatically for o3mini), we are setting temp and max_tokens based on whether the model is a reasoning model or not. If the user has intentionally set one of the values to something the reasoning model can't handle (i.e. temperature=0.7), then we will still throw an error.
@@ -41,6 +44,8 @@ def __init__( | |||
self.num_candidates = num_candidates | |||
self.max_steps = max_steps | |||
self.max_demos = max_demos |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding support for prompt / teacher models
@@ -310,7 +316,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]): | |||
trial_logs[idx_prog-1]["train_score"] = avg_score | |||
|
|||
best_idx = scores.index(max(scores)) if scores else 0 | |||
best_program = candidate_programs[best_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixing max recursion depth error
# Check to see if our model is a reasoning model, which means temp must stay as 1.0 | ||
model_family = lm.model.split("/")[-1].lower() if "/" in lm.model else lm.model.lower() | ||
model_pattern = re.match(r"^o([13])(?:-mini)?", model_family) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function has been updated to:
- add support for teacher model (used for 1 of the N trajectories)
- add support for reasoning models by varying the seed
@@ -28,30 +45,46 @@ def wrapped_program(example): | |||
print(e) | |||
trace = dspy.settings.trace.copy() | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to handle additional metric metadata in addition to the score. To do this, we check if the output from the metric is a float or int (in which case we use it as the score) or a dspy.Prediction object
@@ -116,12 +149,16 @@ def append_a_rule(bucket, system, **kwargs): | |||
worse_program_outputs=dict(bad["prediction"] or {}), | |||
worse_reward_value=bad["score"], | |||
better_reward_value=good["score"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again, adding additional metric metadata here
|
||
trace = bucket[0]["trace"] | ||
good = bucket[0] | ||
trace = good["trace"] | ||
name2demo = {} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double checking that the demo we're appending is not below the 10th percentile of scores
This PR makes the following updates:
Handled here, but merged sooner in another commit: