Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 69 additions & 22 deletions RATapi/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import collections
import contextlib
import warnings
from collections.abc import Iterable, Sequence
from collections.abc import Sequence
from typing import Any, Union

import numpy as np
Expand Down Expand Up @@ -96,8 +96,8 @@ def __setitem__(self, index: int, item: object) -> None:

def _setitem(self, index: int, item: object) -> None:
"""Auxiliary routine of "__setitem__" used to enable wrapping."""
self._check_classes(self + [item])
self._check_unique_name_fields(self + [item])
self._check_classes([item])
self._check_unique_name_fields([item])
self.data[index] = item

def __delitem__(self, index: int) -> None:
Expand All @@ -118,8 +118,8 @@ def _iadd(self, other: Sequence[object]) -> "ClassList":
other = [other]
if not hasattr(self, "_class_handle"):
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
self._check_classes(other)
self._check_unique_name_fields(other)
super().__iadd__(other)
return self

Expand Down Expand Up @@ -168,8 +168,8 @@ def append(self, obj: object = None, **kwargs) -> None:
if obj:
if not hasattr(self, "_class_handle"):
self._class_handle = type(obj)
self._check_classes(self + [obj])
self._check_unique_name_fields(self + [obj])
self._check_classes([obj])
self._check_unique_name_fields([obj])
self.data.append(obj)
else:
if not hasattr(self, "_class_handle"):
Expand Down Expand Up @@ -215,8 +215,8 @@ def insert(self, index: int, obj: object = None, **kwargs) -> None:
if obj:
if not hasattr(self, "_class_handle"):
self._class_handle = type(obj)
self._check_classes(self + [obj])
self._check_unique_name_fields(self + [obj])
self._check_classes([obj])
self._check_unique_name_fields([obj])
self.data.insert(index, obj)
else:
if not hasattr(self, "_class_handle"):
Expand Down Expand Up @@ -252,8 +252,8 @@ def extend(self, other: Sequence[object]) -> None:
other = [other]
if not hasattr(self, "_class_handle"):
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
self._check_classes(other)
self._check_unique_name_fields(other)
self.data.extend(other)

def set_fields(self, index: int, **kwargs) -> None:
Expand Down Expand Up @@ -312,13 +312,14 @@ def _validate_name_field(self, input_args: dict[str, Any]) -> None:
"""
names = [name.lower() for name in self.get_names()]
with contextlib.suppress(KeyError):
if input_args[self.name_field].lower() in names:
name = input_args[self.name_field].lower()
if name in names:
raise ValueError(
f"Input arguments contain the {self.name_field} '{input_args[self.name_field]}', "
f"which is already specified in the ClassList",
f"which is already specified at index {names.index(name)} of the ClassList",
)

def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
def _check_unique_name_fields(self, input_list: Sequence[object]) -> None:
"""Raise a ValueError if any value of the name_field attribute is used more than once in a list of class
objects.

Expand All @@ -333,11 +334,49 @@ def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
Raised if the input list defines more than one object with the same value of name_field.

"""
names = [getattr(model, self.name_field).lower() for model in input_list if hasattr(model, self.name_field)]
if len(set(names)) != len(names):
raise ValueError(f"Input list contains objects with the same value of the {self.name_field} attribute")
error_list = []
try:
existing_names = [name.lower() for name in self.get_names()]
except AttributeError:
existing_names = []

new_names = [getattr(model, self.name_field).lower() for model in input_list if hasattr(model, self.name_field)]
full_names = existing_names + new_names

# There are duplicate names if this test fails
if len(set(full_names)) != len(full_names):
unique_names = [*dict.fromkeys(new_names)]

for name in unique_names:
existing_indices = [i for i, other_name in enumerate(existing_names) if other_name == name]
new_indices = [i for i, other_name in enumerate(new_names) if other_name == name]
if (len(existing_indices) + len(new_indices)) > 1:
existing_string = ""
new_string = ""
if existing_indices:
existing_list = ", ".join(str(i) for i in existing_indices[:-1])
existing_string = (
f" item{f's {existing_list} and ' if existing_list else ' '}"
f"{existing_indices[-1]} of the existing ClassList"
)
if new_indices:
new_list = ", ".join(str(i) for i in new_indices[:-1])
new_string = (
f" item{f's {new_list} and ' if new_list else ' '}" f"{new_indices[-1]} of the input list"
)
error_list.append(
f" '{name}' is shared between{existing_string}"
f"{', and' if existing_string and new_string else ''}{new_string}"
)

def _check_classes(self, input_list: Iterable[object]) -> None:
if error_list:
newline = "\n"
raise ValueError(
f"The value of the '{self.name_field}' attribute must be unique for each item in the ClassList:\n"
f"{newline.join(error for error in error_list)}"
)

def _check_classes(self, input_list: Sequence[object]) -> None:
"""Raise a ValueError if any object in a list of objects is not of the type specified by self._class_handle.

Parameters
Expand All @@ -348,11 +387,19 @@ def _check_classes(self, input_list: Iterable[object]) -> None:
Raises
------
ValueError
Raised if the input list defines objects of different types.
Raised if the input list contains objects of any type other than that given in self._class_handle.

"""
if not (all(isinstance(element, self._class_handle) for element in input_list)):
raise ValueError(f"Input list contains elements of type other than '{self._class_handle.__name__}'")
error_list = []
for i, element in enumerate(input_list):
if not isinstance(element, self._class_handle):
error_list.append(f" index {i} is of type {type(element).__name__}")
if error_list:
newline = "\n"
raise ValueError(
f"This ClassList only supports elements of type {self._class_handle.__name__}. "
f"In the input list:\n{newline.join(error for error in error_list)}\n"
)

def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object, str]:
"""Return the object with the given value of the name_field attribute in the ClassList.
Expand All @@ -379,7 +426,7 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
@staticmethod
def _determine_class_handle(input_list: Sequence[object]):
"""When inputting a sequence of object to a ClassList, the _class_handle should be set as the type of the
element which satisfies "issubclass" for all of the other elements.
element which satisfies "issubclass" for all the other elements.

Parameters
----------
Expand Down
104 changes: 81 additions & 23 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def test_different_classes(self, input_list: Sequence[object]) -> None:
"""If we initialise a ClassList with an input containing multiple classes, we should raise a ValueError."""
with pytest.raises(
ValueError,
match=f"Input list contains elements of type other than '{type(input_list[0]).__name__}'",
match=f"This ClassList only supports elements of type {type(input_list[0]).__name__}. In the input list:\n"
f" index 1 is of type {type(input_list[1]).__name__}\n",
):
ClassList(input_list)

Expand All @@ -134,7 +135,9 @@ def test_identical_name_fields(self, input_list: Sequence[object], name_field: s
"""
with pytest.raises(
ValueError,
match=f"Input list contains objects with the same value of the {name_field} attribute",
match=f"The value of the '{name_field}' attribute must be unique for each item in the "
f"ClassList:\n '{getattr(input_list[0], name_field).lower()}'"
f" is shared between items 0 and 1 of the input list",
):
ClassList(input_list, name_field=name_field)

Expand Down Expand Up @@ -194,7 +197,12 @@ def test_setitem(two_name_class_list: ClassList, new_item: InputAttributes, expe
)
def test_setitem_same_name_field(two_name_class_list: ClassList, new_item: InputAttributes) -> None:
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
with pytest.raises(ValueError, match="Input list contains objects with the same value of the name attribute"):
with pytest.raises(
ValueError,
match=f"The value of the '{two_name_class_list.name_field}' attribute must be unique for each item in the "
f"ClassList:\n '{new_item.name.lower()}' is shared between item 1 of the existing ClassList,"
f" and item 0 of the input list",
):
two_name_class_list[0] = new_item


Expand All @@ -206,7 +214,11 @@ def test_setitem_same_name_field(two_name_class_list: ClassList, new_item: Input
)
def test_setitem_different_classes(two_name_class_list: ClassList, new_values: dict[str, Any]) -> None:
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
with pytest.raises(ValueError, match="Input list contains elements of type other than 'InputAttributes'"):
with pytest.raises(
ValueError,
match=f"This ClassList only supports elements of type {two_name_class_list._class_handle.__name__}. "
f"In the input list:\n index 0 is of type {type(new_values).__name__}\n",
):
two_name_class_list[0] = new_values


Expand Down Expand Up @@ -403,7 +415,9 @@ def test_append_object_same_name_field(two_name_class_list: ClassList, new_objec
"""If we append an object with an already-specified name_field value to a ClassList we should raise a ValueError."""
with pytest.raises(
ValueError,
match=f"Input list contains objects with the same value of the " f"{two_name_class_list.name_field} attribute",
match=f"The value of the '{two_name_class_list.name_field}' attribute must be unique for each item in the "
f"ClassList:\n '{new_object.name.lower()}' is shared between item 0 of the existing ClassList, and "
f"item 0 of the input list",
):
two_name_class_list.append(new_object)

Expand All @@ -420,7 +434,7 @@ def test_append_kwargs_same_name_field(two_name_class_list: ClassList, new_value
ValueError,
match=f"Input arguments contain the {two_name_class_list.name_field} "
f"'{new_values[two_name_class_list.name_field]}', "
f"which is already specified in the ClassList",
f"which is already specified at index 0 of the ClassList",
):
two_name_class_list.append(**new_values)

Expand Down Expand Up @@ -526,7 +540,9 @@ def test_insert_object_same_name(two_name_class_list: ClassList, new_object: obj
"""If we insert an object with an already-specified name_field value to a ClassList we should raise a ValueError."""
with pytest.raises(
ValueError,
match=f"Input list contains objects with the same value of the " f"{two_name_class_list.name_field} attribute",
match=f"The value of the '{two_name_class_list.name_field}' attribute must be unique for each item in the "
f"ClassList:\n '{new_object.name.lower()}' is shared between item 0 of the existing "
f"ClassList, and item 0 of the input list",
):
two_name_class_list.insert(1, new_object)

Expand All @@ -543,7 +559,7 @@ def test_insert_kwargs_same_name(two_name_class_list: ClassList, new_values: dic
ValueError,
match=f"Input arguments contain the {two_name_class_list.name_field} "
f"'{new_values[two_name_class_list.name_field]}', "
f"which is already specified in the ClassList",
f"which is already specified at index 0 of the ClassList",
):
two_name_class_list.insert(1, **new_values)

Expand Down Expand Up @@ -702,7 +718,7 @@ def test_set_fields_same_name_field(two_name_class_list: ClassList, new_values:
ValueError,
match=f"Input arguments contain the {two_name_class_list.name_field} "
f"'{new_values[two_name_class_list.name_field]}', "
f"which is already specified in the ClassList",
f"which is already specified at index 1 of the ClassList",
):
two_name_class_list.set_fields(0, **new_values)

Expand Down Expand Up @@ -767,7 +783,7 @@ def test__validate_name_field(two_name_class_list: ClassList, input_dict: dict[s
"input_dict",
[
({"name": "Alice"}),
({"name": "ALICE"}),
({"name": "BOB"}),
({"name": "alice"}),
],
)
Expand All @@ -777,18 +793,18 @@ def test__validate_name_field_not_unique(two_name_class_list: ClassList, input_d
with pytest.raises(
ValueError,
match=f"Input arguments contain the {two_name_class_list.name_field} "
f"'{input_dict[two_name_class_list.name_field]}', "
f"which is already specified in the ClassList",
f"'{input_dict[two_name_class_list.name_field]}', which is already specified at index "
f"{two_name_class_list.index(input_dict['name'].lower())} of the ClassList",
):
two_name_class_list._validate_name_field(input_dict)


@pytest.mark.parametrize(
"input_list",
[
([InputAttributes(name="Alice"), InputAttributes(name="Bob")]),
([InputAttributes(surname="Morgan"), InputAttributes(surname="Terwilliger")]),
([InputAttributes(name="Alice", surname="Morgan"), InputAttributes(surname="Terwilliger")]),
([InputAttributes(name="Eve"), InputAttributes(name="Gareth")]),
([InputAttributes(surname="Polastri"), InputAttributes(surname="Mallory")]),
([InputAttributes(name="Eve", surname="Polastri"), InputAttributes(surname="Mallory")]),
([InputAttributes()]),
([]),
],
Expand All @@ -801,20 +817,59 @@ def test__check_unique_name_fields(two_name_class_list: ClassList, input_list: I


@pytest.mark.parametrize(
"input_list",
["input_list", "error_message"],
[
([InputAttributes(name="Alice"), InputAttributes(name="Alice")]),
([InputAttributes(name="Alice"), InputAttributes(name="ALICE")]),
([InputAttributes(name="Alice"), InputAttributes(name="alice")]),
(
[InputAttributes(name="Alice"), InputAttributes(name="Bob")],
(
" 'alice' is shared between item 0 of the existing ClassList, and item 0 of the input list\n"
" 'bob' is shared between item 1 of the existing ClassList, and item 1 of the input list"
),
),
(
[InputAttributes(name="Alice"), InputAttributes(name="Alice")],
" 'alice' is shared between item 0 of the existing ClassList, and items 0 and 1 of the input list",
),
(
[InputAttributes(name="Alice"), InputAttributes(name="ALICE")],
" 'alice' is shared between item 0 of the existing ClassList, and items 0 and 1 of the input list",
),
(
[InputAttributes(name="Alice"), InputAttributes(name="alice")],
" 'alice' is shared between item 0 of the existing ClassList, and items 0 and 1 of the input list",
),
(
[InputAttributes(name="Eve"), InputAttributes(name="Eve")],
" 'eve' is shared between items 0 and 1 of the input list",
),
(
[
InputAttributes(name="Bob"),
InputAttributes(name="Alice"),
InputAttributes(name="Eve"),
InputAttributes(name="Alice"),
InputAttributes(name="Eve"),
InputAttributes(name="Alice"),
],
(
" 'bob' is shared between item 1 of the existing ClassList, and item 0 of the input list\n"
" 'alice' is shared between item 0 of the existing ClassList,"
" and items 1, 3 and 5 of the input list\n"
" 'eve' is shared between items 2 and 4 of the input list"
),
),
],
)
def test__check_unique_name_fields_not_unique(two_name_class_list: ClassList, input_list: Iterable) -> None:
def test__check_unique_name_fields_not_unique(
two_name_class_list: ClassList, input_list: Sequence, error_message: str
) -> None:
"""We should raise a ValueError if an input list contains multiple objects with (case-insensitive) matching
name_field values defined.
"""
with pytest.raises(
ValueError,
match=f"Input list contains objects with the same value of the " f"{two_name_class_list.name_field} attribute",
match=f"The value of the '{two_name_class_list.name_field}' attribute must be unique for each item in the "
f"ClassList:\n{error_message}",
):
two_name_class_list._check_unique_name_fields(input_list)

Expand All @@ -837,12 +892,15 @@ def test__check_classes(input_list: Iterable) -> None:
([InputAttributes(name="Alice"), dict(name="Bob")]),
],
)
def test__check_classes_different_classes(input_list: Iterable) -> None:
def test__check_classes_different_classes(input_list: Sequence) -> None:
"""We should raise a ValueError if an input list contains objects of different types."""
class_list = ClassList([InputAttributes()])
with pytest.raises(
ValueError,
match=(f"Input list contains elements of type other " f"than '{class_list._class_handle.__name__}'"),
match=(
f"This ClassList only supports elements of type {class_list._class_handle.__name__}. "
f"In the input list:\n index 1 is of type {type(input_list[1]).__name__}"
),
):
class_list._check_classes(input_list)

Expand Down