Skip to content

Commit 6fc67e4

Browse files
swachhandlpak-laura
authored andcommitted
Replace CHECK with returning an InternalError on failing to create python tuple
Returns InternalError if a non UTF-8 string is passed in as token, instead of crashing. PiperOrigin-RevId: 478123498
1 parent 5dbe90a commit 6fc67e4

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

tensorflow/python/lib/core/py_func.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ bool IsCPUDevice(const Device* d) {
8383
return d == nullptr || d->tensorflow_accelerator_device_info() == nullptr;
8484
}
8585

86-
// Givens the 'call', prepares the token and inputs as a python tuple
87-
// that is appropriate for calling the trampoline.
86+
// Given the 'call', prepares the token and inputs as a python tuple that is
87+
// appropriate for calling the trampoline.
8888
Status MakeArgTuple(const PyCall* call, TFE_Context* ctx, PyObject** tuple) {
8989
int64_t n = call->ins.size();
9090
PyObject* lst = PyList_New(n);
@@ -119,8 +119,12 @@ Status MakeArgTuple(const PyCall* call, TFE_Context* ctx, PyObject** tuple) {
119119
PyList_SetItem(lst, i, arg);
120120
}
121121
*tuple = Py_BuildValue("(ssN)", call->token.c_str(), device_name, lst);
122-
CHECK(*tuple);
123-
return Status::OK();
122+
if (*tuple == nullptr) {
123+
return errors::Internal(
124+
"Failed to create python tuple. Please make sure `token` is a "
125+
"well-formed UTF-8 string.");
126+
}
127+
return OkStatus();
124128
}
125129

126130
bool IsSingleNone(PyObject* obj) {

tensorflow/python/ops/script_ops_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616

1717
from tensorflow.python.eager import def_function
1818
from tensorflow.python.framework import dtypes
19+
from tensorflow.python.framework import errors
1920
from tensorflow.python.framework import test_util
2021
from tensorflow.python.framework import constant_op
22+
from tensorflow.python.ops import gen_script_ops
23+
from tensorflow.python.ops import resource_variable_ops
2124
from tensorflow.python.ops import script_ops
2225
from tensorflow.python.ops.script_ops import numpy_function
2326
from tensorflow.python.platform import test
@@ -87,5 +90,30 @@ def func_stateful(a, b):
8790
2) # as stateful, func is guaranteed to execute twice
8891

8992

93+
class PyFunctionTest(test.TestCase):
94+
95+
@test_util.run_in_graph_and_eager_modes
96+
def test_variable_arguments(self):
97+
98+
def plus(a, b):
99+
return a + b
100+
101+
v1 = resource_variable_ops.ResourceVariable(1)
102+
self.evaluate(v1.initializer)
103+
104+
actual_result = script_ops.eager_py_func(plus, [v1, 2], dtypes.int32)
105+
expect_result = constant_op.constant(3, dtypes.int32)
106+
self.assertAllEqual(actual_result, expect_result)
107+
108+
@test_util.run_in_graph_and_eager_modes
109+
def test_fail_on_non_utf8_token(self):
110+
value = constant_op.constant(value=[1, 2])
111+
token = b"\xb0"
112+
data_type = [dtypes.int32]
113+
with self.assertRaises((errors.InternalError, UnicodeDecodeError)):
114+
self.evaluate(
115+
gen_script_ops.py_func(input=[value], token=token, Tout=data_type))
116+
117+
90118
if __name__ == "__main__":
91119
test.main()

0 commit comments

Comments
 (0)