Replace nan_to_num with where for CuPy#595
Conversation
|
The failing disassemble and SymPy (unrelated) tests were expected. I wasn't expecting the Numba tests to fail. |
38518ae to
42b138f
Compare
8ab72f1 to
9d7a648
Compare
7c1626a to
fc7f41f
Compare
Saransh-cpp
left a comment
There was a problem hiding this comment.
Even though the SymPy solution is not ideal, @ikrommyd's comment - #593 (comment) - suggests that CuPy releasing a fix in the future versions will not work well with older CUDA versions. We can easily revert back to the nan_to_num implementation once CuPy fixes this issue.
I'd leave it to the other maintainers to decide if this should go in.
src/vector/_compute/spatial/eta.py
Outdated
| posinf=inf, | ||
| neginf=-inf, | ||
| ) | ||
| return lib.where(z != 0, lib.arcsinh(z / lib.sqrt(x**2 + y**2)), z) * 1 |
There was a problem hiding this comment.
There is absolutely no issue in switching to this definition for the numerical backends, so this is not a hack.
There was a problem hiding this comment.
I think the only caveat this has is that for the JAX backend gradients could be wrongly propagated as NaNs. This is typically solved by using a double-where statement, see: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
There was a problem hiding this comment.
Thanks, @pfackeldey! TIL.
I have updated the PR with nested wheres.
(We do really need GPU/Dask/JAX integration tests for Scikit-HEP...)
src/vector/_lib.py
Outdated
| # this function is to handle exceptional values — we know that the "normal" values | ||
| # are in the second argument and the "exceptional" ones are in the third argument. | ||
| # TODO: remove once https://github.com/cupy/cupy/issues/9143 is fixed. | ||
| def where(self, val1: sympy.Expr, val2: sympy.Expr, val3: sympy.Expr) -> sympy.Expr: |
There was a problem hiding this comment.
This might be considered a hack/quick fix for the SymPy backend (as explained in the comment above). The only issue with this PR is that we will have to manually check that where is not used as a regular np.where in compute functions.
I’m basically suggesting that when cupy puts out a fix, it’s gonna naturally be only in the latest release and therefore you can’t expect people to always be on the latest release because for whatever reason they may not have a cuda version compatible with that release. This is a “safe” assumption I made. I don’t know the cupy/cuda support matrix by heart. My point was that we may not want to put all our faith on cupy in order for vector to work with the cuda backend if awkward array. |
updates: - [github.com/astral-sh/ruff-pre-commit: v0.11.12 → v0.11.13](astral-sh/ruff-pre-commit@v0.11.12...v0.11.13) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix: get codecov back up again * upload cov files * upload cov files * some debugging * style: pre-commit fixes * specify path * try XML report * use extend * all fixed * revert more debug statements --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
#608) Bumps the actions group with 1 update: [actions/attest-build-provenance](https://github.com/actions/attest-build-provenance). Updates `actions/attest-build-provenance` from 2.3.0 to 2.4.0 - [Release notes](https://github.com/actions/attest-build-provenance/releases) - [Changelog](https://github.com/actions/attest-build-provenance/blob/main/RELEASE.md) - [Commits](actions/attest-build-provenance@db473fd...e8998f9) --- updated-dependencies: - dependency-name: actions/attest-build-provenance dependency-version: 2.4.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: actions ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
* Update talks.md * add citation info in README * add citing information in docs * style: pre-commit fixes * uniformity * style: pre-commit fixes * plural --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #595 +/- ##
==========================================
+ Coverage 87.49% 87.58% +0.09%
==========================================
Files 95 95
Lines 12066 12057 -9
==========================================
+ Hits 10557 10560 +3
+ Misses 1509 1497 -12 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
I'm working for the actual fix on cupy in cupy/cupy#9240 but I think it's still good to merge this because it will be a while since this comes in a cupy release (+ you'd still need people to be on the latest cupy). |
|
pinging @pfackeldey and @ianna here |
ianna
left a comment
There was a problem hiding this comment.
@Saransh-cpp - thanks for fixing it! As it is a temporary fix, I think it would be better if there is an issue linked to the cupy one. Thanks!
|
Thanks for the review, @ianna! I've opened a new issue. This should be ready to review again. |
ianna
left a comment
There was a problem hiding this comment.
@Saransh-cpp - thank you for fixing it! Looks good to me!
Description
Fixes #593
XRef #615
Using
whereis okay for other backends, but it is an issue for the SymPy backend. The shimSympyLib.wherefunction is not a generic replacement fornp.where, instead, it will only work for this particular case (return the second argument (first value) as that has the normal values). Hence, we will have to manually check that no other compute function useswhereas actualnp.where.