-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Description
Describe the bug
This issue was discovered while reviewing #32867, but since it's not directly related, let's open a dedicated issue to avoid derailing the original discussion.
As can be seen in the plots below, passing multiclass_colors="gist_rainbow" has no impact when response_method="predict" but I don't see any reason why it wouldn't be the case.
I think we should refactor the code to have consistent colormap configuration for all the response methods (in particular for the multiclass case).
Steps/Code to Reproduce
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.inspection import DecisionBoundaryDisplay
data = np.array(
[
[-1, -1],
[-2, -1],
[1, 1],
[2, 1],
[2, 2],
[3, 2],
[3, 3],
]
)
target = np.arange(data.shape[0])
clf = LogisticRegression().fit(data, target)
cmap = "gist_rainbow"
_, axes = plt.subplots(nrows=3, ncols=3, figsize=(12, 12), constrained_layout=True)
for plot_method_idx, plot_method in enumerate(["contourf", "contour", "pcolormesh"]):
for response_method_idx, response_method in enumerate(
["predict_proba", "decision_function", "predict"]
):
ax = axes[plot_method_idx, response_method_idx]
display = DecisionBoundaryDisplay.from_estimator(
clf,
data,
multiclass_colors=cmap,
response_method=response_method,
plot_method=plot_method,
ax=ax,
alpha=0.5,
)
ax.scatter(
data[:, 0],
data[:, 1],
c=target.astype(int),
edgecolors="black",
cmap=cmap,
)
ax.set_title(f"plot_method={plot_method}\nresponse_method={response_method}")Expected Results
I would expect the colored areas of the right most columns to also use the "gist_rainbow" colormap as the other columns.
Actual Results
The right most column seems to ignore the multiclass_colors argument and falls back to the default "viridis" colormap which is not very well suited for multiclass problems with a large number of classes.
Versions
System:
python: 3.13.7 | packaged by conda-forge | (main, Sep 3 2025, 14:24:46) [Clang 19.1.7 ]
executable: /Users/ogrisel/miniforge3/envs/dev/bin/python3.13
machine: macOS-15.6.1-arm64-arm-64bit-Mach-O
Python dependencies:
sklearn: 1.9.dev0
pip: 25.2
setuptools: 80.9.0
numpy: 2.3.3
scipy: 1.16.3
Cython: 3.1.4
pandas: 3.0.0.dev0+2566.g2bb3fef887
matplotlib: 3.10.6
joblib: 1.5.2
threadpoolctl: 3.6.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 10
prefix: libopenblas
filepath: /Users/ogrisel/miniforge3/envs/dev/lib/libopenblas.0.dylib
version: 0.3.30
threading_layer: openmp
architecture: VORTEX
user_api: openmp
internal_api: openmp
num_threads: 10
prefix: libomp
filepath: /Users/ogrisel/miniforge3/envs/dev/lib/libomp.dylib
version: None