Skip to content

Commit f7dbf73

Browse files
committed
Sharpen types, including precise islpy
1 parent 9e4099e commit f7dbf73

40 files changed

+53265
-90073
lines changed

.basedpyright/baseline.json

Lines changed: 51617 additions & 89315 deletions
Large diffs are not rendered by default.

loopy/check.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525

2626
import logging
2727
from collections import defaultdict
28+
from dataclasses import dataclass
2829
from functools import reduce
29-
from typing import TYPE_CHECKING
30+
from typing import TYPE_CHECKING, cast
3031

3132
import numpy as np
33+
from typing_extensions import override
3234

3335
import islpy as isl
3436
from islpy import dim_type
@@ -76,6 +78,7 @@
7678
if TYPE_CHECKING:
7779
from collections.abc import Mapping, Sequence
7880

81+
import pymbolic.primitives as p
7982
from pymbolic.typing import Expression
8083

8184
from loopy.kernel import LoopKernel
@@ -707,26 +710,31 @@ def subst_func(x):
707710
# }}}
708711

709712

710-
class _AccessCheckMapper(WalkMapper):
711-
def __init__(self, kernel, callables_table):
712-
self.kernel = kernel
713-
self.callables_table = callables_table
714-
super().__init__()
713+
@dataclass
714+
class _AccessCheckMapper(WalkMapper[[isl.BasicSet, str]]):
715+
kernel: LoopKernel
716+
callables_table: CallablesTable
715717

716718
@memoize_method
717719
def _make_slab(self, space, iname, start, stop):
718720
from loopy.isl_helpers import make_slab
719721
return make_slab(space, iname, start, stop)
720722

721723
@memoize_method
722-
def _get_access_range(self, domain, subscript):
723-
from loopy.symbolic import UnableToDetermineAccessRangeError, get_access_map
724+
def _get_access_range(
725+
self,
726+
domain: isl.BasicSet,
727+
subscript: tuple[Expression, ...]
728+
):
729+
from loopy.diagnostic import UnableToDetermineAccessRangeError
730+
from loopy.symbolic import get_access_map
724731
try:
725732
return get_access_map(domain, subscript).range()
726733
except UnableToDetermineAccessRangeError:
727734
return None
728735

729-
def map_subscript(self, expr, domain, insn_id):
736+
@override
737+
def map_subscript(self, expr: p.Subscript, domain: isl.BasicSet, insn_id: str):
730738
WalkMapper.map_subscript(self, expr, domain, insn_id)
731739

732740
from pymbolic.primitives import Variable
@@ -742,6 +750,7 @@ def map_subscript(self, expr, domain, insn_id):
742750
shape = tv.shape
743751

744752
if shape is not None:
753+
assert isinstance(shape, tuple)
745754
subscript = expr.index
746755

747756
if not isinstance(subscript, tuple):
@@ -787,7 +796,8 @@ def map_subscript(self, expr, domain, insn_id):
787796
" establish '%s' is a subset of '%s')."
788797
% (expr, insn_id, access_range, shape_domain))
789798

790-
def map_if(self, expr, domain, insn_id):
799+
@override
800+
def map_if(self, expr: p. If, domain: isl.BasicSet, insn_id: str):
791801
from loopy.symbolic import condition_to_set
792802
then_set = condition_to_set(domain.space, expr.condition)
793803
if then_set is None:
@@ -800,7 +810,8 @@ def map_if(self, expr, domain, insn_id):
800810
self.rec(expr.then, domain & then_set, insn_id)
801811
self.rec(expr.else_, domain & else_set, insn_id)
802812

803-
def map_call(self, expr, domain, insn_id):
813+
@override
814+
def map_call(self, expr: p.Call, domain: isl.BasicSet, insn_id: str):
804815
# perform access checks on the call arguments
805816
super().map_call(expr, domain, insn_id)
806817

@@ -817,7 +828,9 @@ def map_call(self, expr, domain, insn_id):
817828
and isinstance(self.callables_table[expr.function.name],
818829
CallableKernel)):
819830

820-
subkernel = self.callables_table[expr.function.name].subkernel
831+
subkernel = cast(
832+
"CallableKernel",
833+
self.callables_table[expr.function.name]).subkernel
821834

822835
# The plan here is to add the constraints coming from the values
823836
# args passed at a call-site as assumptions to the callee. To avoid
@@ -835,8 +848,8 @@ def map_call(self, expr, domain, insn_id):
835848

836849
kw_space = isl.Space.create_from_names(
837850
subkernel.isl_context, set=[],
838-
params=(get_dependencies(tuple(kwargs.values()))
839-
| set(kwargs.keys())))
851+
params=[*get_dependencies(tuple(kwargs.values())),
852+
*kwargs.keys()])
840853

841854
extra_assumptions = isl.BasicSet.universe(kw_space).params()
842855

@@ -894,8 +907,8 @@ def _check_bounds_inner(kernel: LoopKernel, callables_table: CallablesTable) ->
894907
domain, assumptions = isl.align_two(domain, kernel.assumptions)
895908
domain_with_assumptions = domain & assumptions
896909

897-
def run_acm(expr):
898-
acm(expr, domain_with_assumptions, insn.id) # noqa: B023
910+
def run_acm(expr: Expression):
911+
acm(expr, domain_with_assumptions, not_none(insn.id)) # noqa: B023
899912
return expr
900913

901914
insn.with_transformed_expressions(run_acm)
@@ -1875,7 +1888,7 @@ def pre_codegen_checks(t_unit: TranslationUnit) -> None:
18751888

18761889
def check_implemented_domains(
18771890
kernel: LoopKernel,
1878-
implemented_domains: Mapping[str, isl.Set],
1891+
implemented_domains: Mapping[str, Sequence[isl.Set]],
18791892
code: str | None = None,
18801893
) -> bool:
18811894
from islpy import align_two, dim_type
@@ -1942,7 +1955,7 @@ def check_implemented_domains(
19421955
d_minus_i = desired_domain - insn_impl_domain
19431956

19441957
parameter_inames = {
1945-
insn_domain.get_dim_name(dim_type.param, i)
1958+
not_none(insn_domain.get_dim_name(dim_type.param, i))
19461959
for i in range(insn_impl_domain.dim(dim_type.param))}
19471960

19481961
lines = []

loopy/codegen/__init__.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434

3535
import constantdict
3636

37+
from loopy.types import NumpyType
38+
from loopy.typing import not_none
39+
3740

3841
logger = logging.getLogger(__name__)
3942

@@ -53,14 +56,15 @@
5356

5457

5558
if TYPE_CHECKING:
59+
from pymbolic import Expression
60+
5661
from loopy.codegen.result import CodeGenerationResult, GeneratedProgram
5762
from loopy.codegen.tools import CodegenOperationCacheManager
5863
from loopy.kernel import LoopKernel
5964
from loopy.library.reduction import ReductionOpFunction
6065
from loopy.target import TargetBase
6166
from loopy.translation_unit import CallablesTable, TranslationUnit
6267
from loopy.types import LoopyType
63-
from loopy.typing import Expression
6468

6569

6670
__doc__ = """
@@ -206,7 +210,7 @@ def intersect(self, other):
206210
new_impl, new_other = isl.align_two(self.implemented_domain, other)
207211
return self.copy(implemented_domain=new_impl & new_other)
208212

209-
def fix(self, iname, aff):
213+
def fix(self, iname: str, aff: isl.Aff) -> CodeGenerationState:
210214
new_impl_domain = self.implemented_domain
211215

212216
impl_space = self.implemented_domain.get_space()
@@ -341,8 +345,12 @@ class PreambleInfo:
341345

342346
# {{{ main code generation entrypoint
343347

344-
def generate_code_for_a_single_kernel(kernel, callables_table, target,
345-
is_entrypoint):
348+
def generate_code_for_a_single_kernel(
349+
kernel: LoopKernel,
350+
callables_table: CallablesTable,
351+
target: TargetBase,
352+
is_entrypoint: bool,
353+
) -> CodeGenerationResult:
346354
"""
347355
:returns: a :class:`CodeGenerationResult`
348356
@@ -359,7 +367,8 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
359367
# {{{ examine arg list
360368

361369
allow_complex = False
362-
for var in kernel.args + list(kernel.temporary_variables.values()):
370+
for var in [*kernel.args, *kernel.temporary_variables.values()]:
371+
assert isinstance(var.dtype, NumpyType)
363372
if var.dtype.involves_complex():
364373
allow_complex = True
365374

@@ -376,7 +385,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
376385
codegen_state = CodeGenerationState(
377386
kernel=kernel,
378387
target=target,
379-
implemented_domain=initial_implemented_domain,
388+
implemented_domain=isl.Set.from_basic_set(initial_implemented_domain),
380389
implemented_predicates=frozenset(),
381390
seen_dtypes=seen_dtypes,
382391
seen_functions=seen_functions,
@@ -389,7 +398,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
389398
target.host_program_name_prefix
390399
+ kernel.name
391400
+ kernel.target.host_program_name_suffix),
392-
schedule_index_end=len(kernel.linearization),
401+
schedule_index_end=len(not_none(kernel.linearization)),
393402
callables_table=callables_table,
394403
is_entrypoint=is_entrypoint,
395404
codegen_cache_manager=CodegenOperationCacheManager.from_kernel(kernel),
@@ -418,7 +427,8 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
418427
if kernel.all_inames():
419428
seen_dtypes.add(kernel.index_dtype)
420429

421-
preambles = kernel.preambles + codegen_result.device_preambles
430+
preambles = [
431+
*kernel.preambles, *codegen_result.device_preambles]
422432

423433
preamble_info = PreambleInfo(
424434
kernel=kernel,
@@ -429,10 +439,10 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
429439
codegen_state=codegen_state
430440
)
431441

432-
preamble_generators = (list(kernel.preamble_generators)
433-
+ list(target.get_device_ast_builder().preamble_generators()))
434-
for prea_gen in preamble_generators:
435-
preambles = preambles + tuple(prea_gen(preamble_info))
442+
for prea_gen in [
443+
*kernel.preamble_generators,
444+
*target.get_device_ast_builder().preamble_generators()]:
445+
preambles.extend(prea_gen(preamble_info))
436446

437447
codegen_result = codegen_result.copy(device_preambles=preambles)
438448

loopy/codegen/control.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
"""
2626

2727
from functools import partial
28+
from typing import TYPE_CHECKING
2829

2930
import islpy as isl
3031

31-
from loopy.codegen.result import merge_codegen_results, wrap_in_if
32+
from loopy.codegen.result import CodeGenerationResult, merge_codegen_results, wrap_in_if
3233
from loopy.diagnostic import LoopyError
3334
from loopy.schedule import (
3435
Barrier,
@@ -41,8 +42,17 @@
4142
)
4243

4344

44-
def generate_code_for_sched_index(codegen_state, sched_index):
45+
if TYPE_CHECKING:
46+
from loopy.codegen import CodeGenerationState
47+
48+
49+
def generate_code_for_sched_index(
50+
codegen_state: CodeGenerationState,
51+
sched_index: int
52+
) -> CodeGenerationResult:
4553
kernel = codegen_state.kernel
54+
assert kernel.linearization is not None
55+
4656
sched_item = kernel.linearization[sched_index]
4757

4858
if isinstance(sched_item, CallKernel):

loopy/codegen/loop.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,29 @@
2424
"""
2525

2626

27+
from typing import TYPE_CHECKING
28+
2729
import islpy as isl
2830
from islpy import dim_type
2931
from pymbolic.mapper.stringifier import PREC_NONE
3032

3133
from loopy.codegen.control import build_loop_nest
32-
from loopy.codegen.result import merge_codegen_results
34+
from loopy.codegen.result import CodeGenerationResult, merge_codegen_results
3335
from loopy.diagnostic import LoopyError, warn
3436
from loopy.symbolic import flatten
3537

3638

39+
if TYPE_CHECKING:
40+
from collections.abc import Sequence
41+
42+
from loopy.codegen import CodeGenerationState
43+
from loopy.kernel import LoopKernel
44+
from loopy.typing import InameStr
45+
46+
3747
# {{{ conditional-reducing slab decomposition
3848

39-
def get_slab_decomposition(kernel, iname):
49+
def get_slab_decomposition(kernel: LoopKernel, iname: InameStr):
4050
iname_domain = kernel.get_inames_domain(iname)
4151

4252
if iname_domain.is_empty():
@@ -231,10 +241,14 @@ def intersect_kernel_with_slab(kernel, slab, iname):
231241

232242
# {{{ hw-parallel loop
233243

234-
def set_up_hw_parallel_loops(codegen_state, schedule_index, next_func,
235-
hw_inames_left=None):
244+
def set_up_hw_parallel_loops(
245+
codegen_state: CodeGenerationState, schedule_index: int,
246+
next_func,
247+
hw_inames_left: Sequence[InameStr] | None = None) -> CodeGenerationResult:
236248
kernel = codegen_state.kernel
237249

250+
assert kernel.linearization is not None
251+
238252
from loopy.kernel.data import (
239253
GroupInameTag,
240254
HardwareConcurrentTag,
@@ -248,7 +262,7 @@ def set_up_hw_parallel_loops(codegen_state, schedule_index, next_func,
248262
schedule_index)
249263

250264
if hw_inames_left is None:
251-
all_inames_by_insns = set()
265+
all_inames_by_insns: set[InameStr] = set()
252266
for insn_id in insn_ids_for_block:
253267
all_inames_by_insns |= kernel.insn_inames(insn_id)
254268

@@ -262,7 +276,7 @@ def set_up_hw_parallel_loops(codegen_state, schedule_index, next_func,
262276
global_size, local_size = kernel.get_grid_sizes_for_insn_ids(
263277
insn_ids_for_block, codegen_state.callables_table, return_dict=True)
264278

265-
hw_inames_left = hw_inames_left[:]
279+
hw_inames_left = list(hw_inames_left)
266280
iname = hw_inames_left.pop()
267281

268282
from loopy.symbolic import GroupHardwareAxisIndex, LocalHardwareAxisIndex
@@ -300,7 +314,7 @@ def set_up_hw_parallel_loops(codegen_state, schedule_index, next_func,
300314
# It's ok to find a bound that's too "loose". The conditional
301315
# generators will mop up after us.
302316
from loopy.kernel.tools import get_hw_axis_base_for_codegen
303-
lower_bound = get_hw_axis_base_for_codegen(kernel, iname)
317+
lower_bound = isl.PwAff.from_aff(get_hw_axis_base_for_codegen(kernel, iname))
304318

305319
# These bounds are 'implemented' by the hardware. Make sure
306320
# that the downstream conditional generators realize that.

loopy/codegen/result.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class CodeGenerationResult:
119119
"""
120120
host_program: GeneratedProgram | None
121121
device_programs: Sequence[GeneratedProgram]
122-
implemented_domains: Mapping[str, islpy.Set]
122+
implemented_domains: Mapping[str, list[islpy.Set]]
123123
host_preambles: Sequence[tuple[str, str]] = ()
124124
device_preambles: Sequence[tuple[str, str]] = ()
125125

@@ -227,8 +227,9 @@ def with_new_ast(self, codegen_state, new_ast):
227227
# {{{ support code for AST merging
228228

229229
def merge_codegen_results(
230-
codegen_state: CodeGenerationState,
231-
elements: Sequence[CodeGenerationResult | Any], collapse=True
230+
codegen_state: CodeGenerationState,
231+
elements: Sequence[CodeGenerationResult | Any],
232+
collapse: bool = True
232233
) -> CodeGenerationResult:
233234
elements = [el for el in elements if el is not None]
234235

@@ -247,7 +248,7 @@ def merge_codegen_results(
247248
new_device_programs = []
248249
new_device_preambles: list[tuple[str, str]] = []
249250
dev_program_names = set()
250-
implemented_domains: dict[str, islpy.Set] = {}
251+
implemented_domains: dict[str, list[islpy.Set]] = {}
251252
codegen_result = None
252253

253254
block_cls = codegen_state.ast_builder.ast_block_class

loopy/frontend/fortran/expression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
if TYPE_CHECKING:
3939
from collections.abc import Mapping
4040

41-
from loopy.symbolic import LexTable
41+
from pytools.lex import LexTable
4242

4343

4444
_less_than = intern("less_than")

0 commit comments

Comments
 (0)