-
Notifications
You must be signed in to change notification settings - Fork 81
Expand file tree
/
Copy pathid_model.cpp
More file actions
153 lines (137 loc) · 3.51 KB
/
id_model.cpp
File metadata and controls
153 lines (137 loc) · 3.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <bindings.h>
#include <id_model/id_model.h>
#include <val_graph.h>
namespace nvfuser::python {
namespace {
void bindIdModelClass(py::module_& idm) {
py::class_<IdModel, std::unique_ptr<IdModel>> id_model(idm, "IdModel");
id_model.def(
py::init([](Fusion* fusion,
bool build_graphs,
bool allow_self_mapping,
bool validate) {
return std::make_unique<IdModel>(
fusion, build_graphs, allow_self_mapping, validate);
}),
py::arg("fusion"),
py::arg("build_graphs") = false,
py::arg("allow_self_mapping") = true,
py::arg("validate") = false,
R"(
Create a new IdModel for the given fusion.
Parameters
----------
fusion : Fusion
The fusion to create the IdModel for
build_graphs : bool
Whether to build graphs
allow_self_mapping : bool
Whether to allow self mapping
validate : bool
Whether to validate graphs
Returns
-------
IdModel
The created IdModel
)");
id_model.def(
"__str__",
&IdModel::toString,
R"(
Returns the string representation of the IdModel.
)");
id_model.def(
"maybe_build_graph",
&IdModel::maybeBuildGraph,
py::arg("mode"),
py::return_value_policy::reference,
R"(
Build a graph if not already built.
Dependent graphs are also built if not yet done.
Parameters
----------
mode : IdMappingMode
The mode to build the graph for
Returns
-------
ValGraph
The graph built
)");
}
void bindValGraph(py::module_& idm) {
py::class_<ValGraph, std::unique_ptr<ValGraph>> val_graph(idm, "ValGraph");
val_graph.def(
"disjoint_val_sets",
&ValGraph::disjointValSets,
py::return_value_policy::reference,
R"(
Returns the disjoint val set.
Returns
-------
DisjointValSets
The disjoint val set
)");
val_graph.def(
"__str__",
&ValGraph::toString,
R"(
Returns the string representation of the ValGraph.
)");
val_graph.def(
"map_vals",
&ValGraph::mapVals,
py::arg("val0"),
py::arg("val1"),
R"(Maps the two values.
Parameters
----------
val0 : Val
The first value to map
val1 : Val
The second value to map
)");
}
void bindDisjointSets(py::module_& id_model) {
py::class_<DisjointSets<Val*>, std::unique_ptr<DisjointSets<Val*>>>
disjoint_sets(id_model, "DisjointValSets");
disjoint_sets.def(
"__str__",
&DisjointSets<Val*>::toString,
R"(
Returns the string representation of the DisjointSets.
)");
disjoint_sets.def(
"strict_are_mapped",
&DisjointSets<Val*>::strictAreMapped,
py::arg("entry0"),
py::arg("entry1"),
R"(
Returns if the two entries are strictly mapped.
Parameters
----------
entry0 : Val
The first entry to check
entry1 : Val
The second entry to check
Returns
-------
bool
True if the two entries are strictly mapped, False otherwise.
)");
}
} // namespace
void bindIdModel(py::module& nvfuser) {
py::module_ idm = nvfuser.def_submodule(
"idm", "This submodule contains all id model operators for NvFuser.");
bindIdModelClass(idm);
bindValGraph(idm);
bindDisjointSets(idm);
}
} // namespace nvfuser::python