See More

// clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on #include #include #include #include #include #include #include #include #include #include "base.h" #include #include #include #include #include namespace nvfuser::python { namespace { // Check if a type is an optional via type_traits template struct is_optional : std::false_type {}; template struct is_optional<:optional>> : std::true_type {}; template inline constexpr bool is_optional_v = is_optional::value; // A struct to hold default values for keyword arguments. template struct KeywordArgument { using type = T; // The data type of the argument std::string name; std::optional default_value; }; // Check if the NvFuser Val can be represented as a Python scalar. bool isPythonScalar(const Val* v) { // short_circuit: Symbolic values are not Python scalars. if (v->isSymbolic()) { return false; } // Check if the dtype is compatible with a Python scalar. // e.g., Python scalar cannot distiguish between ComplexDouble and // ComplexFloat. In this case, define_scalar must be used to specify custom // dtype. PrimDataType value_dtype(std::get(v->dtype().type)); switch (value_dtype) { case PrimDataType::Bool: case PrimDataType::Int: case PrimDataType::Index: case PrimDataType::ComplexDouble: case PrimDataType::Double: return true; default: return false; } } class PythonPrinter { public: PythonPrinter(std::ostream& os) : os_(&os) {} // Generate a python string for a string value. std::string toString(const std::string& s) { return s; } // Generate a python string for a boolean value. std::string toString(bool b) { return b ? "True" : "False"; } // Generate a python string for an int64_t value. std::string toString(int64_t i) { return std::to_string(i); } // Generate a python string for an size_t value. std::string toString(size_t i) { return std::to_string(i); } // Generate a python string for a complex double value. std::string toString(std::complex c) { std::stringstream ss; ss << std::showpoint << std::real(c) << "+" << std::showpoint << std::imag(c) << "j"; return ss.str(); } // Generate a python string for a double value. std::string toString(double d) { if (std::isinf(d)) { if (std::signbit(d)) { return "float(\"-inf\")"; } else { return "float(\"inf\")"; } } else if (std::isnan(d)) { return "float(\"nan\")"; } else { std::stringstream ss; ss << std::showpoint << d; return ss.str(); } } // Generate a python string for a Datatype. std::string toString(DataType dtype) { return dtypeToPyString(std::get(dtype.type)); } // Generate a python string for a PolymorphicValue with simple types. std::string toString(const PolymorphicValue& pv) { if (pv.is()) { return toString(pv.as()); } else if (pv.is()) { return toString(pv.as()); } else if (pv.is<:complex>>()) { return toString(pv.as<:complex>>()); } else if (pv.is()) { return toString(pv.as()); } else if (pv.is<:monostate>()) { return "None"; } else { NVF_THROW("Unsupported PolymorphicValue type"); } } // Generate a unique name for a Val. Map val to name to track Val's lifetime. std::string toString(const nvfuser::Val* v, bool is_lvalue = false) { std::stringstream ss; if (v == nullptr) { return "None"; } else if (v->isA()) { ss << "tv" << v->name(); } else if (!is_lvalue && isPythonScalar(v)) { ss << toString(v->value()); } else { ss << "c" << v->name(); } return ss.str(); } // Generate a python string for an optional value. template std::string toString(std::optional optional, bool skip_none = true) { if (optional.has_value()) { return toString(optional.value()); } else if (!skip_none) { return "None"; } else { return ""; } } // Generate a python string for a keyword argument. template std::string toString( const std::string& name, T value, const std::string& separator) { std::string result = toString(value); if (result.empty()) { return ""; } return separator + name + "=" + result; } // Generate a python list of values. template std::string toString(const std::vector& vec, bool is_list = true) { std::stringstream ss; if (is_list) { ss << "["; } for (auto&& [i, val] : enumerate(vec)) { if constexpr (is_optional_v) { ss << toString(val, /*skip_none=*/false); } else { ss << toString(val); } if (i < std::ssize(vec) - 1) { ss << ", "; } } if (is_list) { ss << "]"; } return ss.str(); } // Generate a python list of values. std::string generateOutputs(const std::vector& vec) { std::stringstream ss; for (auto&& [i, val] : enumerate(vec)) { if (val == nullptr) { ss << "_"; } else { NVF_ERROR( !isPythonScalar(val), "A constant scalar Val* cannot be an output lvalue. Got\t", val->toString()); ss << toString(val, /*is_lvalue=*/true); } if (i < std::ssize(vec) - 1) { ss << ", "; } } return ss.str(); } // Generate a python list of values. template std::string generateList(std::tuple const& args) { if (sizeof...(Ts) == 0) { return ""; } std::stringstream ss; std::apply( [&](Ts const&... tuple_args) { size_t i = 0; (((ss << (i > 0 ? ", " : "") << toString(tuple_args)), ++i), ...); }, args); return ss.str(); } // Generate a python list of values with string keyword arguments. template std::string generateNamedList( const std::vector<:string>& argument_names, std::tuple const& args) { NVF_ERROR( argument_names.size() == sizeof...(Ts), "Input argument names and args must have the same size."); // Use std::apply to unpack tuple of arguments into a lambda. The lambda // contains a C++17 fold expression on a comma operator that writes each // argument to stringstream and increments argument position. std::stringstream ss; std::apply( [this, &ss, &argument_names](Ts const&... tuple_args) { size_t i = 0; (((ss << toString( argument_names[i], tuple_args, (i > 0 ? ", " : ""))), ++i), ...); }, args); return ss.str(); } // Generate a python list of values with string keyword arguments. // * A tuple of default argument values is provided to the function. // * If the default argument is optional and equal to the provided argument, // then skip printing the keyword-argument pair. template std::string generateNamedList( std::tuple const& default_args, std::tuple const& args) { NVF_ERROR( sizeof...(Ds) == sizeof...(Ts), "The default and given arguments must have the same size."); // This immediately-invoked generic lambda uses a C++17 fold expression // to emulate a loop over the tuple elements. // // 1. `std::make_index_sequence` generates a compile-time sequence of // indices (0, 1, 2...). // 2. The lambda accepts this sequence, deducing the indices into the // template pack `Is...`. // 3. A fold expression over the comma operator expands the code for each // index. // 4. A ternary operator `(condition ? ... : ...)` performs the conditional // logic. // 5. If condition is true, another comma operator // `(write_to_stream, increment_counters)` chains the side effects of // writing to the stringstream and advancing the counters. // 6. If condition is false, increment `printed_arg_pos` counter by zero. std::stringstream ss; [&]<:size_t... is>(std::index_sequence) { size_t printed_arg_pos = 0; (((!std::get(default_args).default_value.has_value() || std::get(default_args).default_value.value() != std::get(args)) ? ((ss << toString( std::get(default_args).name, std::get(args), (printed_arg_pos > 0 ? ", " : ""))), ++printed_arg_pos) : (printed_arg_pos += 0)), ...); }(std::make_index_sequence{}); // Generate indices 0..N-1 return ss.str(); } // Generate a python operation with a list of inputs and outputs. void generateOperation( const std::string& op_name, const std::vector& inputs, const std::vector& outputs) { (*os_) << kTab; if (!outputs.empty()) { (*os_) << generateOutputs(outputs) << " = "; } (*os_) << op_name << "(" << toString(inputs, /*is_list=*/false) << ")\n"; } // Generate a python operation with a list of inputs and outputs. // A string keyword argument is added for each input. The default_kwargs // argument allows skipping arguments if it isn't strictly necessary. template < typename... arg_types, typename... default_kwarg_types, typename... kwargs_types> void generateKwargsOperation( const std::string& op_name, const std::tuple& args, const std::tuple& default_kwargs, const std::tuple& kwargs, const std::vector& outputs) { std::string kwargs_str = generateNamedList(default_kwargs, kwargs); constexpr bool any_arguments = sizeof...(arg_types) == 0; std::string connect = (any_arguments || kwargs_str.empty()) ? "" : ", "; (*os_) << kTab << generateOutputs(outputs) << " = " << op_name << "(" << generateList(args) << connect << kwargs_str << ")\n"; } // Generate a python operation with a list of inputs and outputs. // A string is added for each keyword argument. // // NOTES // ------ // - args and kwargs are a tuple, so it accepts a fixed set of arguments of // any type at compile-time. // - outputs is a vector of nvfuser values that is converted into a string. template void generateKwargsOperation( const std::string& op_name, const std::tuple& args, const std::vector<:string>& kwargs_names, const std::tuple& kwargs, const std::vector& outputs) { std::string connect = (sizeof...(arg_types) == 0) ? "" : ", "; (*os_) << kTab << generateOutputs(outputs) << " = " << op_name << "(" << generateList(args) << connect << generateNamedList(kwargs_names, kwargs) << ")\n"; } // Generate a python operation with a list of inputs and a single output. // A string is added for each keyword argument. // // NOTES // ------ // - args and kwargs are a tuple, so it accepts a fixed set of arguments of // any type at compile-time. // - output_name is a string. template void generateKwargsOperation( const std::string& op_name, const std::tuple& args, const std::vector<:string>& kwargs_names, const std::tuple& kwargs, const std::string& output_name) { std::string connect = (sizeof...(arg_types) == 0) ? "" : ", "; (*os_) << kTab << output_name << " = " << op_name << "(" << generateList(args) << connect << generateNamedList(kwargs_names, kwargs) << ")\n"; } // Generate a python operation with a list of inputs and outputs. // A string is added for each keyword argument. // // NOTES // ------ // - args and outputs are vectors of nvfuser values that are converted into // strings. // - kwargs is a tuple, so it accepts a fixed set of arguments of // any type at compile-time. template void generateKwargsOperation( const std::string& op_name, const std::vector& args, const std::vector<:string>& kwargs_names, const std::tuple& kwargs, const std::vector& outputs) { std::string connect = args.empty() ? "" : ", "; (*os_) << kTab << generateOutputs(outputs) << " = " << op_name << "(" << toString(args) << connect << generateNamedList(kwargs_names, kwargs) << ")\n"; } // Generate a python definition for a FusionDefinition. void generateFusionDefinition() { (*os_) << "def nvfuser_fusion(fd : FusionDefinition) -> None :\n"; } private: //! The stream to print the python function to. std::ostream* os_; //! Indentation for python code. static constexpr const char* kTab = " "; }; // PythonTranslator converts CPP Fusion to an equivalent python definition. // // How to add support for an expression not yet overriden by PythonTranslator? // 1. Create handle function for expression. // a. void handle(const SomeOp* op) final // 2. Check if IR node pointer is not nullptr. // 3. Add output values for Expr node to visited_vals_. // 4. Create scalar input arguments. This step is for view and expand // operations. // a. TensorView input arguments are handled via DAG traversal. // 5. Use PythonPrinter to create string for operation. // a. output = operation(inputs...) // 6. Use `PythonPrinter::generateOperation` if the operation only uses // positional arguments. This is mainly used for unary and binary // operations. // 7. Use `PythonPrinter::generateKwargsOperation` if the operation uses // keyword arguments. // a. If none of the keyword arguments have default arguments, create a // static vector of strings. // b. If some of the keyword arguments have default arguments, create a // vector of KeywordArgument. The KeywordArgument struct hold default // values for keyword arguments. Use `std::nullopt` for keyword // arguments without default values. // // How to debug PythonTranslator? // 1. Recompile with debug symbols // `export NVFUSER_BUILD_BUILD_TYPE=RelwithDebInfo` // 2. Run `gdb python` // 3. Catch exception in gdb `(gdb) catch throw`. // 4. Run failing test. // `r -m pytest test_python_frontend.py -k [your_failing_test]` // 5. At gdb Catchpoint, get backtrace for call stack using `(gdb) bt`. // 6. Find and fix failure in PythonTranslate. // // TODO: Python operations without a corresponding Fusion IR node require // pattern matching. // 1. Map a series of scalar values to `define_vector` // 2. Map Squeeze, Reduction, and Broadcast to a single reduction operation // with keepdim argument. // 3. Map Broadcast and Expand to `broadcast_in_dim` // 4. var_mean // 5. var class PythonTranslator : public OptInConstDispatch { public: // Returns a map from the values in the CPP fusion to its corresponding // FusionDefinition State index. static void print(std::ostream& os, Fusion* fusion) { PythonTranslator translator(os, fusion); translator.translate(); } private: PythonTranslator(std::ostream& os, Fusion* fusion) : printer_(os), fusion_(fusion) {} // The output TensorView shape can be dynamic for operations like ReshapeOp. // Check that all dynamic scalar dependencies are handled first before // handling the expression. bool checkDynamicShapeDependency(const Expr* op) { // short-circuit: Only check operations with dynamic shapes encoded in the // output TensorView. if (!op->isOneOf()) { return true; } NVF_ERROR_EQ(op->outputs().size(), 1); const std::vector& logical_out_domain = op->output(0)->as()->domain()->logical(); std::vector logical_domain_extents; std::ranges::copy( logical_out_domain | std::views::transform([](IterDomain* id) { return id->getMaybeExpandedExtent(); }), std::back_inserter(logical_domain_extents)); return std::ranges::all_of(logical_domain_extents, [&](Val* v) { return v->definition() == nullptr || visited_vals_.count(v) > 0; }); } // Gather the expressions necessary to create a scalar value. std::vector gatherScalarExpressions(Val* v) { NVF_ERROR(v != nullptr); NVF_ERROR(v->isScalar()); // short-circuit: v does not have a definition. if (v->definition() == nullptr) { return {}; } std::vector expression_chain; std::unordered_set visited; std::vector to_visit = {v->definition()}; while (!to_visit.empty()) { Expr* e = to_visit.back(); to_visit.pop_back(); expression_chain.push_back(e); visited.insert(e); for (Val* input : e->inputs()) { // short-circuit: input does not have a definition. if (input->definition() == nullptr) { continue; } // short-circuit: input definition is already visited. if (visited.count(input->definition()) > 0) { continue; } to_visit.push_back(input->definition()); } } return expression_chain; } // Gather the scalar expressions necessary to create the logical domain for a // TensorView. std::vector gatherScalarExpressions(TensorView* tv) { NVF_ERROR(tv != nullptr); std::vector logical_domain_expressions; const std::vector& logical_out_domain = tv->domain()->logical(); for (IterDomain* id : logical_out_domain) { std::vector extent_definitions = gatherScalarExpressions(id->getMaybeExpandedExtent()); logical_domain_expressions.insert( logical_domain_expressions.end(), extent_definitions.begin(), extent_definitions.end()); } return logical_domain_expressions; } // Check that all of the expression's inputs are defined in FusionDefinition. bool checkExpressionDependencies(Expr* e) { // short-circuit: Found an operation without all its dynamic shape // dependencies. if (!checkDynamicShapeDependency(e)) { return false; } return std::all_of( e->inputs().begin(), e->inputs().end(), [&](const Val* v) { return isPythonScalar(v) || visited_vals_.count(v) > 0; }); } void translate() { printer_.generateFusionDefinition(); // Add Fusion inputs to FusionDefinition for (nvfuser::Val* v : fusion_->inputs()) { dispatch(v); } // Gather all expressions in CPP Fusion. const std::vector<:expr> fusion_exprs = fusion_->exprs(); std::deque<:expr> to_visit( fusion_exprs.begin(), fusion_exprs.end()); // Scalar expressions are not handled by Fusion::exprs, so gather them // manually. for (Expr* e : to_visit) { if (e->isOneOf()) { NVF_ERROR_EQ(e->outputs().size(), 1); std::vector extent_definitions = gatherScalarExpressions(e->output(0)->as()); to_visit.insert( to_visit.end(), extent_definitions.begin(), extent_definitions.end()); } } // Topological search of Fusion expressions size_t skip_count = 0; std::unordered_set<:expr> visited; while (!to_visit.empty()) { Expr* e = to_visit.front(); to_visit.pop_front(); NVF_ERROR( skip_count <= to_visit.size(), "Cycle detected: None of the expressions can be processed!"); // short-circuit: skip if already visited if (visited.count(e) > 0) { continue; } // TODO: short-circuit: skip Split and Merge expressions created by // Reshape // TODO: short-circuit: skip Resize expressions created by Slice // TODO: direct bindings does not support scheduled expressions. // Handle scalars and constants not generated by separate expression. std::vector scalars; std::ranges::copy_if( e->inputs(), std::back_inserter(scalars), [](Val* v) { return v->isScalar(); }); std::ranges::for_each(scalars, [this](const Val* v) { dispatch(v); }); // short-circuit: add to back of stack if not all of the expression's // dependencies are satisfied. if (!checkExpressionDependencies(e)) { ++skip_count; to_visit.push_back(e); continue; } // Create string representation given inputs, outputs, and attributes. visited.insert(e); dispatch(e); skip_count = 0; } // Add tensor outputs and handle aliased outputs std::unordered_set<:val> visited_alias_output; for (nvfuser::Val* v : fusion_->outputs()) { NVF_ERROR(v->isA()); const AliasInfo& alias_info = fusion_->getOutputAlias(v); switch (alias_info.type) { case AllocationType::New: { handleOutput(v->as()); break; } case AllocationType::ReuseBuffer: { // Only apply aliasing once if (visited_alias_output.count(v) == 0) { visited_alias_output.insert(v); handleOutput(v->as(), alias_info); } // If not hide_output, then the aliased output is returned as a // fusion output. if (alias_info.visibility == OutputVisibility::kVisible) { handleOutput(v->as()); } break; } default: NVF_THROW("Unsupported AllocationType"); } } } // ================================================================================= // Filter Functions // Gather all TensorViews and FusionDefinition indices std::vector tensors() { std::vector tensors; std::ranges::copy_if( visited_vals_, std::back_inserter(tensors), [](const nvfuser::Val* v) { return v->isA(); }); return tensors; } // ================================================================================= // Create scalar for given nvfuser value. The nvfuser value must not already // exist and have a definition. It can be a fusion input, a constant, or a // tensor's extent. void handle(const Val* v) final { NVF_ERROR(v != nullptr); // short-circuit: scalar definition has a definition if (v->definition() != nullptr) { return; } // short-circuit: value already exists in FusionDefinition if (visited_vals_.count(v) > 0) { return; } // short-circuit: print python scalar directly if (isPythonScalar(v)) { return; } visited_vals_.insert(v); // Since scalars can come from TensorView dimension sizes, search through // all TensorViews for an iterDomain whose extent matches the desired // value and then use size op. for (const nvfuser::Val* tv_val : tensors()) { const auto* tv = tv_val->as(); // Get extents for each IterDomain std::vector filtered_logical_domain = TensorDomain::noReductions(tv->domain()->logical()); std::vector extents; extents.reserve(filtered_logical_domain.size()); std::ranges::copy( filtered_logical_domain | std::views::transform([](IterDomain* id) { return id->getMaybeExpandedExtent(); }), std::back_inserter(extents)); // Check if value matches iterdomain extent auto iter = std::ranges::find(extents, v); if (iter == extents.end()) { continue; } int64_t dim = std::distance(extents.begin(), iter); static const std::vector<:string> argument_names = {"dim"}; printer_.generateKwargsOperation( "fd.ops.size", std::make_tuple(tv), argument_names, std::make_tuple(dim), {v}); return; } static const std::vector<:string> argument_names = {"dtype"}; printer_.generateKwargsOperation( "fd.define_scalar", std::make_tuple(v->value()), argument_names, std::make_tuple(v->dtype()), {v}); } // Add Tensor value to Fusion Definition void handle(const TensorView* tv) final { NVF_ERROR(tv != nullptr); // short-circuit: value already exists in FusionDefinition if (visited_vals_.count(tv) > 0) { return; } visited_vals_.insert(tv); std::vector shape; std::transform( tv->domain()->logical().begin(), tv->domain()->logical().end(), std::back_inserter(shape), [](IterDomain* id) { return (id->getMaybeExpandedExtent()->isConstScalar()) ? id->getMaybeExpandedExtent()->evaluate().as() : -1; }); const std::vector& stride_order = tv->domain()->strideOrder(); static const std::vector<:string> argument_names = { "shape", "contiguity", "dtype", "is_cpu", "stride_order"}; printer_.generateKwargsOperation( "fd.define_tensor", std::make_tuple(), argument_names, std::make_tuple( shape, tv->domain()->contiguity(), tv->dtype(), tv->isCpuScalar(), (stride_order.empty()) ? std::nullopt : std::make_optional(stride_order)), {tv}); } // ================================================================================= // Utility functions // Create a vector for the logical domain of TensorView. // Used with ReshapeOp, ExpandOp, and FullOp handlers std::vector getShape(TensorView* tv) { const std::vector& logical_out_domain = tv->domain()->logical(); std::vector logical_domain_extents; // Use expanded extent if available for IterDomain. std::ranges::copy( logical_out_domain | std::views::transform([](IterDomain* id) { return id->getMaybeExpandedExtent(); }), std::back_inserter(logical_domain_extents)); return logical_domain_extents; } // Find integer index corresponding with reduction iterDomains std::vector getReductionAxes(TensorView* tv) { std::vector axes; const std::vector& logical_domain = tv->domain()->logical(); for (int64_t dim : c10::irange((int64_t)logical_domain.size())) { if (logical_domain.at(dim)->isReduction()) { axes.push_back(dim); } } return axes; } // ================================================================================= // Handle add_output variants // Add Tensor output to FusionDefinition void handleOutput(const TensorView* tv) { NVF_ERROR(tv != nullptr); printer_.generateOperation("fd.add_output", {tv}, {}); } // Alias output Tensor with input tensor void handleOutput(const TensorView* tv, const AliasInfo& alias_info) { NVF_ERROR(tv != nullptr); printer_.generateOperation( "fd.add_output", {tv, alias_info.aliased_io}, {}); } // ================================================================================= // Map CPP Expression classes to corresponding RecordFunctors in // python_frontend void handle(const UnaryOp* uop) final { NVF_ERROR(uop != nullptr); // short-circuit: Handle cast operation separately if (uop->getUnaryOpType() == UnaryOpType::Cast) { return handleCastOp(uop); } // Map remaining UnaryOp to python_frontend visited_vals_.insert(uop->out()); printer_.generateOperation( "fd.ops." + nvfuser::python::toString(uop), {uop->in()}, {uop->out()}); } void handleCastOp(const UnaryOp* uop) { NVF_ERROR(uop->getUnaryOpType() == UnaryOpType::Cast); visited_vals_.insert(uop->out()); static const std::vector<:string> argument_names = {"dtype"}; printer_.generateKwargsOperation( "fd.ops.cast", std::make_tuple(uop->in()), argument_names, std::make_tuple(uop->out()->dtype()), {uop->out()}); } void handle(const BinaryOp* bop) final { NVF_ERROR(bop != nullptr); if (visited_vals_.count(bop->out()) > 0) { return; } visited_vals_.insert(bop->out()); printer_.generateOperation( "fd.ops." + nvfuser::python::toString(bop), {bop->lhs(), bop->rhs()}, {bop->out()}); } void handle(const TernaryOp* top) final { NVF_ERROR(top != nullptr); visited_vals_.insert(top->out()); printer_.generateOperation( "fd.ops." + nvfuser::python::toString(top), {top->in1(), top->in2(), top->in3()}, {top->out()}); } void handle(const ReductionOp* rop) final { NVF_ERROR(rop != nullptr); NVF_ERROR(rop->out()->isA()); visited_vals_.insert(rop->out()); // The min and max reduction operations expect the dtype argument to by // PrimDataType::Null DataType dtype = (rop->getReductionOpType() == BinaryOpType::Min || rop->getReductionOpType() == BinaryOpType::FMin || rop->getReductionOpType() == BinaryOpType::Max || rop->getReductionOpType() == BinaryOpType::FMax) ? DataType::Null : rop->out()->dtype(); std::vector dims = getReductionAxes(rop->out()->as()); // TODO: keepdim is always False in ReductionOp because a separate // BroadcastOp node exists if keepdim is True. Detect this pattern to // minify the python definition. static const auto default_args = std::make_tuple( KeywordArgument{ .name = "dims", .default_value = std::nullopt}, KeywordArgument{.name = "keepdim", .default_value = false}, KeywordArgument{ .name = "dtype", .default_value = DataType::Null}); printer_.generateKwargsOperation( "fd.ops." + nvfuser::python::toString(rop), std::make_tuple(rop->in()), default_args, std::make_tuple(dims, false, dtype), {rop->out()}); } void handle(const ScanOp* sop) final { NVF_ERROR(sop != nullptr); visited_vals_.insert(sop->out()); static const auto default_args = std::make_tuple( KeywordArgument{.name = "dim", .default_value = std::nullopt}); printer_.generateKwargsOperation( "fd.ops." + toString(sop), std::make_tuple(sop->in()), default_args, std::make_tuple(sop->dim()), {sop->out()}); } void handle(const WelfordOp* wop) final { NVF_ERROR(wop != nullptr); NVF_ERROR(wop->initAvg()->evaluate().as() == 0.0); NVF_ERROR(wop->initVar()->evaluate().as() == 0.0); NVF_ERROR(wop->initN()->evaluate().as() == 0); visited_vals_.insert(wop->outAvg()); visited_vals_.insert(wop->outVar()); visited_vals_.insert(wop->outN()); static const std::vector<:string> argument_names = {"dims"}; printer_.generateKwargsOperation( "fd.ops.welford", std::make_tuple(wop->in()), argument_names, std::make_tuple(getReductionAxes(wop->outAvg()->as())), {wop->outAvg(), wop->outVar(), wop->outN()}); } void handle(const BroadcastOp* bcast_op) final { NVF_ERROR(bcast_op != nullptr); visited_vals_.insert(bcast_op->out()); static const std::vector<:string> broadcast_argument_names = { "is_broadcast_dim"}; printer_.generateKwargsOperation( "fd.ops.broadcast", std::make_tuple(bcast_op->in()), broadcast_argument_names, std::make_tuple(bcast_op->getBroadcastDimFlags()), {bcast_op->out()}); } void handle(const MatmulOp* matmul_op) final { NVF_ERROR(matmul_op != nullptr); visited_vals_.insert(matmul_op->out()); printer_.generateOperation( "fd.ops.matmul", {matmul_op->inA(), matmul_op->inB()}, {matmul_op->out()}); } void handle(const LinearOp* lop) final { NVF_ERROR(lop != nullptr); visited_vals_.insert(lop->out()); static const auto default_args = std::make_tuple( KeywordArgument{.name = "bias", .default_value = nullptr}); printer_.generateKwargsOperation( "fd.ops.linear", std::make_tuple(lop->inA(), lop->inB()), default_args, std::make_tuple(lop->bias()), {lop->out()}); } void handle(const GroupedMmaOp* gmm_op) final { NVF_ERROR(gmm_op != nullptr); TensorView* out_tv = gmm_op->out(); visited_vals_.insert(gmm_op->out()); int64_t out_block_scale_size = 0; PrimDataType out_block_scale_dtype = DataType::BFloat16; bool out_gamma = false; TensorView* out_block_scale_tv = gmm_op->outScale(); if (out_block_scale_tv != nullptr) { visited_vals_.insert(gmm_op->outScale()); const std::vector& logical = out_block_scale_tv->getLogicalDomain(); Val* block_size_extent = logical.at(logical.size() - 1)->extent(); NVF_CHECK( block_size_extent->isConstInt(), "Block size extent needs to be a constant integer"); out_block_scale_size = block_size_extent->evaluate().as(); out_block_scale_dtype = std::get(out_block_scale_tv->dtype().type); } TensorView* out_gamma_tv = gmm_op->outGamma(); if (out_gamma_tv != nullptr) { visited_vals_.insert(gmm_op->outGamma()); out_gamma = true; } if (gmm_op->inputs().size() == 3) { printer_.generateOperation( "fd.ops.grouped_mm", {gmm_op->matrix1(), gmm_op->matrix2(), gmm_op->offsets()}, {gmm_op->out()}); } else { static const auto default_args = std::make_tuple( KeywordArgumentalpha())>{ .name = "alpha", .default_value = nullptr}, KeywordArgumentbias())>{ .name = "bias", .default_value = nullptr}, KeywordArgumentbeta())>{ .name = "beta", .default_value = nullptr}, KeywordArgument{ .name = "dtype", .default_value = DataType::BFloat16}, KeywordArgument{ .name = "output_block_scale_size", .default_value = 0}, KeywordArgument{ .name = "output_block_scale_dtype", .default_value = DataType::BFloat16}, KeywordArgument{ .name = "output_gamma", .default_value = false}); printer_.generateKwargsOperation( "fd.ops.grouped_mm", std::make_tuple( gmm_op->matrix1(), gmm_op->matrix2(), gmm_op->offsets(), gmm_op->scale1(), gmm_op->scale2()), default_args, std::make_tuple( gmm_op->alpha(), gmm_op->bias(), gmm_op->beta(), out_tv->dtype(), out_block_scale_size, out_block_scale_dtype, out_gamma), {gmm_op->out(), out_block_scale_tv, out_gamma_tv}); } } void handle(const CutlassNvfp4GroupedMmaOp* cmm_op) final { NVF_ERROR(cmm_op != nullptr); visited_vals_.insert(cmm_op->out()); printer_.generateOperation( "fd.ops.cutlass_nvfp4_grouped_mm", {cmm_op->matrix1(), cmm_op->matrix2(), cmm_op->scale1(), cmm_op->scale2(), cmm_op->alpha(), cmm_op->problemSizes(), cmm_op->expertOffsets(), cmm_op->scalingFactorOffsets()}, {cmm_op->out()}); } void handle(const ScaledMmaOp* smm_op) final { NVF_ERROR(smm_op != nullptr); TensorView* out_tv = smm_op->out(); visited_vals_.insert(smm_op->out()); int64_t out_block_scale_size = 0; PrimDataType out_block_scale_dtype = DataType::BFloat16; bool out_gamma = false; TensorView* out_block_scale_tv = smm_op->outScale(); if (out_block_scale_tv != nullptr) { visited_vals_.insert(smm_op->outScale()); const std::vector& logical = out_block_scale_tv->getLogicalDomain(); Val* block_size_extent = logical.at(logical.size() - 1)->extent(); NVF_CHECK( block_size_extent->isConstInt(), "Block size extent needs to be a constant integer"); out_block_scale_size = block_size_extent->evaluate().as(); out_block_scale_dtype = std::get(out_block_scale_tv->dtype().type); } TensorView* out_gamma_tv = smm_op->outGamma(); if (out_gamma_tv != nullptr) { visited_vals_.insert(smm_op->outGamma()); out_gamma = true; } static const auto default_args = std::make_tuple( KeywordArgumentalpha())>{ .name = "alpha", .default_value = nullptr}, KeywordArgumentbias())>{ .name = "bias", .default_value = nullptr}, KeywordArgumentbeta())>{ .name = "beta", .default_value = nullptr}, KeywordArgument{ .name = "dtype", .default_value = DataType::BFloat16}, KeywordArgument{ .name = "output_block_scale_size", .default_value = 0}, KeywordArgument{ .name = "output_block_scale_dtype", .default_value = DataType::BFloat16}, KeywordArgument{.name = "output_gamma", .default_value = false}); printer_.generateKwargsOperation( "fd.ops.scaled_mm", std::make_tuple( smm_op->matrix1(), smm_op->matrix2(), smm_op->scale1(), smm_op->scale2()), default_args, std::make_tuple( smm_op->alpha(), smm_op->bias(), smm_op->beta(), out_tv->dtype(), out_block_scale_size, out_block_scale_dtype, out_gamma), {smm_op->out(), out_block_scale_tv, out_gamma_tv}); } void handle(const SdpaFwdOp* sdpa_fwd_op) final { NVF_ERROR(sdpa_fwd_op != nullptr); static const auto default_args = std::make_tuple( KeywordArgument{.name = "bias", .default_value = nullptr}, KeywordArgument{.name = "mask", .default_value = nullptr}, KeywordArgument{.name = "dropout_p", .default_value = nullptr}, KeywordArgument{.name = "is_causal", .default_value = nullptr}, KeywordArgument{.name = "scale", .default_value = nullptr}); visited_vals_.insert(sdpa_fwd_op->attn_out()); visited_vals_.insert(sdpa_fwd_op->logsumexp()); visited_vals_.insert(sdpa_fwd_op->philox_seed()); visited_vals_.insert(sdpa_fwd_op->philox_offset()); printer_.generateKwargsOperation( "fd.ops.sdpfa_fwd", std::make_tuple( sdpa_fwd_op->query(), sdpa_fwd_op->key(), sdpa_fwd_op->value()), default_args, std::make_tuple( sdpa_fwd_op->bias(), sdpa_fwd_op->mask(), sdpa_fwd_op->dropout_p(), sdpa_fwd_op->is_causal(), sdpa_fwd_op->scale()), {sdpa_fwd_op->attn_out(), sdpa_fwd_op->logsumexp(), sdpa_fwd_op->philox_seed(), sdpa_fwd_op->philox_offset()}); } void handle(const SdpaBwdOp* sdpa_bwd_op) final { NVF_ERROR(sdpa_bwd_op != nullptr); static const std::vector<:string> argument_names = { "dropout_p", "is_causal", "philox_seed", "philox_offset", "scale"}; visited_vals_.insert(sdpa_bwd_op->grad_query()); visited_vals_.insert(sdpa_bwd_op->grad_key()); visited_vals_.insert(sdpa_bwd_op->grad_value()); printer_.generateKwargsOperation( "fd.ops.sdpfa_bwd", std::make_tuple( sdpa_bwd_op->grad_attn(), sdpa_bwd_op->query(), sdpa_bwd_op->key(), sdpa_bwd_op->value(), sdpa_bwd_op->attn_out(), sdpa_bwd_op->logsumexp()), argument_names, std::make_tuple( sdpa_bwd_op->dropout_p(), sdpa_bwd_op->is_causal(), sdpa_bwd_op->philox_seed(), sdpa_bwd_op->philox_offset(), sdpa_bwd_op->scale()), {sdpa_bwd_op->grad_query(), sdpa_bwd_op->grad_key(), sdpa_bwd_op->grad_value()}); } void handle(const SqueezeOp* sop) final { NVF_ERROR(sop != nullptr); visited_vals_.insert(sop->out()); const std::vector& is_squeeze_dims = sop->getSqueezeDimFlags(); auto filter_range = std::views::iota(0UL, is_squeeze_dims.size()) | std::views::filter([&is_squeeze_dims](int64_t dim) { return is_squeeze_dims.at(dim); }); std::vector squeeze_dims(filter_range.begin(), filter_range.end()); auto* in_tv = sop->in()->as(); NVF_ERROR(in_tv != nullptr); // TODO: Use std::ranges::zip_view AND std::ranges::any_of with cpp23 bool squeeze_expanded = false; for (auto [squeeze_dim, id] : zip(is_squeeze_dims, in_tv->getLogicalDomain())) { if (!squeeze_dim) { continue; } squeeze_expanded |= (id->isBroadcast() && id->hasExpandedExtent()); } static const auto default_args = std::make_tuple( KeywordArgument{ .name = "dims", .default_value = std::nullopt}, KeywordArgument{ .name = "squeeze_expanded", .default_value = false}); printer_.generateKwargsOperation( "fd.ops.squeeze", std::make_tuple(sop->in()), default_args, std::make_tuple(squeeze_dims, squeeze_expanded), {sop->out()}); } void handle(const ReshapeOp* vop) final { NVF_ERROR(vop != nullptr); // Get extent's for output's logical domain auto* out_tv = vop->out()->as(); std::vector new_shape = getShape(out_tv); // TODO Check if new_shape is a vector of symbolic fusion inputs // TODO Use define_vector to create more pythonic syntax // Add CPP values to Fusion Definition if necessary static const std::vector<:string> reshape_argument_names = { "new_shape"}; std::ranges::for_each(new_shape, [this](const Val* v) { dispatch(v); }); visited_vals_.insert(vop->out()); printer_.generateKwargsOperation( "fd.ops.reshape", std::make_tuple(vop->in()), reshape_argument_names, std::make_tuple(new_shape), {vop->out()}); } void handle(const ExpandOp* eop) final { NVF_ERROR(eop != nullptr); auto* in_tv = eop->in()->as(); auto* out_tv = eop->out()->as(); NVF_ERROR(in_tv->nDims() == out_tv->nDims()); std::vector shape = getShape(out_tv); static const std::vector<:string> expand_argument_names = {"shape"}; // Add CPP values to Fusion Definition if necessary std::ranges::for_each(shape, [this](const Val* v) { dispatch(v); }); visited_vals_.insert(eop->out()); printer_.generateKwargsOperation( "fd.ops.expand", std::make_tuple(eop->in()), expand_argument_names, std::make_tuple(shape), {eop->out()}); } void handle(const SliceOp* sop) final { NVF_ERROR(sop != nullptr); std::vector<:slice> slices = sop->getRanges(); std::vector start_indices; start_indices.reserve(slices.size()); std::vector stop_indices; stop_indices.reserve(slices.size()); std::vector strides; strides.reserve(slices.size()); for (const nvfuser::Slice& s : slices) { start_indices.push_back(s.start); stop_indices.push_back(s.stop); strides.push_back(s.step); } visited_vals_.insert(sop->out()); // Since the normalization operations are expressed in the Fusion IR, // manual_normalization argument is always true and default arguments is not // used here. static const std::vector<:string> slice_argument_names = { "start_indices", "end_indices", "strides", "manual_normalization"}; printer_.generateKwargsOperation( "fd.ops.slice", std::make_tuple(sop->in()), slice_argument_names, std::make_tuple( start_indices, stop_indices, strides, /*manual_normalization=*/true), {sop->out()}); } void handle(const PadOp* pad_op) final { NVF_ERROR(pad_op != nullptr); // Step 1: Get pad widths in normalized order. std::vector normalized_pad_widths = pad_op->getPadWidths(); int64_t total_size = (int64_t)normalized_pad_widths.size(); // Step 2: Get indices for normalized pad widths. std::vector normalized_indices(total_size); std::iota(normalized_indices.begin(), normalized_indices.end(), 0); // Step 3: Transform to indices for original pad widths std::vector original_indices; original_indices.reserve(normalized_indices.size()); std::ranges::transform( normalized_indices, std::back_inserter(original_indices), [=](int64_t normalized_idx) { int64_t offset = total_size - normalized_idx; int64_t dim = ceilDiv(offset, 2) - 1; int64_t original_idx = dim * 2; // right pad values require an additional offset if (offset % 2 == 1) { original_idx += 1; } return original_idx; }); // Step 4: Get pad widths in original order. std::vector original_order_pad_widths(total_size, nullptr); for (int64_t normalized_idx : normalized_indices) { original_order_pad_widths.at(original_indices.at(normalized_idx)) = normalized_pad_widths.at(normalized_idx); } // Check that no pad width values are nullptr. NVF_ERROR(std::ranges::all_of( original_order_pad_widths, [](Val* v) { return v != nullptr; })); visited_vals_.insert(pad_op->out()); static const auto default_args = std::make_tuple( KeywordArgument{ .name = "pad_widths", .default_value = std::nullopt}, KeywordArgument{.name = "value", .default_value = nullptr}); printer_.generateKwargsOperation( "fd.ops.pad", std::make_tuple(pad_op->in()), default_args, std::make_tuple(original_order_pad_widths, pad_op->value()), {pad_op->out()}); } void handle(const CatOp* cat_op) final { NVF_ERROR(cat_op != nullptr); visited_vals_.insert(cat_op->output(0)); // Since the normalization operations are expressed in the Fusion IR, // manual_normalization argument is always true and default arguments is not // used here. static const std::vector<:string> cat_argument_names = { "dim", "manual_padding"}; printer_.generateKwargsOperation( "fd.ops.cat", cat_op->inputs(), cat_argument_names, std::make_tuple(cat_op->concatenatedDim(), /*manual_padding=*/true), {cat_op->output(0)}); } // Map RNGOp to RandomDistOpRecord void handle(const RNGOp* rop) final { NVF_ERROR(rop != nullptr); visited_vals_.insert(rop->output(0)); std::string rng_op_name; switch (rop->getRNGOpType()) { case RNGOpType::Uniform: case RNGOpType::UniformRange: rng_op_name = "fd.ops.uniform"; break; case RNGOpType::NormalStandard: case RNGOpType::NormalGeneral: rng_op_name = "fd.ops.normal"; break; default: NVF_ERROR(false, "Unsupported RNGOpType."); } static const auto default_args = std::make_tuple( KeywordArgument{.name = "rng_seed", .default_value = nullptr}, KeywordArgument{.name = "rng_offset", .default_value = nullptr}, KeywordArgument{ .name = "dtype", .default_value = DataType::Float}); Val* first_arg = nullptr; Val* second_arg = nullptr; NVF_ERROR(rop->getParameters().size() == 2 || rop->getParameters().empty()); if (rop->getParameters().size() == 2) { first_arg = rop->getParameters().at(0); second_arg = rop->getParameters().at(1); } else { // Default arg1 and arg2 is (0, 1) for both uniform and normal. first_arg = fusion_->zeroVal(); second_arg = fusion_->oneVal(); } NVF_ERROR(first_arg != nullptr && second_arg != nullptr); printer_.generateKwargsOperation( rng_op_name, std::make_tuple(first_arg, second_arg, rop->getShape()), default_args, std::make_tuple( rop->getRNGSeedVal(), rop->getRNGOffsetVal(), rop->dtype()), {rop->output(0)}); } // If input and output values share the same type, a LoadStoreOp will be // created instead of a CastOp. void handle(const LoadStoreOp* lsop) final { if (lsop->out()->isA()) { auto* out_tv = lsop->out()->as(); NVF_ERROR(!(out_tv->hasRoot() && out_tv->hasAllocation())); // short-circuit: lsop is a permutation. if (out_tv->hasRoot()) { return handlePermute(lsop); } // short-circuit: lsop is a stride_order. if (out_tv->hasAllocation()) { return handleStrideOrder(lsop); } } NVF_ERROR( lsop->in()->dtype() == lsop->out()->dtype(), "Expected the dtype for input and output to be the same"); visited_vals_.insert(lsop->out()); static const std::vector<:string> argument_names = {"dtype"}; printer_.generateKwargsOperation( "fd.ops.cast", std::make_tuple(lsop->in()), argument_names, std::make_tuple(lsop->out()->dtype()), {lsop->out()}); } void handlePermute(const LoadStoreOp* lsop) { auto* out_tv = lsop->out()->as(); std::optional<:vector>> new2old_opt = ir_utils::computePermutation( out_tv->getRootDomain(), out_tv->getLogicalDomain()); NVF_ERROR(new2old_opt.has_value(), "Expected permutation"); visited_vals_.insert(lsop->out()); static const std::vector<:string> argument_names = {"dims"}; printer_.generateKwargsOperation( "fd.ops.permute", std::make_tuple(lsop->in()), argument_names, std::make_tuple(new2old_opt.value()), {lsop->out()}); } void handleStrideOrder(const LoadStoreOp* lsop) { auto* out_tv = lsop->out()->as(); visited_vals_.insert(lsop->out()); static const std::vector<:string> argument_names = {"stride_order"}; printer_.generateKwargsOperation( "fd.ops.stride_order", std::make_tuple(lsop->in()), argument_names, std::make_tuple(out_tv->domain()->strideOrder()), {lsop->out()}); } void handle(const FullOp* fop) final { NVF_ERROR(fop != nullptr); auto* out_tv = fop->output(0)->as(); visited_vals_.insert(out_tv); // Fill value can be dynamic so create it dispatch(fop->getFillValue()); static const std::vector<:string> argument_names = { "shape", "fill_value", "dtype"}; printer_.generateKwargsOperation( "fd.ops.full", std::make_tuple(), argument_names, std::make_tuple(getShape(out_tv), fop->getFillValue(), out_tv->dtype()), {out_tv}); } void handle(const IotaOp* iop) final { NVF_ERROR(iop != nullptr); auto* out_tv = iop->output(0)->as(); visited_vals_.insert(out_tv); dispatch(iop->length()); dispatch(iop->start()); dispatch(iop->step()); static const auto default_args = std::make_tuple( KeywordArgumentlength())>{ .name = "length", .default_value = std::nullopt}, KeywordArgumentstart())>{ .name = "start", .default_value = nullptr}, KeywordArgumentstep())>{ .name = "step", .default_value = nullptr}, KeywordArgument{ .name = "dtype", .default_value = DataType::Int}); printer_.generateKwargsOperation( "fd.ops.iota", std::make_tuple(), default_args, std::make_tuple(iop->length(), iop->start(), iop->step(), iop->dtype()), {out_tv}); } void handle(const IndexSelectOp* isop) final { NVF_ERROR(isop != nullptr); auto* out_tv = isop->output(0)->as(); visited_vals_.insert(out_tv); static const std::vector<:string> argument_names = {"dim"}; printer_.generateKwargsOperation( "fd.ops.index_select", std::make_tuple(isop->lookupTv(), isop->indexTv()), argument_names, std::make_tuple(isop->dim()), {out_tv}); } void handle(const SelectOp* sop) final { NVF_ERROR(sop != nullptr); auto* out_tv = sop->output(0)->as(); visited_vals_.insert(out_tv); static const std::vector<:string> argument_names = {"dim"}; printer_.generateKwargsOperation( "fd.ops.select", std::make_tuple(sop->lookupTv(), sop->input(1)), argument_names, std::make_tuple(sop->dim()), {out_tv}); } void handle(const ScatterOp* sop) final { NVF_ERROR(sop != nullptr); auto* out_tv = sop->output(0)->as(); visited_vals_.insert(out_tv); static const std::vector<:string> argument_names = {"dim"}; printer_.generateKwargsOperation( "fd.ops.scatter", std::make_tuple(sop->in(), sop->index(), sop->src()), argument_names, std::make_tuple(sop->dim()), {out_tv}); } void handle(const GatherOp* gop) final { NVF_ERROR(gop != nullptr); auto* out_tv = gop->output(0)->as(); visited_vals_.insert(out_tv); static const std::vector<:string> argument_names = {"dim"}; printer_.generateKwargsOperation( (gop->exactSizes() ? "fd.ops.take_along_axis" : "fd.ops.gather"), std::make_tuple(gop->lookupTv(), gop->indexTv()), argument_names, std::make_tuple(gop->dim()), {out_tv}); } void handle(const TopKOp* topkop) final { NVF_ERROR(topkop != nullptr); visited_vals_.insert(topkop->output(0)); visited_vals_.insert(topkop->output(1)); static const auto default_args = std::make_tuple( KeywordArgumentdim())>{ .name = "dim", .default_value = -1}, KeywordArgument{.name = "largest", .default_value = true}, KeywordArgument{.name = "sorted", .default_value = false}); printer_.generateKwargsOperation( "fd.ops.topk", std::make_tuple(topkop->in(), topkop->k()), default_args, std::make_tuple(topkop->dim(), topkop->isLargest(), topkop->isSorted()), std::vector{topkop->output(0), topkop->output(1)}); } void handle(const ArgsortOp* argsortop) final { NVF_ERROR(argsortop != nullptr); auto* out_tv = argsortop->output(0)->as(); visited_vals_.insert(out_tv); static const auto default_args = std::make_tuple( KeywordArgumentdim())>{ .name = "dim", .default_value = std::nullopt}, KeywordArgument{.name = "descending", .default_value = false}, KeywordArgument{.name = "stable", .default_value = false}); printer_.generateKwargsOperation( "fd.ops.argsort", std::make_tuple(argsortop->in()), default_args, std::make_tuple( argsortop->dim(), argsortop->isDescending(), argsortop->isStable()), {out_tv}); } // Map EmbeddingFwdOp to python frontend void handle(const EmbeddingFwdOp* eop) final { NVF_ERROR(eop != nullptr); visited_vals_.insert(eop->output(0)); static const auto default_args = std::make_tuple( KeywordArgument{.name = "padding_idx", .default_value = nullptr}, KeywordArgument{.name = "max_norm", .default_value = nullptr}, KeywordArgument{.name = "norm_type", .default_value = nullptr}, KeywordArgument{ .name = "scale_grad_by_freq", .default_value = nullptr}, KeywordArgument{.name = "sparse", .default_value = nullptr}); printer_.generateKwargsOperation( "fd.ops.embedding_fwd", std::make_tuple(eop->in(), eop->weight()), default_args, std::make_tuple( eop->padding_idx(), eop->max_norm(), eop->norm_type(), eop->scale_grad_by_freq(), eop->sparse()), {eop->out()}); } void handle(const PreprocessGroupedMatmulInputSf* layout_op) final { NVF_ERROR(layout_op != nullptr); visited_vals_.insert(layout_op->output(0)); printer_.generateOperation( "fd.ops.preprocess_grouped_matmul_input_sf", {layout_op->in()->as(), layout_op->inputOffsets(), layout_op->outputOffsets()}, {layout_op->out()}); } // Map BlockQuantizationOp to python frontend void handle(const BlockQuantizationOp* bqop) final { NVF_ERROR(bqop != nullptr); visited_vals_.insert(bqop->output(0)); visited_vals_.insert(bqop->output(1)); static const auto default_args = std::make_tuple( KeywordArgumentglobalScale())>{ .name = "global_scale", .default_value = nullptr}, KeywordArgument{.name = "block_size", .default_value = 16}, KeywordArgument{ .name = "swizzle_block_scales", .default_value = false}, KeywordArgument{ .name = "dtype", .default_value = DataType::Float4_e2m1fn}); auto dtype = bqop->quantizedOutput()->as()->dtype(); printer_.generateKwargsOperation( "fd.ops.nv_block_quantize", std::make_tuple(bqop->in()), default_args, std::make_tuple( bqop->globalScale(), bqop->blockSize(), bqop->isSwizzledScales(), dtype), std::vector{bqop->output(0), bqop->output(1)}); } void handle(const GroupedBlockQuantizationOp* grouped_bqop) final { NVF_ERROR(grouped_bqop != nullptr); visited_vals_.insert(grouped_bqop->output(0)); visited_vals_.insert(grouped_bqop->output(1)); static const auto default_args = std::make_tuple( KeywordArgumentglobalScale())>{ .name = "global_scale", .default_value = nullptr}, KeywordArgument{.name = "block_size", .default_value = 16}, KeywordArgument{ .name = "dtype", .default_value = DataType::Float4_e2m1fn}); auto dtype = grouped_bqop->quantizedOutput()->as()->dtype(); printer_.generateKwargsOperation( "fd.ops.nv_grouped_block_quantize", std::make_tuple( grouped_bqop->in(), grouped_bqop->inputOffsets(), grouped_bqop->outputOffsets()), default_args, std::make_tuple( grouped_bqop->globalScale(), grouped_bqop->blockSize(), dtype), std::vector{ grouped_bqop->output(0), grouped_bqop->output(1)}); } private: //! Convert CPP values to python syntax. PythonPrinter printer_; //! The reference CPP fusion to be translated. Fusion* fusion_ = nullptr; //! Set of NvFuser Val's created in the Fusion. std::unordered_set visited_vals_; }; } // namespace std::string translateFusion(nvfuser::Fusion* f) { std::stringstream ss; PythonTranslator::print(ss, f); return ss.str(); } } // namespace nvfuser::python