Skip to content

fix: jax backend tolist for tracers in logging #2580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

pfackeldey
Copy link

@pfackeldey pfackeldey commented Mar 21, 2025

Description

This PR closes #1422

This failure is encountered during tracing time. JAX can not convert a tracer during .tolist (because it has no data to put in a list) which is why it fails with a different error. This PR checks if we're currently in tracing time and instead returns the tracer (or rather its abstract value for a little nicer representation).

The new error looks as expected like this now:

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.simplemodels.uncorrelated_background([10], [15], [5])
>>> pyhf.infer.mle.fit([12.5], m)
>>> ... InvalidPdfData: "eval failed as data has len 1 but 2 was expected"

The logging will print:

"Eval failed for data ShapedArray(float64[1]) pars: ShapedArray(float64[2])"

This is the most information available during tracing time (shape and dtype), there's not much more to give to the user.

Checklist Before Requesting Reviewer

  • Tests are passing
  • "WIP" removed from the title of the pull request
  • Selected an Assignee for the PR to be responsible for the log summary

Before Merging

For the PR Assignees:

  • Summarize commit messages into a comprehensive review of the PR

@kratsg
Copy link
Contributor

kratsg commented Mar 21, 2025

this likely needs #2566 in

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fix A bug fix
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Raise better error message for pyhf.exceptions.InvalidPdfData for JAX backend
3 participants