Skip to content

DecisionBoundaryDisplay with response_method="predict" has inconsistent handling for the colormap in the multiclass case #32872

@ogrisel

Description

@ogrisel

Describe the bug

This issue was discovered while reviewing #32867, but since it's not directly related, let's open a dedicated issue to avoid derailing the original discussion.

As can be seen in the plots below, passing multiclass_colors="gist_rainbow" has no impact when response_method="predict" but I don't see any reason why it wouldn't be the case.

Image

I think we should refactor the code to have consistent colormap configuration for all the response methods (in particular for the multiclass case).

Steps/Code to Reproduce

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.inspection import DecisionBoundaryDisplay


data = np.array(
    [
        [-1, -1],
        [-2, -1],
        [1, 1],
        [2, 1],
        [2, 2],
        [3, 2],
        [3, 3],
    ]
)
target = np.arange(data.shape[0])
clf = LogisticRegression().fit(data, target)

cmap = "gist_rainbow"
_, axes = plt.subplots(nrows=3, ncols=3, figsize=(12, 12), constrained_layout=True)
for plot_method_idx, plot_method in enumerate(["contourf", "contour", "pcolormesh"]):
    for response_method_idx, response_method in enumerate(
        ["predict_proba", "decision_function", "predict"]
    ):
        ax = axes[plot_method_idx, response_method_idx]
        display = DecisionBoundaryDisplay.from_estimator(
            clf,
            data,
            multiclass_colors=cmap,
            response_method=response_method,
            plot_method=plot_method,
            ax=ax,
            alpha=0.5,
        )
        ax.scatter(
            data[:, 0],
            data[:, 1],
            c=target.astype(int),
            edgecolors="black",
            cmap=cmap,
        )
        ax.set_title(f"plot_method={plot_method}\nresponse_method={response_method}")

Expected Results

I would expect the colored areas of the right most columns to also use the "gist_rainbow" colormap as the other columns.

Actual Results

The right most column seems to ignore the multiclass_colors argument and falls back to the default "viridis" colormap which is not very well suited for multiclass problems with a large number of classes.

Versions

System:
    python: 3.13.7 | packaged by conda-forge | (main, Sep  3 2025, 14:24:46) [Clang 19.1.7 ]
executable: /Users/ogrisel/miniforge3/envs/dev/bin/python3.13
   machine: macOS-15.6.1-arm64-arm-64bit-Mach-O

Python dependencies:
      sklearn: 1.9.dev0
          pip: 25.2
   setuptools: 80.9.0
        numpy: 2.3.3
        scipy: 1.16.3
       Cython: 3.1.4
       pandas: 3.0.0.dev0+2566.g2bb3fef887
   matplotlib: 3.10.6
       joblib: 1.5.2
threadpoolctl: 3.6.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 10
         prefix: libopenblas
       filepath: /Users/ogrisel/miniforge3/envs/dev/lib/libopenblas.0.dylib
        version: 0.3.30
threading_layer: openmp
   architecture: VORTEX

       user_api: openmp
   internal_api: openmp
    num_threads: 10
         prefix: libomp
       filepath: /Users/ogrisel/miniforge3/envs/dev/lib/libomp.dylib
        version: None

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions