Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add nanobind support
  • Loading branch information
Matthias Wittgen committed Mar 10, 2025
commit 4ecb321c3195c0517455356d7ac82490f9e2486a
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
working-directory: ${{github.workspace}}/build
shell: bash -l {0}
# Execute the build. You can specify a specific target with "--target <NAME>"
run: cmake --build . --config $BUILD_TYPE
run: cmake --build . --verbose --config $BUILD_TYPE

- name: Test
working-directory: ${{github.workspace}}/build
Expand Down
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ set (CMAKE_CXX_STANDARD 11)

option(NDARRAY_TEST "Enable tests?" ON)
option(NDARRAY_EIGEN "Enable Eigen tests?" ON)
option(NDARRAY_FFTW "Enable FFTW tests?" ON)
option(NDARRAY_PYBIND11 "Enable Pybind11 tests?" OFF)
option(NDARRAY_FFTW "Enable FFTW tests?" Off)
option(NDARRAY_PYBIND11 "Enable Pybind11 tests?" On)
option(NDARRAY_NANOBIND "Enable Nanobind tests?" On)

add_subdirectory(include)

Expand Down
1 change: 1 addition & 0 deletions etc/conda-forge-testing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- fftw
- numpy
- pybind11
- nanobind
- c-compiler
- eigen
- cmake
Expand Down
273 changes: 273 additions & 0 deletions include/ndarray/nanobind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
/*
* LSST Data Management System
* Copyright 2008-2016 AURA/LSST.
*
* This product includes software developed by the
* LSST Project (http://www.lsst.org/).
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the LSST License Statement and
* the GNU General Public License along with this program. If not,
* see <https://www.lsstcorp.org/LegalNotices/>.
*/

#ifndef NDARRAY_nanobind_h_INCLUDED
#define NDARRAY_nanobind_h_INCLUDED

/**
* @file ndarray/nanobind.h
* @brief Public header file for pybind11-based Python support.
*
* \warning Both the Numpy C-API headers "arrayobject.h" and
* "ufuncobject.h" must be included before ndarray/python.hpp
* or any of the files in ndarray/python.
*
* \note This file is not included by the main "ndarray.h" header file.
*/

/** \defgroup ndarrayPythonGroup Python Support
*
* The ndarray Python support module provides conversion
* functions between ndarray objects, notably Array and
* Vector, and Python Numpy objects.
*/

#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include "ndarray.h"
#include "ndarray/eigen.h"
#include "ndarray/Array.h"

#include <typeinfo>

namespace nb = nanobind;
namespace ndarray {
namespace detail {

inline void destroyCapsule(PyObject *p) {
void *m = PyCapsule_GetPointer(p, "ndarray.Manager");
Manager::Ptr *b = reinterpret_cast<Manager::Ptr *>(m);
delete b;
}

} // namespace ndarray::detail

inline PyObject *makePyManager(Manager::Ptr const &m) {
return PyCapsule_New(
new Manager::Ptr(m),
"ndarray.Manager",
detail::destroyCapsule
);
}

template<typename T, int N, int C>
struct
#ifdef __GNUG__
// pybind11 hides all symbols in its namespace only when this is set,
// and in that case we should hide these classes too.
__attribute__((visibility("hidden")))
#endif
NanobindHelper {
};

} // namespace ndarray
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)
template<typename T, int N, int C>
struct type_caster<::ndarray::Array<T,N,C>> {
using Wrapper = std::remove_const_t<nb::ndarray<nb::numpy, T>>;
using ArrayType = nb::ndarray<nb::numpy, typename std::remove_const_t<T>> ;
using Array = std::conditional_t<std::is_const_v<T>, const ArrayType, ArrayType>;
using Element = typename std::remove_const_t<T>;
static constexpr bool isConst = std::is_const<Element>::value;

using Value = ::ndarray::Array<T,N,C>;
static constexpr auto Name = const_name("ndarray");
template<typename T_> using Cast = movable_cast_t<T_>;

static handle from_cpp(Value *p, rv_policy policy, cleanup_list *list) {
if (!p)return none().release();
return from_cpp(*p, policy, list);
}

bool init(nb::handle src, cleanup_list *cleanup) {
isNone = src.is_none();
if (isNone) {
return true;
}

int64_t shape[N];
ndarray_config config;
config.shape = shape;
wrapper = Wrapper(ndarray_import(
src.ptr(), &config, true, cleanup));
return wrapper.is_valid();
}

bool check() const {
if (isNone) {
return true;
}

if (wrapper.ndim() != N) {
return false;
}
if(wrapper.dtype().bits != sizeof(Element) * 8) {
return false;
}
switch(dlpack::dtype_code(wrapper.dtype().code)) {
case dlpack::dtype_code::Float:
if(!std::is_floating_point_v<Element>) return false;
break;
case dlpack::dtype_code::Int:
if(!(std::is_signed_v<Element> && std::is_integral_v<Element>)) return false;
break;
case dlpack::dtype_code::UInt:
if(!(std::is_unsigned_v<Element> && std::is_integral_v<Element>)) return false;
break;
case dlpack::dtype_code::Bool:
if(!std::is_same_v<Element, bool>) return false;
break;
default:
return false;
}

//if (!isConst && !wrapper.writeable()) {
// return false;
//}

int64_t const * shape = wrapper.shape_ptr();
int64_t const * strides = wrapper.stride_ptr();
size_t const itemsize = wrapper.itemsize();
if (C > 0) {
// If the shape is zero in any dimension, we don't
// worry about the strides.
for (int i = 0; i < C; ++i) {
if (shape[N-i-1] == 0) {
return true;
}
}

int64_t requiredStride = 1;//itemsize;
for (int i = 0; i < C; ++i) {
if (strides[N-i-1] != requiredStride) {
return false;
}
requiredStride *= shape[N-i-1];
}
} else if (C < 0) {
// If the shape is zero in any dimension, we don't
// worry about the strides.
for (int i = 0; i < -C; ++i) {
if (shape[i] == 0) {
return true;
}
}
size_t requiredStride = itemsize;
for (int i = 0; i < -C; ++i) {
if (strides[i] != requiredStride) {
return false;
}
requiredStride *= shape[i];
}
}
return true;
}

Value convert() const {
if (isNone) {
return Value();
}

//if (!wrapper.dtype().attr("isnative")) {
// throw nb::type_error("Only arrays with native byteorder can be converted to C++.");
//}

::ndarray::Vector<::ndarray::Size,N> nShape;
::ndarray::Vector<::ndarray::Offset,N> nStrides;
int64_t const * pShape = wrapper.shape_ptr();
int64_t const * pStrides = wrapper.stride_ptr();
size_t itemsize = wrapper.itemsize();
for (int i = 0; i < N; ++i) {
nShape[i] = pShape[i];
nStrides[i] = pStrides[i];
}

auto *p = const_cast<Element*>(wrapper.data());

return Value (
::ndarray::external(const_cast<Element*>(wrapper.data()),
nShape, nStrides, wrapper)
);
}

void set_value() {
value = convert();
}

explicit operator Value * () {
if (isNone) {
return nullptr;
} else {
set_value();
return &value;
}
}

explicit operator Value &() {
set_value();
return (Value &) value;
}

explicit operator Value &&() {
set_value();
return (Value &&) value;
}


bool from_python(nb::handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
bool result = init(src, cleanup) && check();
return result;
}
static nb::handle from_cpp(const ::ndarray::Array<T, N, C> &src, rv_policy policy,
cleanup_list *cleanup) noexcept {
using ArrayType = nb::ndarray<nb::numpy, typename std::remove_const_t<T>> ;
using Array = std::conditional_t<std::is_const_v<T>, const ArrayType, ArrayType>;
using Element = typename std::remove_const_t<T>;
::ndarray::Vector<::ndarray::Size,N> nShape = src.getShape();
::ndarray::Vector<::ndarray::Offset,N> nStrides = src.getStrides();
std::vector<size_t> pShape(N);
std::vector<int64_t> pStrides(N);
for (int i = 0; i < N; ++i) {
pShape[i] = nShape[i];
pStrides[i] = nStrides[i];
}
nb::object base = nb::object();
if (src.getManager()) {
base = nb::steal<nb::object>(::ndarray::makePyManager(src.getManager()));
}
Array array((Element*)src.getData(), N, pShape.data(), base, pStrides.data());

nb::handle result = ndarray_export(array.handle(), nb::numpy::value, policy, cleanup);
if (std::is_const_v<T>) {
result.attr("flags")["WRITEABLE"] = false;
}
return result;
}
private:
bool isNone = false;
Value value;
Wrapper wrapper = Wrapper();
};
NAMESPACE_END(detail);
NAMESPACE_END(NB_NAMESPACE);
#endif
11 changes: 7 additions & 4 deletions include/ndarray/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ Pybind11Helper {
if (src.getManager()) {
base = pybind11::reinterpret_steal<pybind11::object>(ndarray::makePyManager(src.getManager()));
}

Wrapper result(pShape, pStrides, src.getData(), base);
if (std::is_const<Element>::value) {
result.attr("flags")["WRITEABLE"] = false;
Expand All @@ -220,9 +221,9 @@ class type_caster< ndarray::Array<T,N,C> > {
using Helper = ndarray::Pybind11Helper<T,N,C>;
public:

bool load(handle src, bool) {
bool load(handle src, bool) {
return _helper.init(src) && _helper.check();
}
}

void set_value() {
_value = _helper.convert();
Expand All @@ -238,7 +239,7 @@ class type_caster< ndarray::Array<T,N,C> > {
return cast(*src, policy, parent);
}

operator ndarray::Array<T,N,C> * () {
explicit operator ndarray::Array<T,N,C> * () {
if (_helper.isNone) {
return nullptr;
} else {
Expand All @@ -247,7 +248,9 @@ class type_caster< ndarray::Array<T,N,C> > {
}
}

operator ndarray::Array<T,N,C> & () { set_value(); return _value; }
explicit operator ndarray::Array<T,N,C> & () {
set_value(); return _value;
}

template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;

Expand Down
31 changes: 29 additions & 2 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
# (4) Addition of the test executable via add_test.

### Core tests, which rely only on boost-test and ndarray.
find_package(Boost COMPONENTS unit_test_framework REQUIRED)
find_package(Boost COMPONENTS headers unit_test_framework REQUIRED)

include_directories( ${PROJECT_SOURCE_DIR}/include)

add_executable(ndarray_test ndarray.cc)

target_link_libraries(ndarray_test ndarray Boost::unit_test_framework)
target_link_libraries(ndarray_test ndarray Boost::headers Boost::unit_test_framework)
add_test(test_ndarray ndarray_test)

add_executable(views views.cc)
Expand Down Expand Up @@ -64,3 +64,30 @@ if(NDARRAY_PYBIND11)
message(STATUS "Skipping pybind11 tests as they depend on Eigen")
endif(NDARRAY_EIGEN)
endif(NDARRAY_PYBIND11)

###Nanobind dependency tests (also depend on Eigen)
if(NDARRAY_NANOBIND)
if(NDARRAY_EIGEN)
find_package(Python
REQUIRED COMPONENTS Interpreter Development.Module
OPTIONAL_COMPONENTS Development.SABIModule)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)

nanobind_add_module(nanobind_test_mod
nanobind_test_mod.cc
STABLE_ABI
NB_SHARED)
target_link_libraries(nanobind_test_mod PRIVATE Eigen3::Eigen)
configure_file(nanobind_test.py nanobind_test.py COPYONLY)
add_test(NAME nanobind_test
COMMAND ${Python_EXECUTABLE}
${CMAKE_CURRENT_BINARY_DIR}/nanobind_test.py)
else(NDARRAY_EIGEN)
message(STATUS "Skipping nanobind tests as they depend on Eigen")
endif(NDARRAY_EIGEN)
endif(NDARRAY_NANOBIND)

Loading