-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy path_db_loader.py
More file actions
299 lines (274 loc) · 11.6 KB
/
_db_loader.py
File metadata and controls
299 lines (274 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
"""
MLSTRUCT-FP - DB - DBLOADER
Loads a given dataset .json file.
"""
__all__ = ['DbLoader']
from MLStructFP.db._floor import Floor
from MLStructFP.db._c_rect import Rect
from MLStructFP.db._c_point import Point
from MLStructFP.db._c_slab import Slab
from MLStructFP.db._c_room import Room
from MLStructFP.db._c_item import Item
from MLStructFP._types import Tuple
import json
import math
import matplotlib.pyplot as plt
import os
import tabulate
from collections import Counter
from IPython.display import HTML, display
from pathlib import Path
from typing import Dict, Callable, Optional, List
class DbLoader(object):
"""
Dataset loader.
"""
__filter: Optional[Callable[['Floor'], bool]]
__filtered_floors: List['Floor']
__floor: Dict[int, 'Floor']
__floor_categories: Dict[int, str]
__path: str
def __init__(self, db: str, floor_only: bool = False) -> None:
"""
Loads a dataset file.
:param db: Dataset path
:param floor_only: If true, load only floors
"""
assert os.path.isfile(db), f'Dataset file {db} not found'
self.__filter = None
self.__filtered_floors = []
self.__floor = {}
self.__floor_categories: Dict[int, str] = {}
self.__path = str(Path(os.path.realpath(db)).parent)
with open(db, 'r', encoding='utf8') as dbfile:
data: dict = json.load(dbfile)
meta: dict = data['meta'] if 'meta' in data else {}
# Load metadata
for cat in (meta['floor_categories'] if 'floor_categories' in meta else {}):
self.__floor_categories[meta['floor_categories'][cat]] = cat
item_types: Dict[int, Tuple[str, str]] = {}
for cat in (meta['item_types'] if 'item_types' in meta else {}):
ic = meta['item_types'][cat]
item_types[ic[0]] = (cat, ic[1])
project_label: Dict[int, str] = {}
for pid in (meta['project_label'] if 'project_label' in meta else {}):
try:
project_label[int(pid)] = meta['project_label'][pid]
except ValueError:
pass
room_categories: Dict[int, Tuple[str, str]] = {}
for cat in (meta['room_categories'] if 'room_categories' in meta else {}):
rc = meta['room_categories'][cat]
room_categories[rc[0]] = (cat, rc[1])
# Load floors
for f_id in data.get('floor', {}):
f_data: dict = data['floor'][f_id]
f_cat: int = int(f_data['category'] if 'category' in f_data else 0)
project_id: int = f_data['project'] if 'project' in f_data else -1
self.__floor[int(f_id)] = Floor(
floor_id=int(f_id),
image_path=os.path.join(self.__path, f_data['image']),
image_scale=f_data['scale'],
project_id=project_id,
project_label=project_label[project_id] if project_id in project_label else '',
category=f_cat,
category_name=self.__floor_categories.get(f_cat, ''),
elevation=f_data['elevation'] if 'elevation' in f_data else False
)
if floor_only:
return
# Load objects
for rect_id in data.get('rect', {}):
rect_data: dict = data['rect'][rect_id]
rect_a = rect_data['angle']
Rect(
rect_id=int(rect_id),
wall_id=int(rect_data['wallID']),
floor=self.__floor[rect_data['floorID']],
angle=rect_a if not isinstance(rect_a, list) else rect_a[0],
length=rect_data['length'],
thickness=rect_data['thickness'],
x=rect_data['x'],
y=rect_data['y'],
line_m=rect_data['line'][0], # Slope
line_n=rect_data['line'][1], # Intercept
line_theta=rect_data['line'][2], # Theta
partition=rect_data['partition'] if 'partition' in rect_data else False # Is partition
)
for point_id in data.get('point', {}):
point_data: dict = data['point'][point_id]
Point(
point_id=int(point_id),
wall_id=int(point_data['wallID']),
floor=self.__floor[point_data['floorID']],
x=point_data['x'],
y=point_data['y'],
topo=int(point_data['topo'])
)
for slab_id in data.get('slab', {}):
slab_data: dict = data['slab'][slab_id]
Slab(
slab_id=int(slab_id),
floor=self.__floor[slab_data['floorID']],
x=slab_data['x'],
y=slab_data['y']
)
for room_id in data.get('room', {}):
room_data: dict = data['room'][room_id]
room_cat = int(room_data['category'])
Room(
room_id=int(room_id),
floor=self.__floor[room_data['floorID']],
x=room_data['x'],
y=room_data['y'],
color=room_categories[room_cat][1] if room_cat in room_categories else '#000000',
category=room_cat,
category_name=room_categories[room_cat][0] if room_cat in room_categories else ''
)
for item_id in data.get('item', {}):
item_data: dict = data['item'][item_id]
item_cat = int(item_data['category'])
Item(
item_id=int(item_id),
floor=self.__floor[item_data['floorID']],
x=item_data['x'],
y=item_data['y'],
color=item_types[item_cat][1] if item_cat in item_types else '#000000',
category=item_cat,
category_name=item_types[item_cat][0] if item_cat in item_types else ''
)
def __getitem__(self, item: int) -> 'Floor':
return self.__floor[item]
def add_floor(self, floor_image: str, scale: float, category: int, elevation: bool) -> 'Floor':
"""
Adds a floor to the dataset. No project.
:param floor_image: Floor image file
:param scale: Image scale
:param category: Floor category
:param elevation: Floor is elevation
:return: Added floor object
"""
assert os.path.isfile(floor_image)
f_id: int = len(self.__floor) + 1
f = Floor(
floor_id=int(f_id),
image_path=floor_image,
image_scale=scale,
project_id=-1,
project_label='',
category=category,
category_name=self.__floor_categories.get(category, ''),
elevation=elevation
)
self.__floor[f_id] = f
return f
@property
def floors(self) -> Tuple['Floor', ...]:
if len(self.__filtered_floors) == 0:
for f in self.__floor.values():
if self.__filter is None or self.__filter(f):
self.__filtered_floors.append(f)
return tuple(self.__filtered_floors)
@property
def path(self) -> str:
return self.__path
@property
def scale_limits(self) -> Tuple[float, float]:
sc_min = math.inf
sc_max = 0
for f in self.floors:
sc_min = min(sc_min, f.image_scale)
sc_max = max(sc_max, f.image_scale)
return sc_min, sc_max
def set_filter(self, f_filter: Callable[['Floor'], bool]) -> None:
"""
Set floor filter.
:param f_filter: Floor filter. If "None", it is removed
"""
self.__filter = f_filter
self.__filtered_floors.clear()
def tabulate(self, limit: int = 0, legacy: bool = False,
f_filter: Optional[Callable[['Floor'], bool]] = None,
category_name: bool = False) -> None:
"""
Tabulates each floor, with their file and number of rects.
:param limit: Limits the number of items
:param legacy: Show legacy mode
:param f_filter: Floor filter
:param category_name: If true, shows category name instead of numeric value
"""
assert isinstance(limit, int) and limit >= 0, 'Limit must be an integer greater or equal than zero'
theads = ['#']
for t in (
('Project ID', 'Project label', 'Floor ID', 'Cat', 'Elev',
'Rects', 'Points', 'Slabs', 'Rooms', 'Items', 'Floor image path'
) if not legacy else
('Floor ID', 'Rects', 'Slabs', 'Floor image path')
):
theads.append(t)
table = [theads]
floors = self.floors
for j in range(len(floors)):
f: 'Floor' = floors[j]
if f_filter is not None and not f_filter(f):
continue
table_data = [j]
f_file: str = os.path.basename(f.image_path)
for i in (
(f.project_id, f.project_label, f.id, f.category if not category_name else f.category_name,
1 if f.elevation else 0, len(f.rect), len(f.point), len(f.slab),
len(f.room), len(f.item), f_file
) if not legacy else
(f.id, len(f.rect), len(f.slab), f_file)
):
table_data.append(i)
table.append(table_data) # type: ignore
if 0 < limit - 1 <= j:
break
display(HTML(tabulate.tabulate(
table,
headers='firstrow',
numalign='center',
stralign='center',
tablefmt='html'
)))
def hist(self,
f_hist: Callable[['Floor'], List[str]] = lambda f: [f.category_name],
f_filter: Optional[Callable[['Floor'], bool]] = None,
sort_cat: bool = True,
show_plot: bool = True
) -> Tuple[str, ...]:
"""
Create an histogram of object categories.
:param f_hist: Function that feeds histogram with object categories
:param f_filter: Floor filter
:param sort_cat: Sort object categories
:param show_plot: Show plot
:return: All categories, considering sort
"""
cat: List[str] = []
for f in self.floors:
if f_filter is not None and not f_filter(f):
continue
fh = f_hist(f)
assert isinstance(fh, list), (f'f_hist must return a list of categories to assemble histogram, '
f'"{fh}" is not allowed')
for c in fh:
assert isinstance(c, str), f'f_hist must return only strings, "{c}" is not allowed'
cat.append(c)
category_counts = Counter(cat)
if sort_cat: # Sort categories
categories, counts = zip(*sorted(category_counts.items(), key=lambda x: x[1], reverse=True))
else:
categories, counts = list(category_counts.keys()), list(category_counts.values())
lc = len(categories)
plt.figure(figsize=(12, 6))
plt.bar(categories, counts)
plt.xticks(rotation=45, fontsize=8 if lc > 10 else 10, ha='right')
plt.xlabel('Category')
plt.ylabel('Frequency')
plt.title(f'Histogram ({lc} categories / {len(cat)} objects)')
plt.tight_layout()
if show_plot:
plt.show()
return tuple(categories)