Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions crates/stdlib/src/faulthandler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ pub(crate) use decl::make_module;
#[pymodule(name = "faulthandler")]
mod decl {
use crate::vm::{
PyObjectRef, PyResult, VirtualMachine, builtins::PyFloat, frame::Frame,
function::OptionalArg, py_io::Write,
PyObjectRef, PyResult, VirtualMachine,
frame::Frame,
function::{ArgIntoFloat, OptionalArg},
py_io::Write,
};
use alloc::sync::Arc;
use core::sync::atomic::{AtomicBool, AtomicI32, Ordering};
Expand Down Expand Up @@ -762,8 +764,8 @@ mod decl {
#[derive(FromArgs)]
#[allow(unused)]
struct DumpTracebackLaterArgs {
#[pyarg(positional)]
timeout: PyObjectRef,
#[pyarg(positional, error_msg = "timeout must be a number (int or float)")]
timeout: ArgIntoFloat,
#[pyarg(any, default = false)]
repeat: bool,
#[pyarg(any, default)]
Expand All @@ -774,18 +776,7 @@ mod decl {

#[pyfunction]
fn dump_traceback_later(args: DumpTracebackLaterArgs, vm: &VirtualMachine) -> PyResult<()> {
use num_traits::ToPrimitive;
// Convert timeout to f64 (accepting int or float)
let timeout: f64 = if let Some(float) = args.timeout.downcast_ref::<PyFloat>() {
float.to_f64()
} else if let Some(int) = args.timeout.try_index_opt(vm).transpose()? {
int.as_bigint()
.to_i64()
.ok_or_else(|| vm.new_overflow_error("timeout value is too large".to_owned()))?
as f64
} else {
return Err(vm.new_type_error("timeout must be a number (int or float)".to_owned()));
};
let timeout: f64 = args.timeout.into_float();

if timeout <= 0.0 {
return Err(vm.new_value_error("timeout must be greater than 0".to_owned()));
Expand Down
63 changes: 32 additions & 31 deletions crates/stdlib/src/ssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ mod _ssl {
VirtualMachine,
builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyTypeRef},
convert::IntoPyException,
function::{ArgBytesLike, ArgMemoryBuffer, FuncArgs, OptionalArg, PyComparisonValue},
function::{
ArgBytesLike, ArgMemoryBuffer, Either, FuncArgs, OptionalArg, PyComparisonValue,
},
stdlib::warnings,
types::{Comparable, Constructor, Hashable, PyComparisonOp, Representable},
},
Expand Down Expand Up @@ -821,20 +823,20 @@ mod _ssl {

#[derive(FromArgs)]
struct LoadVerifyLocationsArgs {
#[pyarg(any, optional)]
cafile: OptionalArg<Option<PyObjectRef>>,
#[pyarg(any, optional)]
capath: OptionalArg<Option<PyObjectRef>>,
#[pyarg(any, optional)]
cadata: OptionalArg<PyObjectRef>,
#[pyarg(any, optional, error_msg = "path should be a str or bytes")]
cafile: OptionalArg<Option<Either<PyStrRef, ArgBytesLike>>>,
#[pyarg(any, optional, error_msg = "path should be a str or bytes")]
capath: OptionalArg<Option<Either<PyStrRef, ArgBytesLike>>>,
#[pyarg(any, optional, error_msg = "cadata should be a str or bytes")]
cadata: OptionalArg<Option<Either<PyStrRef, ArgBytesLike>>>,
}

#[derive(FromArgs)]
struct LoadCertChainArgs {
#[pyarg(any)]
certfile: PyObjectRef,
#[pyarg(any, optional)]
keyfile: OptionalArg<Option<PyObjectRef>>,
#[pyarg(any, error_msg = "path should be a str or bytes")]
certfile: Either<PyStrRef, ArgBytesLike>,
#[pyarg(any, optional, error_msg = "path should be a str or bytes")]
keyfile: OptionalArg<Option<Either<PyStrRef, ArgBytesLike>>>,
#[pyarg(any, optional)]
password: OptionalArg<PyObjectRef>,
}
Expand Down Expand Up @@ -1229,7 +1231,7 @@ mod _ssl {
// Check that at least one argument is provided
let has_cafile = matches!(&args.cafile, OptionalArg::Present(Some(_)));
let has_capath = matches!(&args.capath, OptionalArg::Present(Some(_)));
let has_cadata = matches!(&args.cadata, OptionalArg::Present(obj) if !vm.is_none(obj));
let has_cadata = matches!(&args.cadata, OptionalArg::Present(Some(_)));

if !has_cafile && !has_capath && !has_cadata {
return Err(
Expand All @@ -1250,10 +1252,8 @@ mod _ssl {
None
};

let cadata_parsed = if let OptionalArg::Present(ref cadata_obj) = args.cadata
&& !vm.is_none(cadata_obj)
{
let is_string = PyStrRef::try_from_object(vm, cadata_obj.clone()).is_ok();
let cadata_parsed = if let OptionalArg::Present(Some(ref cadata_obj)) = args.cadata {
let is_string = matches!(cadata_obj, Either::A(_));
let data_vec = self.parse_cadata_arg(cadata_obj, vm)?;
Some((data_vec, is_string))
} else {
Expand Down Expand Up @@ -1989,14 +1989,14 @@ mod _ssl {
// Helper functions (private):

/// Parse path argument (str or bytes) to string
fn parse_path_arg(arg: &PyObject, vm: &VirtualMachine) -> PyResult<String> {
if let Ok(s) = PyStrRef::try_from_object(vm, arg.to_owned()) {
Ok(s.as_str().to_owned())
} else if let Ok(b) = ArgBytesLike::try_from_object(vm, arg.to_owned()) {
String::from_utf8(b.borrow_buf().to_vec())
.map_err(|_| vm.new_value_error("path contains invalid UTF-8".to_owned()))
} else {
Err(vm.new_type_error("path should be a str or bytes".to_owned()))
fn parse_path_arg(
arg: &Either<PyStrRef, ArgBytesLike>,
vm: &VirtualMachine,
) -> PyResult<String> {
match arg {
Either::A(s) => Ok(s.as_str().to_owned()),
Either::B(b) => String::from_utf8(b.borrow_buf().to_vec())
.map_err(|_| vm.new_value_error("path contains invalid UTF-8".to_owned())),
}
}

Expand Down Expand Up @@ -2167,13 +2167,14 @@ mod _ssl {
}

/// Helper: Parse cadata argument (str or bytes)
fn parse_cadata_arg(&self, arg: &PyObject, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
if let Ok(s) = PyStrRef::try_from_object(vm, arg.to_owned()) {
Ok(s.as_str().as_bytes().to_vec())
} else if let Ok(b) = ArgBytesLike::try_from_object(vm, arg.to_owned()) {
Ok(b.borrow_buf().to_vec())
} else {
Err(vm.new_type_error("cadata should be a str or bytes".to_owned()))
fn parse_cadata_arg(
&self,
arg: &Either<PyStrRef, ArgBytesLike>,
_vm: &VirtualMachine,
) -> PyResult<Vec<u8>> {
match arg {
Either::A(s) => Ok(s.as_str().as_bytes().to_vec()),
Either::B(b) => Ok(b.borrow_buf().to_vec()),
}
}

Expand Down
Loading