// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include
#include
#include
#include
#include
namespace nvfuser::python {
void bindEnums(py::module& nvfuser) {
//! DataTypes supported by nvFuser in the FusionDefinition. The python
//! DataType maps to the CPP PrimDataType. On the CPP side, there is also a
//! DateType enum that includes struct, array, pointer, or opaque datatypes.
py::enum_(nvfuser, "DataType", py::module_local())
.value("Double", DataType::Double)
.value("Float", DataType::Float)
.value("Half", DataType::Half)
.value("Int", DataType::Int)
.value("Int32", DataType::Int32)
.value("UInt64", DataType::UInt64)
.value("Index", DataType::Index)
.value("Bool", DataType::Bool)
.value("BFloat16", DataType::BFloat16)
.value("Float8_e4m3fn", DataType::Float8_e4m3fn)
.value("Float8_e5m2", DataType::Float8_e5m2)
.value("Float8_e8m0fnu", DataType::Float8_e8m0fnu)
.value("Float4_e2m1fn", DataType::Float4_e2m1fn)
.value("Float4_e2m1fn_x2", DataType::Float4_e2m1fn_x2)
.value("ComplexFloat", DataType::ComplexFloat)
.value("ComplexDouble", DataType::ComplexDouble)
.value("Null", DataType::Null);
py::enum_(nvfuser, "ParallelType", py::module_local())
.value("mesh_x", ParallelType::DIDx)
.value("mesh_y", ParallelType::DIDy)
.value("mesh_z", ParallelType::DIDz)
.value("grid_x", ParallelType::BIDx)
.value("grid_y", ParallelType::BIDy)
.value("grid_z", ParallelType::BIDz)
.value("block_x", ParallelType::TIDx)
.value("block_y", ParallelType::TIDy)
.value("block_z", ParallelType::TIDz)
.value("mma", ParallelType::Mma)
.value("serial", ParallelType::Serial)
.value("tma", ParallelType::Bulk)
.value("unroll", ParallelType::Unroll)
.value("unswitch", ParallelType::Unswitch)
.value("vectorize", ParallelType::Vectorize)
.value("stream", ParallelType::Stream);
py::enum_(
nvfuser, "CommunicatorBackend", py::module_local())
.value("nccl", CommunicatorBackend::kNccl)
.value("ucc", CommunicatorBackend::kUcc)
.value("cuda", CommunicatorBackend::kCuda);
py::enum_(nvfuser, "SchedulerType", py::module_local())
.value("none", SchedulerType::None)
.value("no_op", SchedulerType::NoOp)
.value("pointwise", SchedulerType::PointWise)
.value("matmul", SchedulerType::Matmul)
.value("reduction", SchedulerType::Reduction)
.value("inner_persistent", SchedulerType::InnerPersistent)
.value("inner_outer_persistent", SchedulerType::InnerOuterPersistent)
.value("outer_persistent", SchedulerType::OuterPersistent)
.value("transpose", SchedulerType::Transpose)
.value("expr_eval", SchedulerType::ExprEval)
.value("resize", SchedulerType::Resize);
py::enum_(nvfuser, "LoadStoreOpType", py::module_local())
.value("set", LoadStoreOpType::Set)
.value("load_matrix", LoadStoreOpType::LdMatrix)
.value("cp_async", LoadStoreOpType::CpAsync)
.value("tma", LoadStoreOpType::CpAsyncBulkTensorTile);
py::enum_(nvfuser, "MemoryType", py::module_local())
.value("tensor", MemoryType::Tensor)
.value("local", MemoryType::Local)
.value("shared", MemoryType::Shared)
.value("global", MemoryType::Global)
.value("symmetric", MemoryType::Symmetric);
py::enum_(nvfuser, "CacheOp", py::module_local())
.value("unspecified", CacheOp::Unspecified)
.value("all_levels", CacheOp::AllLevels)
.value("streaming", CacheOp::Streaming)
.value("global", CacheOp::Global);
py::enum_(nvfuser, "IdMappingMode")
.value("exact", IdMappingMode::EXACT)
.value("almost_exact", IdMappingMode::ALMOSTEXACT)
.value("broadcast", IdMappingMode::BROADCAST)
.value("permissive", IdMappingMode::PERMISSIVE)
.value("loop", IdMappingMode::LOOP);
py::enum_<:tilingstrategy> tiling_strategy(
nvfuser, "MatmulTilingStrategy", py::module_local());
tiling_strategy.value(
"one_tile_per_cta", MatmulParams::TilingStrategy::OneTilePerCTA);
tiling_strategy.value(
"distribute_tiles_across_sms",
MatmulParams::TilingStrategy::DistributeTilesAcrossSMs);
tiling_strategy.value(
"distribute_stages_across_sms",
MatmulParams::TilingStrategy::DistributeStagesAcrossSMs);
py::enum_<:bufferinglooplevel> buffering_loop_level(
nvfuser, "MatmulBufferingLoopLevel", py::module_local());
buffering_loop_level.value(
"cta_tiles", MatmulParams::BufferingLoopLevel::CTATiles);
buffering_loop_level.value(
"warp_tiles", MatmulParams::BufferingLoopLevel::WarpTiles);
py::enum_<:circularbufferingstrategy>
circular_buffering_strategy(
nvfuser, "MatmulCircularBufferingStrategy", py::module_local());
circular_buffering_strategy.value(
"pipelined", MatmulParams::CircularBufferingStrategy::Pipelined);
circular_buffering_strategy.value(
"warp_specialized",
MatmulParams::CircularBufferingStrategy::WarpSpecialized);
}
} // namespace nvfuser::python