// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include "base.h"
#include
#include
#include
#include
#include
namespace nvfuser::python {
namespace {
// Check if a type is an optional via type_traits
template
struct is_optional : std::false_type {};
template
struct is_optional<:optional>> : std::true_type {};
template
inline constexpr bool is_optional_v = is_optional::value;
// A struct to hold default values for keyword arguments.
template
struct KeywordArgument {
using type = T; // The data type of the argument
std::string name;
std::optional default_value;
};
// Check if the NvFuser Val can be represented as a Python scalar.
bool isPythonScalar(const Val* v) {
// short_circuit: Symbolic values are not Python scalars.
if (v->isSymbolic()) {
return false;
}
// Check if the dtype is compatible with a Python scalar.
// e.g., Python scalar cannot distiguish between ComplexDouble and
// ComplexFloat. In this case, define_scalar must be used to specify custom
// dtype.
PrimDataType value_dtype(std::get(v->dtype().type));
switch (value_dtype) {
case PrimDataType::Bool:
case PrimDataType::Int:
case PrimDataType::Index:
case PrimDataType::ComplexDouble:
case PrimDataType::Double:
return true;
default:
return false;
}
}
class PythonPrinter {
public:
PythonPrinter(std::ostream& os) : os_(&os) {}
// Generate a python string for a string value.
std::string toString(const std::string& s) {
return s;
}
// Generate a python string for a boolean value.
std::string toString(bool b) {
return b ? "True" : "False";
}
// Generate a python string for an int64_t value.
std::string toString(int64_t i) {
return std::to_string(i);
}
// Generate a python string for an size_t value.
std::string toString(size_t i) {
return std::to_string(i);
}
// Generate a python string for a complex double value.
std::string toString(std::complex c) {
std::stringstream ss;
ss << std::showpoint << std::real(c) << "+" << std::showpoint
<< std::imag(c) << "j";
return ss.str();
}
// Generate a python string for a double value.
std::string toString(double d) {
if (std::isinf(d)) {
if (std::signbit(d)) {
return "float(\"-inf\")";
} else {
return "float(\"inf\")";
}
} else if (std::isnan(d)) {
return "float(\"nan\")";
} else {
std::stringstream ss;
ss << std::showpoint << d;
return ss.str();
}
}
// Generate a python string for a Datatype.
std::string toString(DataType dtype) {
return dtypeToPyString(std::get(dtype.type));
}
// Generate a python string for a PolymorphicValue with simple types.
std::string toString(const PolymorphicValue& pv) {
if (pv.is()) {
return toString(pv.as());
} else if (pv.is()) {
return toString(pv.as());
} else if (pv.is<:complex>>()) {
return toString(pv.as<:complex>>());
} else if (pv.is()) {
return toString(pv.as());
} else if (pv.is<:monostate>()) {
return "None";
} else {
NVF_THROW("Unsupported PolymorphicValue type");
}
}
// Generate a unique name for a Val. Map val to name to track Val's lifetime.
std::string toString(const nvfuser::Val* v, bool is_lvalue = false) {
std::stringstream ss;
if (v == nullptr) {
return "None";
} else if (v->isA()) {
ss << "tv" << v->name();
} else if (!is_lvalue && isPythonScalar(v)) {
ss << toString(v->value());
} else {
ss << "c" << v->name();
}
return ss.str();
}
// Generate a python string for an optional value.
template
std::string toString(std::optional optional, bool skip_none = true) {
if (optional.has_value()) {
return toString(optional.value());
} else if (!skip_none) {
return "None";
} else {
return "";
}
}
// Generate a python string for a keyword argument.
template
std::string toString(
const std::string& name,
T value,
const std::string& separator) {
std::string result = toString(value);
if (result.empty()) {
return "";
}
return separator + name + "=" + result;
}
// Generate a python list of values.
template
std::string toString(const std::vector& vec, bool is_list = true) {
std::stringstream ss;
if (is_list) {
ss << "[";
}
for (auto&& [i, val] : enumerate(vec)) {
if constexpr (is_optional_v) {
ss << toString(val, /*skip_none=*/false);
} else {
ss << toString(val);
}
if (i < std::ssize(vec) - 1) {
ss << ", ";
}
}
if (is_list) {
ss << "]";
}
return ss.str();
}
// Generate a python list of values.
std::string generateOutputs(const std::vector& vec) {
std::stringstream ss;
for (auto&& [i, val] : enumerate(vec)) {
if (val == nullptr) {
ss << "_";
} else {
NVF_ERROR(
!isPythonScalar(val),
"A constant scalar Val* cannot be an output lvalue. Got\t",
val->toString());
ss << toString(val, /*is_lvalue=*/true);
}
if (i < std::ssize(vec) - 1) {
ss << ", ";
}
}
return ss.str();
}
// Generate a python list of values.
template
std::string generateList(std::tuple const& args) {
if (sizeof...(Ts) == 0) {
return "";
}
std::stringstream ss;
std::apply(
[&](Ts const&... tuple_args) {
size_t i = 0;
(((ss << (i > 0 ? ", " : "") << toString(tuple_args)), ++i), ...);
},
args);
return ss.str();
}
// Generate a python list of values with string keyword arguments.
template
std::string generateNamedList(
const std::vector<:string>& argument_names,
std::tuple const& args) {
NVF_ERROR(
argument_names.size() == sizeof...(Ts),
"Input argument names and args must have the same size.");
// Use std::apply to unpack tuple of arguments into a lambda. The lambda
// contains a C++17 fold expression on a comma operator that writes each
// argument to stringstream and increments argument position.
std::stringstream ss;
std::apply(
[this, &ss, &argument_names](Ts const&... tuple_args) {
size_t i = 0;
(((ss << toString(
argument_names[i], tuple_args, (i > 0 ? ", " : ""))),
++i),
...);
},
args);
return ss.str();
}
// Generate a python list of values with string keyword arguments.
// * A tuple of default argument values is provided to the function.
// * If the default argument is optional and equal to the provided argument,
// then skip printing the keyword-argument pair.
template
std::string generateNamedList(
std::tuple const& default_args,
std::tuple const& args) {
NVF_ERROR(
sizeof...(Ds) == sizeof...(Ts),
"The default and given arguments must have the same size.");
// This immediately-invoked generic lambda uses a C++17 fold expression
// to emulate a loop over the tuple elements.
//
// 1. `std::make_index_sequence` generates a compile-time sequence of
// indices (0, 1, 2...).
// 2. The lambda accepts this sequence, deducing the indices into the
// template pack `Is...`.
// 3. A fold expression over the comma operator expands the code for each
// index.
// 4. A ternary operator `(condition ? ... : ...)` performs the conditional
// logic.
// 5. If condition is true, another comma operator
// `(write_to_stream, increment_counters)` chains the side effects of
// writing to the stringstream and advancing the counters.
// 6. If condition is false, increment `printed_arg_pos` counter by zero.
std::stringstream ss;
[&]<:size_t... is>(std::index_sequence) {
size_t printed_arg_pos = 0;
(((!std::get(default_args).default_value.has_value() ||
std::get(default_args).default_value.value() != std::get(args))
? ((ss << toString(
std::get(default_args).name,
std::get(args),
(printed_arg_pos > 0 ? ", " : ""))),
++printed_arg_pos)
: (printed_arg_pos += 0)),
...);
}(std::make_index_sequence{}); // Generate indices 0..N-1
return ss.str();
}
// Generate a python operation with a list of inputs and outputs.
void generateOperation(
const std::string& op_name,
const std::vector& inputs,
const std::vector& outputs) {
(*os_) << kTab;
if (!outputs.empty()) {
(*os_) << generateOutputs(outputs) << " = ";
}
(*os_) << op_name << "(" << toString(inputs, /*is_list=*/false) << ")\n";
}
// Generate a python operation with a list of inputs and outputs.
// A string keyword argument is added for each input. The default_kwargs
// argument allows skipping arguments if it isn't strictly necessary.
template <
typename... arg_types,
typename... default_kwarg_types,
typename... kwargs_types>
void generateKwargsOperation(
const std::string& op_name,
const std::tuple& args,
const std::tuple& default_kwargs,
const std::tuple& kwargs,
const std::vector& outputs) {
std::string kwargs_str = generateNamedList(default_kwargs, kwargs);
constexpr bool any_arguments = sizeof...(arg_types) == 0;
std::string connect = (any_arguments || kwargs_str.empty()) ? "" : ", ";
(*os_) << kTab << generateOutputs(outputs) << " = " << op_name << "("
<< generateList(args) << connect << kwargs_str << ")\n";
}
// Generate a python operation with a list of inputs and outputs.
// A string is added for each keyword argument.
//
// NOTES
// ------
// - args and kwargs are a tuple, so it accepts a fixed set of arguments of
// any type at compile-time.
// - outputs is a vector of nvfuser values that is converted into a string.
template
void generateKwargsOperation(
const std::string& op_name,
const std::tuple& args,
const std::vector<:string>& kwargs_names,
const std::tuple& kwargs,
const std::vector& outputs) {
std::string connect = (sizeof...(arg_types) == 0) ? "" : ", ";
(*os_) << kTab << generateOutputs(outputs) << " = " << op_name << "("
<< generateList(args) << connect
<< generateNamedList(kwargs_names, kwargs) << ")\n";
}
// Generate a python operation with a list of inputs and a single output.
// A string is added for each keyword argument.
//
// NOTES
// ------
// - args and kwargs are a tuple, so it accepts a fixed set of arguments of
// any type at compile-time.
// - output_name is a string.
template
void generateKwargsOperation(
const std::string& op_name,
const std::tuple& args,
const std::vector<:string>& kwargs_names,
const std::tuple& kwargs,
const std::string& output_name) {
std::string connect = (sizeof...(arg_types) == 0) ? "" : ", ";
(*os_) << kTab << output_name << " = " << op_name << "("
<< generateList(args) << connect
<< generateNamedList(kwargs_names, kwargs) << ")\n";
}
// Generate a python operation with a list of inputs and outputs.
// A string is added for each keyword argument.
//
// NOTES
// ------
// - args and outputs are vectors of nvfuser values that are converted into
// strings.
// - kwargs is a tuple, so it accepts a fixed set of arguments of
// any type at compile-time.
template
void generateKwargsOperation(
const std::string& op_name,
const std::vector& args,
const std::vector<:string>& kwargs_names,
const std::tuple& kwargs,
const std::vector& outputs) {
std::string connect = args.empty() ? "" : ", ";
(*os_) << kTab << generateOutputs(outputs) << " = " << op_name << "("
<< toString(args) << connect
<< generateNamedList(kwargs_names, kwargs) << ")\n";
}
// Generate a python definition for a FusionDefinition.
void generateFusionDefinition() {
(*os_) << "def nvfuser_fusion(fd : FusionDefinition) -> None :\n";
}
private:
//! The stream to print the python function to.
std::ostream* os_;
//! Indentation for python code.
static constexpr const char* kTab = " ";
};
// PythonTranslator converts CPP Fusion to an equivalent python definition.
//
// How to add support for an expression not yet overriden by PythonTranslator?
// 1. Create handle function for expression.
// a. void handle(const SomeOp* op) final
// 2. Check if IR node pointer is not nullptr.
// 3. Add output values for Expr node to visited_vals_.
// 4. Create scalar input arguments. This step is for view and expand
// operations.
// a. TensorView input arguments are handled via DAG traversal.
// 5. Use PythonPrinter to create string for operation.
// a. output = operation(inputs...)
// 6. Use `PythonPrinter::generateOperation` if the operation only uses
// positional arguments. This is mainly used for unary and binary
// operations.
// 7. Use `PythonPrinter::generateKwargsOperation` if the operation uses
// keyword arguments.
// a. If none of the keyword arguments have default arguments, create a
// static vector of strings.
// b. If some of the keyword arguments have default arguments, create a
// vector of KeywordArgument. The KeywordArgument struct hold default
// values for keyword arguments. Use `std::nullopt` for keyword
// arguments without default values.
//
// How to debug PythonTranslator?
// 1. Recompile with debug symbols
// `export NVFUSER_BUILD_BUILD_TYPE=RelwithDebInfo`
// 2. Run `gdb python`
// 3. Catch exception in gdb `(gdb) catch throw`.
// 4. Run failing test.
// `r -m pytest test_python_frontend.py -k [your_failing_test]`
// 5. At gdb Catchpoint, get backtrace for call stack using `(gdb) bt`.
// 6. Find and fix failure in PythonTranslate.
//
// TODO: Python operations without a corresponding Fusion IR node require
// pattern matching.
// 1. Map a series of scalar values to `define_vector`
// 2. Map Squeeze, Reduction, and Broadcast to a single reduction operation
// with keepdim argument.
// 3. Map Broadcast and Expand to `broadcast_in_dim`
// 4. var_mean
// 5. var
class PythonTranslator : public OptInConstDispatch {
public:
// Returns a map from the values in the CPP fusion to its corresponding
// FusionDefinition State index.
static void print(std::ostream& os, Fusion* fusion) {
PythonTranslator translator(os, fusion);
translator.translate();
}
private:
PythonTranslator(std::ostream& os, Fusion* fusion)
: printer_(os), fusion_(fusion) {}
// The output TensorView shape can be dynamic for operations like ReshapeOp.
// Check that all dynamic scalar dependencies are handled first before
// handling the expression.
bool checkDynamicShapeDependency(const Expr* op) {
// short-circuit: Only check operations with dynamic shapes encoded in the
// output TensorView.
if (!op->isOneOf()) {
return true;
}
NVF_ERROR_EQ(op->outputs().size(), 1);
const std::vector& logical_out_domain =
op->output(0)->as()->domain()->logical();
std::vector logical_domain_extents;
std::ranges::copy(
logical_out_domain | std::views::transform([](IterDomain* id) {
return id->getMaybeExpandedExtent();
}),
std::back_inserter(logical_domain_extents));
return std::ranges::all_of(logical_domain_extents, [&](Val* v) {
return v->definition() == nullptr || visited_vals_.count(v) > 0;
});
}
// Gather the expressions necessary to create a scalar value.
std::vector gatherScalarExpressions(Val* v) {
NVF_ERROR(v != nullptr);
NVF_ERROR(v->isScalar());
// short-circuit: v does not have a definition.
if (v->definition() == nullptr) {
return {};
}
std::vector expression_chain;
std::unordered_set visited;
std::vector to_visit = {v->definition()};
while (!to_visit.empty()) {
Expr* e = to_visit.back();
to_visit.pop_back();
expression_chain.push_back(e);
visited.insert(e);
for (Val* input : e->inputs()) {
// short-circuit: input does not have a definition.
if (input->definition() == nullptr) {
continue;
}
// short-circuit: input definition is already visited.
if (visited.count(input->definition()) > 0) {
continue;
}
to_visit.push_back(input->definition());
}
}
return expression_chain;
}
// Gather the scalar expressions necessary to create the logical domain for a
// TensorView.
std::vector gatherScalarExpressions(TensorView* tv) {
NVF_ERROR(tv != nullptr);
std::vector logical_domain_expressions;
const std::vector& logical_out_domain =
tv->domain()->logical();
for (IterDomain* id : logical_out_domain) {
std::vector extent_definitions =
gatherScalarExpressions(id->getMaybeExpandedExtent());
logical_domain_expressions.insert(
logical_domain_expressions.end(),
extent_definitions.begin(),
extent_definitions.end());
}
return logical_domain_expressions;
}
// Check that all of the expression's inputs are defined in FusionDefinition.
bool checkExpressionDependencies(Expr* e) {
// short-circuit: Found an operation without all its dynamic shape
// dependencies.
if (!checkDynamicShapeDependency(e)) {
return false;
}
return std::all_of(
e->inputs().begin(), e->inputs().end(), [&](const Val* v) {
return isPythonScalar(v) || visited_vals_.count(v) > 0;
});
}
void translate() {
printer_.generateFusionDefinition();
// Add Fusion inputs to FusionDefinition
for (nvfuser::Val* v : fusion_->inputs()) {
dispatch(v);
}
// Gather all expressions in CPP Fusion.
const std::vector<:expr> fusion_exprs = fusion_->exprs();
std::deque<:expr> to_visit(
fusion_exprs.begin(), fusion_exprs.end());
// Scalar expressions are not handled by Fusion::exprs, so gather them
// manually.
for (Expr* e : to_visit) {
if (e->isOneOf()) {
NVF_ERROR_EQ(e->outputs().size(), 1);
std::vector extent_definitions =
gatherScalarExpressions(e->output(0)->as());
to_visit.insert(
to_visit.end(),
extent_definitions.begin(),
extent_definitions.end());
}
}
// Topological search of Fusion expressions
size_t skip_count = 0;
std::unordered_set<:expr> visited;
while (!to_visit.empty()) {
Expr* e = to_visit.front();
to_visit.pop_front();
NVF_ERROR(
skip_count <= to_visit.size(),
"Cycle detected: None of the expressions can be processed!");
// short-circuit: skip if already visited
if (visited.count(e) > 0) {
continue;
}
// TODO: short-circuit: skip Split and Merge expressions created by
// Reshape
// TODO: short-circuit: skip Resize expressions created by Slice
// TODO: direct bindings does not support scheduled expressions.
// Handle scalars and constants not generated by separate expression.
std::vector scalars;
std::ranges::copy_if(
e->inputs(), std::back_inserter(scalars), [](Val* v) {
return v->isScalar();
});
std::ranges::for_each(scalars, [this](const Val* v) { dispatch(v); });
// short-circuit: add to back of stack if not all of the expression's
// dependencies are satisfied.
if (!checkExpressionDependencies(e)) {
++skip_count;
to_visit.push_back(e);
continue;
}
// Create string representation given inputs, outputs, and attributes.
visited.insert(e);
dispatch(e);
skip_count = 0;
}
// Add tensor outputs and handle aliased outputs
std::unordered_set<:val> visited_alias_output;
for (nvfuser::Val* v : fusion_->outputs()) {
NVF_ERROR(v->isA());
const AliasInfo& alias_info = fusion_->getOutputAlias(v);
switch (alias_info.type) {
case AllocationType::New: {
handleOutput(v->as());
break;
}
case AllocationType::ReuseBuffer: {
// Only apply aliasing once
if (visited_alias_output.count(v) == 0) {
visited_alias_output.insert(v);
handleOutput(v->as(), alias_info);
}
// If not hide_output, then the aliased output is returned as a
// fusion output.
if (alias_info.visibility == OutputVisibility::kVisible) {
handleOutput(v->as());
}
break;
}
default:
NVF_THROW("Unsupported AllocationType");
}
}
}
// =================================================================================
// Filter Functions
// Gather all TensorViews and FusionDefinition indices
std::vector tensors() {
std::vector tensors;
std::ranges::copy_if(
visited_vals_, std::back_inserter(tensors), [](const nvfuser::Val* v) {
return v->isA();
});
return tensors;
}
// =================================================================================
// Create scalar for given nvfuser value. The nvfuser value must not already
// exist and have a definition. It can be a fusion input, a constant, or a
// tensor's extent.
void handle(const Val* v) final {
NVF_ERROR(v != nullptr);
// short-circuit: scalar definition has a definition
if (v->definition() != nullptr) {
return;
}
// short-circuit: value already exists in FusionDefinition
if (visited_vals_.count(v) > 0) {
return;
}
// short-circuit: print python scalar directly
if (isPythonScalar(v)) {
return;
}
visited_vals_.insert(v);
// Since scalars can come from TensorView dimension sizes, search through
// all TensorViews for an iterDomain whose extent matches the desired
// value and then use size op.
for (const nvfuser::Val* tv_val : tensors()) {
const auto* tv = tv_val->as();
// Get extents for each IterDomain
std::vector filtered_logical_domain =
TensorDomain::noReductions(tv->domain()->logical());
std::vector extents;
extents.reserve(filtered_logical_domain.size());
std::ranges::copy(
filtered_logical_domain | std::views::transform([](IterDomain* id) {
return id->getMaybeExpandedExtent();
}),
std::back_inserter(extents));
// Check if value matches iterdomain extent
auto iter = std::ranges::find(extents, v);
if (iter == extents.end()) {
continue;
}
int64_t dim = std::distance(extents.begin(), iter);
static const std::vector<:string> argument_names = {"dim"};
printer_.generateKwargsOperation(
"fd.ops.size",
std::make_tuple(tv),
argument_names,
std::make_tuple(dim),
{v});
return;
}
static const std::vector<:string> argument_names = {"dtype"};
printer_.generateKwargsOperation(
"fd.define_scalar",
std::make_tuple(v->value()),
argument_names,
std::make_tuple(v->dtype()),
{v});
}
// Add Tensor value to Fusion Definition
void handle(const TensorView* tv) final {
NVF_ERROR(tv != nullptr);
// short-circuit: value already exists in FusionDefinition
if (visited_vals_.count(tv) > 0) {
return;
}
visited_vals_.insert(tv);
std::vector shape;
std::transform(
tv->domain()->logical().begin(),
tv->domain()->logical().end(),
std::back_inserter(shape),
[](IterDomain* id) {
return (id->getMaybeExpandedExtent()->isConstScalar())
? id->getMaybeExpandedExtent()->evaluate().as()
: -1;
});
const std::vector& stride_order = tv->domain()->strideOrder();
static const std::vector<:string> argument_names = {
"shape", "contiguity", "dtype", "is_cpu", "stride_order"};
printer_.generateKwargsOperation(
"fd.define_tensor",
std::make_tuple(),
argument_names,
std::make_tuple(
shape,
tv->domain()->contiguity(),
tv->dtype(),
tv->isCpuScalar(),
(stride_order.empty()) ? std::nullopt
: std::make_optional(stride_order)),
{tv});
}
// =================================================================================
// Utility functions
// Create a vector for the logical domain of TensorView.
// Used with ReshapeOp, ExpandOp, and FullOp handlers
std::vector getShape(TensorView* tv) {
const std::vector& logical_out_domain =
tv->domain()->logical();
std::vector logical_domain_extents;
// Use expanded extent if available for IterDomain.
std::ranges::copy(
logical_out_domain | std::views::transform([](IterDomain* id) {
return id->getMaybeExpandedExtent();
}),
std::back_inserter(logical_domain_extents));
return logical_domain_extents;
}
// Find integer index corresponding with reduction iterDomains
std::vector getReductionAxes(TensorView* tv) {
std::vector axes;
const std::vector& logical_domain = tv->domain()->logical();
for (int64_t dim : c10::irange((int64_t)logical_domain.size())) {
if (logical_domain.at(dim)->isReduction()) {
axes.push_back(dim);
}
}
return axes;
}
// =================================================================================
// Handle add_output variants
// Add Tensor output to FusionDefinition
void handleOutput(const TensorView* tv) {
NVF_ERROR(tv != nullptr);
printer_.generateOperation("fd.add_output", {tv}, {});
}
// Alias output Tensor with input tensor
void handleOutput(const TensorView* tv, const AliasInfo& alias_info) {
NVF_ERROR(tv != nullptr);
printer_.generateOperation(
"fd.add_output", {tv, alias_info.aliased_io}, {});
}
// =================================================================================
// Map CPP Expression classes to corresponding RecordFunctors in
// python_frontend
void handle(const UnaryOp* uop) final {
NVF_ERROR(uop != nullptr);
// short-circuit: Handle cast operation separately
if (uop->getUnaryOpType() == UnaryOpType::Cast) {
return handleCastOp(uop);
}
// Map remaining UnaryOp to python_frontend
visited_vals_.insert(uop->out());
printer_.generateOperation(
"fd.ops." + nvfuser::python::toString(uop), {uop->in()}, {uop->out()});
}
void handleCastOp(const UnaryOp* uop) {
NVF_ERROR(uop->getUnaryOpType() == UnaryOpType::Cast);
visited_vals_.insert(uop->out());
static const std::vector<:string> argument_names = {"dtype"};
printer_.generateKwargsOperation(
"fd.ops.cast",
std::make_tuple(uop->in()),
argument_names,
std::make_tuple(uop->out()->dtype()),
{uop->out()});
}
void handle(const BinaryOp* bop) final {
NVF_ERROR(bop != nullptr);
if (visited_vals_.count(bop->out()) > 0) {
return;
}
visited_vals_.insert(bop->out());
printer_.generateOperation(
"fd.ops." + nvfuser::python::toString(bop),
{bop->lhs(), bop->rhs()},
{bop->out()});
}
void handle(const TernaryOp* top) final {
NVF_ERROR(top != nullptr);
visited_vals_.insert(top->out());
printer_.generateOperation(
"fd.ops." + nvfuser::python::toString(top),
{top->in1(), top->in2(), top->in3()},
{top->out()});
}
void handle(const ReductionOp* rop) final {
NVF_ERROR(rop != nullptr);
NVF_ERROR(rop->out()->isA());
visited_vals_.insert(rop->out());
// The min and max reduction operations expect the dtype argument to by
// PrimDataType::Null
DataType dtype = (rop->getReductionOpType() == BinaryOpType::Min ||
rop->getReductionOpType() == BinaryOpType::FMin ||
rop->getReductionOpType() == BinaryOpType::Max ||
rop->getReductionOpType() == BinaryOpType::FMax)
? DataType::Null
: rop->out()->dtype();
std::vector dims = getReductionAxes(rop->out()->as());
// TODO: keepdim is always False in ReductionOp because a separate
// BroadcastOp node exists if keepdim is True. Detect this pattern to
// minify the python definition.
static const auto default_args = std::make_tuple(
KeywordArgument{
.name = "dims", .default_value = std::nullopt},
KeywordArgument{.name = "keepdim", .default_value = false},
KeywordArgument{
.name = "dtype", .default_value = DataType::Null});
printer_.generateKwargsOperation(
"fd.ops." + nvfuser::python::toString(rop),
std::make_tuple(rop->in()),
default_args,
std::make_tuple(dims, false, dtype),
{rop->out()});
}
void handle(const ScanOp* sop) final {
NVF_ERROR(sop != nullptr);
visited_vals_.insert(sop->out());
static const auto default_args = std::make_tuple(
KeywordArgument{.name = "dim", .default_value = std::nullopt});
printer_.generateKwargsOperation(
"fd.ops." + toString(sop),
std::make_tuple(sop->in()),
default_args,
std::make_tuple(sop->dim()),
{sop->out()});
}
void handle(const WelfordOp* wop) final {
NVF_ERROR(wop != nullptr);
NVF_ERROR(wop->initAvg()->evaluate().as() == 0.0);
NVF_ERROR(wop->initVar()->evaluate().as() == 0.0);
NVF_ERROR(wop->initN()->evaluate().as() == 0);
visited_vals_.insert(wop->outAvg());
visited_vals_.insert(wop->outVar());
visited_vals_.insert(wop->outN());
static const std::vector<:string> argument_names = {"dims"};
printer_.generateKwargsOperation(
"fd.ops.welford",
std::make_tuple(wop->in()),
argument_names,
std::make_tuple(getReductionAxes(wop->outAvg()->as())),
{wop->outAvg(), wop->outVar(), wop->outN()});
}
void handle(const BroadcastOp* bcast_op) final {
NVF_ERROR(bcast_op != nullptr);
visited_vals_.insert(bcast_op->out());
static const std::vector<:string> broadcast_argument_names = {
"is_broadcast_dim"};
printer_.generateKwargsOperation(
"fd.ops.broadcast",
std::make_tuple(bcast_op->in()),
broadcast_argument_names,
std::make_tuple(bcast_op->getBroadcastDimFlags()),
{bcast_op->out()});
}
void handle(const MatmulOp* matmul_op) final {
NVF_ERROR(matmul_op != nullptr);
visited_vals_.insert(matmul_op->out());
printer_.generateOperation(
"fd.ops.matmul",
{matmul_op->inA(), matmul_op->inB()},
{matmul_op->out()});
}
void handle(const LinearOp* lop) final {
NVF_ERROR(lop != nullptr);
visited_vals_.insert(lop->out());
static const auto default_args = std::make_tuple(
KeywordArgument{.name = "bias", .default_value = nullptr});
printer_.generateKwargsOperation(
"fd.ops.linear",
std::make_tuple(lop->inA(), lop->inB()),
default_args,
std::make_tuple(lop->bias()),
{lop->out()});
}
void handle(const GroupedMmaOp* gmm_op) final {
NVF_ERROR(gmm_op != nullptr);
TensorView* out_tv = gmm_op->out();
visited_vals_.insert(gmm_op->out());
int64_t out_block_scale_size = 0;
PrimDataType out_block_scale_dtype = DataType::BFloat16;
bool out_gamma = false;
TensorView* out_block_scale_tv = gmm_op->outScale();
if (out_block_scale_tv != nullptr) {
visited_vals_.insert(gmm_op->outScale());
const std::vector& logical =
out_block_scale_tv->getLogicalDomain();
Val* block_size_extent = logical.at(logical.size() - 1)->extent();
NVF_CHECK(
block_size_extent->isConstInt(),
"Block size extent needs to be a constant integer");
out_block_scale_size = block_size_extent->evaluate().as();
out_block_scale_dtype =
std::get(out_block_scale_tv->dtype().type);
}
TensorView* out_gamma_tv = gmm_op->outGamma();
if (out_gamma_tv != nullptr) {
visited_vals_.insert(gmm_op->outGamma());
out_gamma = true;
}
if (gmm_op->inputs().size() == 3) {
printer_.generateOperation(
"fd.ops.grouped_mm",
{gmm_op->matrix1(), gmm_op->matrix2(), gmm_op->offsets()},
{gmm_op->out()});
} else {
static const auto default_args = std::make_tuple(
KeywordArgumentalpha())>{
.name = "alpha", .default_value = nullptr},
KeywordArgumentbias())>{
.name = "bias", .default_value = nullptr},
KeywordArgumentbeta())>{
.name = "beta", .default_value = nullptr},
KeywordArgument{
.name = "dtype", .default_value = DataType::BFloat16},
KeywordArgument{
.name = "output_block_scale_size", .default_value = 0},
KeywordArgument{
.name = "output_block_scale_dtype",
.default_value = DataType::BFloat16},
KeywordArgument{
.name = "output_gamma", .default_value = false});
printer_.generateKwargsOperation(
"fd.ops.grouped_mm",
std::make_tuple(
gmm_op->matrix1(),
gmm_op->matrix2(),
gmm_op->offsets(),
gmm_op->scale1(),
gmm_op->scale2()),
default_args,
std::make_tuple(
gmm_op->alpha(),
gmm_op->bias(),
gmm_op->beta(),
out_tv->dtype(),
out_block_scale_size,
out_block_scale_dtype,
out_gamma),
{gmm_op->out(), out_block_scale_tv, out_gamma_tv});
}
}
void handle(const CutlassNvfp4GroupedMmaOp* cmm_op) final {
NVF_ERROR(cmm_op != nullptr);
visited_vals_.insert(cmm_op->out());
printer_.generateOperation(
"fd.ops.cutlass_nvfp4_grouped_mm",
{cmm_op->matrix1(),
cmm_op->matrix2(),
cmm_op->scale1(),
cmm_op->scale2(),
cmm_op->alpha(),
cmm_op->problemSizes(),
cmm_op->expertOffsets(),
cmm_op->scalingFactorOffsets()},
{cmm_op->out()});
}
void handle(const ScaledMmaOp* smm_op) final {
NVF_ERROR(smm_op != nullptr);
TensorView* out_tv = smm_op->out();
visited_vals_.insert(smm_op->out());
int64_t out_block_scale_size = 0;
PrimDataType out_block_scale_dtype = DataType::BFloat16;
bool out_gamma = false;
TensorView* out_block_scale_tv = smm_op->outScale();
if (out_block_scale_tv != nullptr) {
visited_vals_.insert(smm_op->outScale());
const std::vector& logical =
out_block_scale_tv->getLogicalDomain();
Val* block_size_extent = logical.at(logical.size() - 1)->extent();
NVF_CHECK(
block_size_extent->isConstInt(),
"Block size extent needs to be a constant integer");
out_block_scale_size = block_size_extent->evaluate().as();
out_block_scale_dtype =
std::get(out_block_scale_tv->dtype().type);
}
TensorView* out_gamma_tv = smm_op->outGamma();
if (out_gamma_tv != nullptr) {
visited_vals_.insert(smm_op->outGamma());
out_gamma = true;
}
static const auto default_args = std::make_tuple(
KeywordArgumentalpha())>{
.name = "alpha", .default_value = nullptr},
KeywordArgumentbias())>{
.name = "bias", .default_value = nullptr},
KeywordArgumentbeta())>{
.name = "beta", .default_value = nullptr},
KeywordArgument{
.name = "dtype", .default_value = DataType::BFloat16},
KeywordArgument{
.name = "output_block_scale_size", .default_value = 0},
KeywordArgument{
.name = "output_block_scale_dtype",
.default_value = DataType::BFloat16},
KeywordArgument{.name = "output_gamma", .default_value = false});
printer_.generateKwargsOperation(
"fd.ops.scaled_mm",
std::make_tuple(
smm_op->matrix1(),
smm_op->matrix2(),
smm_op->scale1(),
smm_op->scale2()),
default_args,
std::make_tuple(
smm_op->alpha(),
smm_op->bias(),
smm_op->beta(),
out_tv->dtype(),
out_block_scale_size,
out_block_scale_dtype,
out_gamma),
{smm_op->out(), out_block_scale_tv, out_gamma_tv});
}
void handle(const SdpaFwdOp* sdpa_fwd_op) final {
NVF_ERROR(sdpa_fwd_op != nullptr);
static const auto default_args = std::make_tuple(
KeywordArgument{.name = "bias", .default_value = nullptr},
KeywordArgument{.name = "mask", .default_value = nullptr},
KeywordArgument{.name = "dropout_p", .default_value = nullptr},
KeywordArgument{.name = "is_causal", .default_value = nullptr},
KeywordArgument{.name = "scale", .default_value = nullptr});
visited_vals_.insert(sdpa_fwd_op->attn_out());
visited_vals_.insert(sdpa_fwd_op->logsumexp());
visited_vals_.insert(sdpa_fwd_op->philox_seed());
visited_vals_.insert(sdpa_fwd_op->philox_offset());
printer_.generateKwargsOperation(
"fd.ops.sdpfa_fwd",
std::make_tuple(
sdpa_fwd_op->query(), sdpa_fwd_op->key(), sdpa_fwd_op->value()),
default_args,
std::make_tuple(
sdpa_fwd_op->bias(),
sdpa_fwd_op->mask(),
sdpa_fwd_op->dropout_p(),
sdpa_fwd_op->is_causal(),
sdpa_fwd_op->scale()),
{sdpa_fwd_op->attn_out(),
sdpa_fwd_op->logsumexp(),
sdpa_fwd_op->philox_seed(),
sdpa_fwd_op->philox_offset()});
}
void handle(const SdpaBwdOp* sdpa_bwd_op) final {
NVF_ERROR(sdpa_bwd_op != nullptr);
static const std::vector<:string> argument_names = {
"dropout_p", "is_causal", "philox_seed", "philox_offset", "scale"};
visited_vals_.insert(sdpa_bwd_op->grad_query());
visited_vals_.insert(sdpa_bwd_op->grad_key());
visited_vals_.insert(sdpa_bwd_op->grad_value());
printer_.generateKwargsOperation(
"fd.ops.sdpfa_bwd",
std::make_tuple(
sdpa_bwd_op->grad_attn(),
sdpa_bwd_op->query(),
sdpa_bwd_op->key(),
sdpa_bwd_op->value(),
sdpa_bwd_op->attn_out(),
sdpa_bwd_op->logsumexp()),
argument_names,
std::make_tuple(
sdpa_bwd_op->dropout_p(),
sdpa_bwd_op->is_causal(),
sdpa_bwd_op->philox_seed(),
sdpa_bwd_op->philox_offset(),
sdpa_bwd_op->scale()),
{sdpa_bwd_op->grad_query(),
sdpa_bwd_op->grad_key(),
sdpa_bwd_op->grad_value()});
}
void handle(const SqueezeOp* sop) final {
NVF_ERROR(sop != nullptr);
visited_vals_.insert(sop->out());
const std::vector& is_squeeze_dims = sop->getSqueezeDimFlags();
auto filter_range = std::views::iota(0UL, is_squeeze_dims.size()) |
std::views::filter([&is_squeeze_dims](int64_t dim) {
return is_squeeze_dims.at(dim);
});
std::vector squeeze_dims(filter_range.begin(), filter_range.end());
auto* in_tv = sop->in()->as();
NVF_ERROR(in_tv != nullptr);
// TODO: Use std::ranges::zip_view AND std::ranges::any_of with cpp23
bool squeeze_expanded = false;
for (auto [squeeze_dim, id] :
zip(is_squeeze_dims, in_tv->getLogicalDomain())) {
if (!squeeze_dim) {
continue;
}
squeeze_expanded |= (id->isBroadcast() && id->hasExpandedExtent());
}
static const auto default_args = std::make_tuple(
KeywordArgument{
.name = "dims", .default_value = std::nullopt},
KeywordArgument{
.name = "squeeze_expanded", .default_value = false});
printer_.generateKwargsOperation(
"fd.ops.squeeze",
std::make_tuple(sop->in()),
default_args,
std::make_tuple(squeeze_dims, squeeze_expanded),
{sop->out()});
}
void handle(const ReshapeOp* vop) final {
NVF_ERROR(vop != nullptr);
// Get extent's for output's logical domain
auto* out_tv = vop->out()->as();
std::vector new_shape = getShape(out_tv);
// TODO Check if new_shape is a vector of symbolic fusion inputs
// TODO Use define_vector to create more pythonic syntax
// Add CPP values to Fusion Definition if necessary
static const std::vector<:string> reshape_argument_names = {
"new_shape"};
std::ranges::for_each(new_shape, [this](const Val* v) { dispatch(v); });
visited_vals_.insert(vop->out());
printer_.generateKwargsOperation(
"fd.ops.reshape",
std::make_tuple(vop->in()),
reshape_argument_names,
std::make_tuple(new_shape),
{vop->out()});
}
void handle(const ExpandOp* eop) final {
NVF_ERROR(eop != nullptr);
auto* in_tv = eop->in()->as();
auto* out_tv = eop->out()->as();
NVF_ERROR(in_tv->nDims() == out_tv->nDims());
std::vector shape = getShape(out_tv);
static const std::vector<:string> expand_argument_names = {"shape"};
// Add CPP values to Fusion Definition if necessary
std::ranges::for_each(shape, [this](const Val* v) { dispatch(v); });
visited_vals_.insert(eop->out());
printer_.generateKwargsOperation(
"fd.ops.expand",
std::make_tuple(eop->in()),
expand_argument_names,
std::make_tuple(shape),
{eop->out()});
}
void handle(const SliceOp* sop) final {
NVF_ERROR(sop != nullptr);
std::vector<:slice> slices = sop->getRanges();
std::vector start_indices;
start_indices.reserve(slices.size());
std::vector stop_indices;
stop_indices.reserve(slices.size());
std::vector strides;
strides.reserve(slices.size());
for (const nvfuser::Slice& s : slices) {
start_indices.push_back(s.start);
stop_indices.push_back(s.stop);
strides.push_back(s.step);
}
visited_vals_.insert(sop->out());
// Since the normalization operations are expressed in the Fusion IR,
// manual_normalization argument is always true and default arguments is not
// used here.
static const std::vector<:string> slice_argument_names = {
"start_indices", "end_indices", "strides", "manual_normalization"};
printer_.generateKwargsOperation(
"fd.ops.slice",
std::make_tuple(sop->in()),
slice_argument_names,
std::make_tuple(
start_indices,
stop_indices,
strides,
/*manual_normalization=*/true),
{sop->out()});
}
void handle(const PadOp* pad_op) final {
NVF_ERROR(pad_op != nullptr);
// Step 1: Get pad widths in normalized order.
std::vector normalized_pad_widths = pad_op->getPadWidths();
int64_t total_size = (int64_t)normalized_pad_widths.size();
// Step 2: Get indices for normalized pad widths.
std::vector normalized_indices(total_size);
std::iota(normalized_indices.begin(), normalized_indices.end(), 0);
// Step 3: Transform to indices for original pad widths
std::vector original_indices;
original_indices.reserve(normalized_indices.size());
std::ranges::transform(
normalized_indices,
std::back_inserter(original_indices),
[=](int64_t normalized_idx) {
int64_t offset = total_size - normalized_idx;
int64_t dim = ceilDiv(offset, 2) - 1;
int64_t original_idx = dim * 2;
// right pad values require an additional offset
if (offset % 2 == 1) {
original_idx += 1;
}
return original_idx;
});
// Step 4: Get pad widths in original order.
std::vector original_order_pad_widths(total_size, nullptr);
for (int64_t normalized_idx : normalized_indices) {
original_order_pad_widths.at(original_indices.at(normalized_idx)) =
normalized_pad_widths.at(normalized_idx);
}
// Check that no pad width values are nullptr.
NVF_ERROR(std::ranges::all_of(
original_order_pad_widths, [](Val* v) { return v != nullptr; }));
visited_vals_.insert(pad_op->out());
static const auto default_args = std::make_tuple(
KeywordArgument{
.name = "pad_widths", .default_value = std::nullopt},
KeywordArgument{.name = "value", .default_value = nullptr});
printer_.generateKwargsOperation(
"fd.ops.pad",
std::make_tuple(pad_op->in()),
default_args,
std::make_tuple(original_order_pad_widths, pad_op->value()),
{pad_op->out()});
}
void handle(const CatOp* cat_op) final {
NVF_ERROR(cat_op != nullptr);
visited_vals_.insert(cat_op->output(0));
// Since the normalization operations are expressed in the Fusion IR,
// manual_normalization argument is always true and default arguments is not
// used here.
static const std::vector<:string> cat_argument_names = {
"dim", "manual_padding"};
printer_.generateKwargsOperation(
"fd.ops.cat",
cat_op->inputs(),
cat_argument_names,
std::make_tuple(cat_op->concatenatedDim(), /*manual_padding=*/true),
{cat_op->output(0)});
}
// Map RNGOp to RandomDistOpRecord
void handle(const RNGOp* rop) final {
NVF_ERROR(rop != nullptr);
visited_vals_.insert(rop->output(0));
std::string rng_op_name;
switch (rop->getRNGOpType()) {
case RNGOpType::Uniform:
case RNGOpType::UniformRange:
rng_op_name = "fd.ops.uniform";
break;
case RNGOpType::NormalStandard:
case RNGOpType::NormalGeneral:
rng_op_name = "fd.ops.normal";
break;
default:
NVF_ERROR(false, "Unsupported RNGOpType.");
}
static const auto default_args = std::make_tuple(
KeywordArgument{.name = "rng_seed", .default_value = nullptr},
KeywordArgument{.name = "rng_offset", .default_value = nullptr},
KeywordArgument{
.name = "dtype", .default_value = DataType::Float});
Val* first_arg = nullptr;
Val* second_arg = nullptr;
NVF_ERROR(rop->getParameters().size() == 2 || rop->getParameters().empty());
if (rop->getParameters().size() == 2) {
first_arg = rop->getParameters().at(0);
second_arg = rop->getParameters().at(1);
} else {
// Default arg1 and arg2 is (0, 1) for both uniform and normal.
first_arg = fusion_->zeroVal();
second_arg = fusion_->oneVal();
}
NVF_ERROR(first_arg != nullptr && second_arg != nullptr);
printer_.generateKwargsOperation(
rng_op_name,
std::make_tuple(first_arg, second_arg, rop->getShape()),
default_args,
std::make_tuple(
rop->getRNGSeedVal(), rop->getRNGOffsetVal(), rop->dtype()),
{rop->output(0)});
}
// If input and output values share the same type, a LoadStoreOp will be
// created instead of a CastOp.
void handle(const LoadStoreOp* lsop) final {
if (lsop->out()->isA()) {
auto* out_tv = lsop->out()->as();
NVF_ERROR(!(out_tv->hasRoot() && out_tv->hasAllocation()));
// short-circuit: lsop is a permutation.
if (out_tv->hasRoot()) {
return handlePermute(lsop);
}
// short-circuit: lsop is a stride_order.
if (out_tv->hasAllocation()) {
return handleStrideOrder(lsop);
}
}
NVF_ERROR(
lsop->in()->dtype() == lsop->out()->dtype(),
"Expected the dtype for input and output to be the same");
visited_vals_.insert(lsop->out());
static const std::vector<:string> argument_names = {"dtype"};
printer_.generateKwargsOperation(
"fd.ops.cast",
std::make_tuple(lsop->in()),
argument_names,
std::make_tuple(lsop->out()->dtype()),
{lsop->out()});
}
void handlePermute(const LoadStoreOp* lsop) {
auto* out_tv = lsop->out()->as();
std::optional<:vector>> new2old_opt =
ir_utils::computePermutation(
out_tv->getRootDomain(), out_tv->getLogicalDomain());
NVF_ERROR(new2old_opt.has_value(), "Expected permutation");
visited_vals_.insert(lsop->out());
static const std::vector<:string> argument_names = {"dims"};
printer_.generateKwargsOperation(
"fd.ops.permute",
std::make_tuple(lsop->in()),
argument_names,
std::make_tuple(new2old_opt.value()),
{lsop->out()});
}
void handleStrideOrder(const LoadStoreOp* lsop) {
auto* out_tv = lsop->out()->as();
visited_vals_.insert(lsop->out());
static const std::vector<:string> argument_names = {"stride_order"};
printer_.generateKwargsOperation(
"fd.ops.stride_order",
std::make_tuple(lsop->in()),
argument_names,
std::make_tuple(out_tv->domain()->strideOrder()),
{lsop->out()});
}
void handle(const FullOp* fop) final {
NVF_ERROR(fop != nullptr);
auto* out_tv = fop->output(0)->as();
visited_vals_.insert(out_tv);
// Fill value can be dynamic so create it
dispatch(fop->getFillValue());
static const std::vector<:string> argument_names = {
"shape", "fill_value", "dtype"};
printer_.generateKwargsOperation(
"fd.ops.full",
std::make_tuple(),
argument_names,
std::make_tuple(getShape(out_tv), fop->getFillValue(), out_tv->dtype()),
{out_tv});
}
void handle(const IotaOp* iop) final {
NVF_ERROR(iop != nullptr);
auto* out_tv = iop->output(0)->as();
visited_vals_.insert(out_tv);
dispatch(iop->length());
dispatch(iop->start());
dispatch(iop->step());
static const auto default_args = std::make_tuple(
KeywordArgumentlength())>{
.name = "length", .default_value = std::nullopt},
KeywordArgumentstart())>{
.name = "start", .default_value = nullptr},
KeywordArgumentstep())>{
.name = "step", .default_value = nullptr},
KeywordArgument{
.name = "dtype", .default_value = DataType::Int});
printer_.generateKwargsOperation(
"fd.ops.iota",
std::make_tuple(),
default_args,
std::make_tuple(iop->length(), iop->start(), iop->step(), iop->dtype()),
{out_tv});
}
void handle(const IndexSelectOp* isop) final {
NVF_ERROR(isop != nullptr);
auto* out_tv = isop->output(0)->as();
visited_vals_.insert(out_tv);
static const std::vector<:string> argument_names = {"dim"};
printer_.generateKwargsOperation(
"fd.ops.index_select",
std::make_tuple(isop->lookupTv(), isop->indexTv()),
argument_names,
std::make_tuple(isop->dim()),
{out_tv});
}
void handle(const SelectOp* sop) final {
NVF_ERROR(sop != nullptr);
auto* out_tv = sop->output(0)->as();
visited_vals_.insert(out_tv);
static const std::vector<:string> argument_names = {"dim"};
printer_.generateKwargsOperation(
"fd.ops.select",
std::make_tuple(sop->lookupTv(), sop->input(1)),
argument_names,
std::make_tuple(sop->dim()),
{out_tv});
}
void handle(const ScatterOp* sop) final {
NVF_ERROR(sop != nullptr);
auto* out_tv = sop->output(0)->as();
visited_vals_.insert(out_tv);
static const std::vector<:string> argument_names = {"dim"};
printer_.generateKwargsOperation(
"fd.ops.scatter",
std::make_tuple(sop->in(), sop->index(), sop->src()),
argument_names,
std::make_tuple(sop->dim()),
{out_tv});
}
void handle(const GatherOp* gop) final {
NVF_ERROR(gop != nullptr);
auto* out_tv = gop->output(0)->as();
visited_vals_.insert(out_tv);
static const std::vector<:string> argument_names = {"dim"};
printer_.generateKwargsOperation(
(gop->exactSizes() ? "fd.ops.take_along_axis" : "fd.ops.gather"),
std::make_tuple(gop->lookupTv(), gop->indexTv()),
argument_names,
std::make_tuple(gop->dim()),
{out_tv});
}
void handle(const TopKOp* topkop) final {
NVF_ERROR(topkop != nullptr);
visited_vals_.insert(topkop->output(0));
visited_vals_.insert(topkop->output(1));
static const auto default_args = std::make_tuple(
KeywordArgumentdim())>{
.name = "dim", .default_value = -1},
KeywordArgument{.name = "largest", .default_value = true},
KeywordArgument{.name = "sorted", .default_value = false});
printer_.generateKwargsOperation(
"fd.ops.topk",
std::make_tuple(topkop->in(), topkop->k()),
default_args,
std::make_tuple(topkop->dim(), topkop->isLargest(), topkop->isSorted()),
std::vector{topkop->output(0), topkop->output(1)});
}
void handle(const ArgsortOp* argsortop) final {
NVF_ERROR(argsortop != nullptr);
auto* out_tv = argsortop->output(0)->as();
visited_vals_.insert(out_tv);
static const auto default_args = std::make_tuple(
KeywordArgumentdim())>{
.name = "dim", .default_value = std::nullopt},
KeywordArgument{.name = "descending", .default_value = false},
KeywordArgument{.name = "stable", .default_value = false});
printer_.generateKwargsOperation(
"fd.ops.argsort",
std::make_tuple(argsortop->in()),
default_args,
std::make_tuple(
argsortop->dim(), argsortop->isDescending(), argsortop->isStable()),
{out_tv});
}
// Map EmbeddingFwdOp to python frontend
void handle(const EmbeddingFwdOp* eop) final {
NVF_ERROR(eop != nullptr);
visited_vals_.insert(eop->output(0));
static const auto default_args = std::make_tuple(
KeywordArgument{.name = "padding_idx", .default_value = nullptr},
KeywordArgument{.name = "max_norm", .default_value = nullptr},
KeywordArgument{.name = "norm_type", .default_value = nullptr},
KeywordArgument{
.name = "scale_grad_by_freq", .default_value = nullptr},
KeywordArgument{.name = "sparse", .default_value = nullptr});
printer_.generateKwargsOperation(
"fd.ops.embedding_fwd",
std::make_tuple(eop->in(), eop->weight()),
default_args,
std::make_tuple(
eop->padding_idx(),
eop->max_norm(),
eop->norm_type(),
eop->scale_grad_by_freq(),
eop->sparse()),
{eop->out()});
}
void handle(const PreprocessGroupedMatmulInputSf* layout_op) final {
NVF_ERROR(layout_op != nullptr);
visited_vals_.insert(layout_op->output(0));
printer_.generateOperation(
"fd.ops.preprocess_grouped_matmul_input_sf",
{layout_op->in()->as(),
layout_op->inputOffsets(),
layout_op->outputOffsets()},
{layout_op->out()});
}
// Map BlockQuantizationOp to python frontend
void handle(const BlockQuantizationOp* bqop) final {
NVF_ERROR(bqop != nullptr);
visited_vals_.insert(bqop->output(0));
visited_vals_.insert(bqop->output(1));
static const auto default_args = std::make_tuple(
KeywordArgumentglobalScale())>{
.name = "global_scale", .default_value = nullptr},
KeywordArgument{.name = "block_size", .default_value = 16},
KeywordArgument{
.name = "swizzle_block_scales", .default_value = false},
KeywordArgument{
.name = "dtype", .default_value = DataType::Float4_e2m1fn});
auto dtype = bqop->quantizedOutput()->as()->dtype();
printer_.generateKwargsOperation(
"fd.ops.nv_block_quantize",
std::make_tuple(bqop->in()),
default_args,
std::make_tuple(
bqop->globalScale(),
bqop->blockSize(),
bqop->isSwizzledScales(),
dtype),
std::vector{bqop->output(0), bqop->output(1)});
}
void handle(const GroupedBlockQuantizationOp* grouped_bqop) final {
NVF_ERROR(grouped_bqop != nullptr);
visited_vals_.insert(grouped_bqop->output(0));
visited_vals_.insert(grouped_bqop->output(1));
static const auto default_args = std::make_tuple(
KeywordArgumentglobalScale())>{
.name = "global_scale", .default_value = nullptr},
KeywordArgument{.name = "block_size", .default_value = 16},
KeywordArgument{
.name = "dtype", .default_value = DataType::Float4_e2m1fn});
auto dtype = grouped_bqop->quantizedOutput()->as()->dtype();
printer_.generateKwargsOperation(
"fd.ops.nv_grouped_block_quantize",
std::make_tuple(
grouped_bqop->in(),
grouped_bqop->inputOffsets(),
grouped_bqop->outputOffsets()),
default_args,
std::make_tuple(
grouped_bqop->globalScale(), grouped_bqop->blockSize(), dtype),
std::vector{
grouped_bqop->output(0), grouped_bqop->output(1)});
}
private:
//! Convert CPP values to python syntax.
PythonPrinter printer_;
//! The reference CPP fusion to be translated.
Fusion* fusion_ = nullptr;
//! Set of NvFuser Val's created in the Fusion.
std::unordered_set visited_vals_;
};
} // namespace
std::string translateFusion(nvfuser::Fusion* f) {
std::stringstream ss;
PythonTranslator::print(ss, f);
return ss.str();
}
} // namespace nvfuser::python