Skip to content

DecisionBoundaryDisplay.from_estimator only displays up to 7 distinct colours #32866

@ThexXTURBOXx

Description

@ThexXTURBOXx

Describe the bug

I was trying to use DecisionBoundaryDisplay.from_estimator to display different regions classified by NuSVC:

Image

My expectation was to find 11, but by visually counting, I was only able to find 10.
Also, some colours were obviously duplicated in this figure.
This is how it actually should look like:

Image

Proposed fix

I guess that this will be fixed at some point by matplotlib/matplotlib@0eadaf0. However, this fix is not yet available in the stable versions of matplotlib.
Manually changing the "default"/hard-wired value here fixes the problem.
The best fix, thus, is probably for scikit-learn to specify the levels parameter on its own here:

self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs)

That line should look something like this:

self.surface_ = plot_func(self.xx0, self.xx1, self.response, <levels>, **kwargs)

where <levels> is replaced by the appropriate amount of levels.

The case when self.response.ndim == 3 may also be affected, but I have not yet tried out that one

Steps/Code to Reproduce

import numpy as np
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.svm import NuSVC
import matplotlib.pyplot as plt

# List of some 11 distinct points
pts = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1],
                [2, 2], [3, 2], [3, 3], [4, 3],
                [4, 4], [5, 4], [5, 5]])
svc = NuSVC(nu=0.1)
svc.fit(pts, [1,2,3,4,5,6,7,8,9,10,11])
DecisionBoundaryDisplay.from_estimator(
    svc,
    pts,
    response_method="predict",
    cmap="gist_rainbow"
)
plt.show()

Expected Results

Image

Actual Results

Image

Versions

System:
    python: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct  4 2024, 13:27:36) [GCC 11.2.0]
executable: /home/nico/miniconda3/envs/pca-assessment/bin/python
   machine: Linux-6.14.0-36-generic-x86_64-with-glibc2.39

Python dependencies:
      sklearn: 1.7.2
          pip: 24.3.1
   setuptools: 72.1.0
        numpy: 2.3.3
        scipy: 1.16.2
       Cython: None
       pandas: 2.3.2
   matplotlib: 3.10.6
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 16
         prefix: libscipy_openblas
       filepath: /home/nico/miniconda3/envs/pca-assessment/lib/python3.12/site-packages/numpy.libs/libscipy_openblas64_-8fb3d286.so
        version: 0.3.30
threading_layer: pthreads
   architecture: Haswell

       user_api: blas
   internal_api: mkl
    num_threads: 12
         prefix: libmkl_rt
       filepath: /home/nico/miniconda3/envs/pca-assessment/lib/libmkl_rt.so.2
        version: 2025.0-Product
threading_layer: gnu

       user_api: openmp
   internal_api: openmp
    num_threads: 16
         prefix: libgomp
       filepath: /home/nico/miniconda3/envs/pca-assessment/lib/libgomp.so.1.0.0
        version: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions