-
Notifications
You must be signed in to change notification settings - Fork 81
Expand file tree
/
Copy pathdirect_utils.cpp
More file actions
71 lines (62 loc) · 2.2 KB
/
direct_utils.cpp
File metadata and controls
71 lines (62 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <direct_utils.h>
#include <algorithm>
namespace nvfuser::python {
namespace {
PolymorphicValue toPolymorphicValue(const py::handle& obj) {
static py::object torch_Tensor = py::module_::import("torch").attr("Tensor");
if (py::isinstance(obj, torch_Tensor)) {
return PolymorphicValue(py::cast<at::Tensor>(obj));
} else if (py::isinstance<py::bool_>(obj)) {
return PolymorphicValue(py::cast<bool>(obj));
} else if (py::isinstance<py::int_>(obj)) {
return PolymorphicValue(py::cast<int64_t>(obj));
} else if (py::isinstance<py::float_>(obj)) {
return PolymorphicValue(py::cast<double>(obj));
} else if (PyComplex_Check(obj.ptr())) {
return PolymorphicValue(py::cast<std::complex<double>>(obj));
}
NVF_THROW("Cannot convert provided py::handle to a PolymorphicValue.");
}
} // namespace
KernelArgumentHolder from_pyiterable(
const py::iterable& iter,
std::optional<int64_t> device) {
KernelArgumentHolder args;
for (py::handle obj : iter) {
// Allows for a Vector of Sizes to be inputed as a list/tuple
if (py::isinstance<py::list>(obj) || py::isinstance<py::tuple>(obj)) {
for (py::handle item : obj) {
args.push(toPolymorphicValue(item));
}
} else {
args.push(toPolymorphicValue(obj));
}
}
// Transform int64_t device to int8_t
std::optional<int8_t> selected_device = std::nullopt;
if (device.has_value()) {
NVF_CHECK(device.value() < 256, "Maximum device index is 255");
selected_device = (int8_t)device.value();
}
args.setDeviceIndex(selected_device);
return args;
}
std::vector<at::Tensor> to_tensor_vector(const KernelArgumentHolder& outputs) {
// Convert outputs KernelArgumentHolder to std::vector<at::Tensor>
std::vector<at::Tensor> out_tensors;
out_tensors.reserve(outputs.size());
std::transform(
outputs.begin(),
outputs.end(),
std::back_inserter(out_tensors),
[](const PolymorphicValue& out) { return out.as<at::Tensor>(); });
return out_tensors;
}
} // namespace nvfuser::python