Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions crates/derive-impl/src/from_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct ArgAttribute {
name: Option<String>,
kind: ParameterKind,
default: Option<DefaultValue>,
error_msg: Option<String>,
}

impl ArgAttribute {
Expand All @@ -63,6 +64,7 @@ impl ArgAttribute {
name: None,
kind,
default: None,
error_msg: None,
});
return Ok(());
};
Expand Down Expand Up @@ -94,6 +96,12 @@ impl ArgAttribute {
}
let val = meta.value()?.parse::<syn::LitStr>()?;
self.name = Some(val.value())
} else if meta.path.is_ident("error_msg") {
if self.error_msg.is_some() {
return Err(meta.error("already have an error_msg"));
}
let val = meta.value()?.parse::<syn::LitStr>()?;
self.error_msg = Some(val.value())
} else {
return Err(meta.error("Unrecognized pyarg attribute"));
}
Expand Down Expand Up @@ -146,8 +154,15 @@ fn generate_field((i, field): (usize, &Field)) -> Result<TokenStream> {
.or(name_string)
.ok_or_else(|| err_span!(field, "field in tuple struct must have name attribute"))?;

let middle = quote! {
.map(|x| ::rustpython_vm::convert::TryFromObject::try_from_object(vm, x)).transpose()?
let middle = if let Some(error_msg) = &attr.error_msg {
quote! {
.map(|x| ::rustpython_vm::convert::TryFromObject::try_from_object(vm, x)
.map_err(|_| vm.new_type_error(#error_msg))).transpose()?
}
} else {
quote! {
.map(|x| ::rustpython_vm::convert::TryFromObject::try_from_object(vm, x)).transpose()?
}
};

let ending = if let Some(default) = attr.default {
Expand Down
8 changes: 4 additions & 4 deletions crates/vm/src/builtins/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,13 +855,13 @@ pub struct PyFunctionNewArgs {
code: PyRef<PyCode>,
#[pyarg(positional)]
globals: PyDictRef,
#[pyarg(any, optional)]
#[pyarg(any, optional, error_msg = "arg 3 (name) must be None or string")]
name: OptionalArg<PyStrRef>,
#[pyarg(any, optional)]
#[pyarg(any, optional, error_msg = "arg 4 (defaults) must be None or tuple")]
argdefs: Option<PyTupleRef>,
#[pyarg(any, optional)]
#[pyarg(any, optional, error_msg = "arg 5 (closure) must be None or tuple")]
closure: Option<PyTupleRef>,
#[pyarg(any, optional)]
#[pyarg(any, optional, error_msg = "arg 6 (kwdefaults) must be None or dict")]
kwdefaults: Option<PyDictRef>,
}

Expand Down
36 changes: 15 additions & 21 deletions crates/vm/src/builtins/interpolation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,16 @@ impl Constructor for PyInterpolation {
type Args = InterpolationArgs;

fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
let conversion = match args.conversion {
OptionalArg::Present(c) => {
if vm.is_none(&c) {
vm.ctx.none()
} else {
let s = c.downcast::<PyStr>().map_err(|_| {
vm.new_type_error(
"Interpolation() argument 'conversion' must be str or None",
)
})?;
let s_str = s.as_str();
if s_str.len() != 1 || !matches!(s_str.chars().next(), Some('s' | 'r' | 'a')) {
return Err(vm.new_value_error(
"Interpolation() argument 'conversion' must be one of 's', 'a' or 'r'",
));
}
s.into()
}
let conversion: PyObjectRef = if let Some(s) = args.conversion {
let s_str = s.as_str();
if s_str.len() != 1 || !matches!(s_str.chars().next(), Some('s' | 'r' | 'a')) {
return Err(vm.new_value_error(
"Interpolation() argument 'conversion' must be one of 's', 'a' or 'r'",
));
}
OptionalArg::Missing => vm.ctx.none(),
s.into()
} else {
vm.ctx.none()
};

let expression = args
Expand All @@ -103,8 +93,12 @@ pub struct InterpolationArgs {
value: PyObjectRef,
#[pyarg(any, optional)]
expression: OptionalArg<PyStrRef>,
#[pyarg(any, optional)]
conversion: OptionalArg<PyObjectRef>,
#[pyarg(
any,
optional,
error_msg = "Interpolation() argument 'conversion' must be str or None"
)]
conversion: Option<PyStrRef>,
#[pyarg(any, optional)]
format_spec: OptionalArg<PyStrRef>,
}
Expand Down
9 changes: 3 additions & 6 deletions crates/vm/src/builtins/super.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ impl Constructor for PySuper {

#[derive(FromArgs)]
pub struct InitArgs {
#[pyarg(positional, optional)]
py_type: OptionalArg<PyObjectRef>,
#[pyarg(positional, optional, error_msg = "super() argument 1 must be a type")]
py_type: OptionalArg<PyTypeRef>,
#[pyarg(positional, optional)]
py_obj: OptionalArg<PyObjectRef>,
}
Expand All @@ -75,10 +75,7 @@ impl Initializer for PySuper {
vm: &VirtualMachine,
) -> PyResult<()> {
// Get the type:
let (typ, obj) = if let OptionalArg::Present(ty_obj) = py_type {
let ty = ty_obj
.downcast::<PyType>()
.map_err(|_| vm.new_type_error("super() argument 1 must be a type"))?;
let (typ, obj) = if let OptionalArg::Present(ty) = py_type {
(ty, py_obj.unwrap_or_none(vm))
} else {
let frame = vm
Expand Down
Loading