Skip to content

Commit d55d1c4

Browse files
committed
Sharpen types, including precise islpy
Bump islpy dependency
1 parent 4c19613 commit d55d1c4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+49562
-110889
lines changed

.basedpyright/baseline.json

Lines changed: 46980 additions & 109688 deletions
Large diffs are not rendered by default.

doc/ref_internals.rst

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,39 @@ Schedule
6161
.. automodule:: loopy.schedule.tools
6262
.. automodule:: loopy.schedule.tree
6363

64+
References
65+
----------
66+
67+
Mostly things that Sphinx (our documentation tool) should resolve but won't.
68+
69+
.. class:: constantdict
70+
71+
See :class:`constantdict.constantdict`.
72+
73+
.. class:: DTypeLike
74+
75+
See :data:`numpy.typing.DTypeLike`.
76+
77+
.. currentmodule:: p
78+
79+
.. class:: Call
80+
81+
See :class:`pymbolic.primitives.Call`.
82+
83+
.. class:: CallWithKwargs
84+
85+
See :class:`pymbolic.primitives.CallWithKwargs`.
86+
87+
.. currentmodule:: isl
88+
89+
.. class:: Space
90+
91+
See :class:`islpy.Space`.
92+
93+
.. class:: Aff
94+
95+
See :class:`islpy.Aff`.
96+
97+
.. class:: PwAff
6498

99+
See :class:`islpy.PwAff`.

doc/ref_kernel.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ Instructions
270270
Assignment objects
271271
^^^^^^^^^^^^^^^^^^
272272

273+
.. currentmodule:: loopy.kernel.instruction
274+
275+
.. class:: Assignable
276+
277+
.. currentmodule:: loopy
278+
273279
.. autoclass:: Assignment
274280

275281
.. _assignment-syntax:

loopy/auto_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"""
2525

2626
from dataclasses import dataclass
27-
from typing import TYPE_CHECKING
27+
from typing import TYPE_CHECKING, cast
2828
from warnings import warn
2929

3030
import numpy as np
@@ -37,6 +37,8 @@
3737
if TYPE_CHECKING:
3838
import pyopencl.array as cla
3939

40+
from loopy.types import NumpyType
41+
4042

4143
AUTO_TEST_SKIP_RUN = False
4244

@@ -142,7 +144,7 @@ def make_ref_args(kernel, queue, parameters):
142144
"testing" % arg.name)
143145

144146
shape = evaluate_shape(arg.shape, parameters)
145-
dtype = arg.dtype
147+
dtype = cast("NumpyType", arg.dtype).dtype
146148

147149
is_output = arg.is_output
148150

loopy/check.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525

2626
import logging
2727
from collections import defaultdict
28+
from dataclasses import dataclass
2829
from functools import reduce
29-
from typing import TYPE_CHECKING
30+
from typing import TYPE_CHECKING, cast
3031

3132
import numpy as np
33+
from typing_extensions import override
3234

3335
import islpy as isl
3436
from islpy import dim_type
@@ -76,6 +78,7 @@
7678
if TYPE_CHECKING:
7779
from collections.abc import Mapping, Sequence
7880

81+
import pymbolic.primitives as p
7982
from pymbolic.typing import Expression
8083

8184
from loopy.kernel import LoopKernel
@@ -707,10 +710,12 @@ def subst_func(x):
707710
# }}}
708711

709712

710-
class _AccessCheckMapper(WalkMapper):
711-
def __init__(self, kernel, callables_table):
712-
self.kernel = kernel
713-
self.callables_table = callables_table
713+
@dataclass
714+
class _AccessCheckMapper(WalkMapper[[isl.Set, str]]):
715+
kernel: LoopKernel
716+
callables_table: CallablesTable
717+
718+
def __post_init__(self) -> None:
714719
super().__init__()
715720

716721
@memoize_method
@@ -719,14 +724,20 @@ def _make_slab(self, space, iname, start, stop):
719724
return make_slab(space, iname, start, stop)
720725

721726
@memoize_method
722-
def _get_access_range(self, domain, subscript):
723-
from loopy.symbolic import UnableToDetermineAccessRangeError, get_access_map
727+
def _get_access_range(
728+
self,
729+
domain: isl.Set,
730+
subscript: tuple[Expression, ...]
731+
):
732+
from loopy.diagnostic import UnableToDetermineAccessRangeError
733+
from loopy.symbolic import get_access_map
724734
try:
725735
return get_access_map(domain, subscript).range()
726736
except UnableToDetermineAccessRangeError:
727737
return None
728738

729-
def map_subscript(self, expr, domain, insn_id):
739+
@override
740+
def map_subscript(self, expr: p.Subscript, domain: isl.Set, insn_id: str):
730741
WalkMapper.map_subscript(self, expr, domain, insn_id)
731742

732743
from pymbolic.primitives import Variable
@@ -742,6 +753,7 @@ def map_subscript(self, expr, domain, insn_id):
742753
shape = tv.shape
743754

744755
if shape is not None:
756+
assert isinstance(shape, tuple)
745757
subscript = expr.index
746758

747759
if not isinstance(subscript, tuple):
@@ -787,7 +799,8 @@ def map_subscript(self, expr, domain, insn_id):
787799
" establish '%s' is a subset of '%s')."
788800
% (expr, insn_id, access_range, shape_domain))
789801

790-
def map_if(self, expr, domain, insn_id):
802+
@override
803+
def map_if(self, expr: p. If, domain: isl.Set, insn_id: str):
791804
from loopy.symbolic import condition_to_set
792805
then_set = condition_to_set(domain.space, expr.condition)
793806
if then_set is None:
@@ -800,7 +813,8 @@ def map_if(self, expr, domain, insn_id):
800813
self.rec(expr.then, domain & then_set, insn_id)
801814
self.rec(expr.else_, domain & else_set, insn_id)
802815

803-
def map_call(self, expr, domain, insn_id):
816+
@override
817+
def map_call(self, expr: p.Call, domain: isl.Set, insn_id: str):
804818
# perform access checks on the call arguments
805819
super().map_call(expr, domain, insn_id)
806820

@@ -817,7 +831,9 @@ def map_call(self, expr, domain, insn_id):
817831
and isinstance(self.callables_table[expr.function.name],
818832
CallableKernel)):
819833

820-
subkernel = self.callables_table[expr.function.name].subkernel
834+
subkernel = cast(
835+
"CallableKernel",
836+
self.callables_table[expr.function.name]).subkernel
821837

822838
# The plan here is to add the constraints coming from the values
823839
# args passed at a call-site as assumptions to the callee. To avoid
@@ -835,8 +851,8 @@ def map_call(self, expr, domain, insn_id):
835851

836852
kw_space = isl.Space.create_from_names(
837853
subkernel.isl_context, set=[],
838-
params=(get_dependencies(tuple(kwargs.values()))
839-
| set(kwargs.keys())))
854+
params=[*get_dependencies(tuple(kwargs.values())),
855+
*kwargs.keys()])
840856

841857
extra_assumptions = isl.BasicSet.universe(kw_space).params()
842858

@@ -894,7 +910,7 @@ def _check_bounds_inner(kernel: LoopKernel, callables_table: CallablesTable) ->
894910
domain, assumptions = isl.align_two(domain, kernel.assumptions)
895911
domain_with_assumptions = domain & assumptions
896912

897-
def run_acm(expr):
913+
def run_acm(expr: Expression):
898914
acm(expr, domain_with_assumptions, insn.id) # noqa: B023
899915
return expr
900916

@@ -1111,16 +1127,14 @@ def _check_variable_access_ordered_inner(kernel: LoopKernel) -> None:
11111127
# {{{ compute rev_depends, depends_on
11121128

11131129
# depends_on: mapping from insn_ids to their dependencies
1114-
depends_on: dict[str, set[str]] = {
1115-
not_none(insn.id): set() for insn in kernel.instructions}
1130+
depends_on: dict[str, set[str]] = {insn.id: set() for insn in kernel.instructions}
11161131
# rev_depends: mapping from insn_ids to their reverse deps.
1117-
rev_depends: dict[str, set[str]] = {
1118-
not_none(insn.id): set() for insn in kernel.instructions}
1132+
rev_depends: dict[str, set[str]] = {insn.id: set() for insn in kernel.instructions}
11191133

11201134
for insn in kernel.instructions:
1121-
depends_on[not_none(insn.id)].update(insn.depends_on)
1135+
depends_on[insn.id].update(insn.depends_on)
11221136
for dep in insn.depends_on:
1123-
rev_depends[dep].add(not_none(insn.id))
1137+
rev_depends[dep].add(insn.id)
11241138

11251139
# }}}
11261140

@@ -1588,7 +1602,7 @@ def check_that_temporaries_are_defined_in_subkernels_where_used(
15881602
def check_that_all_insns_are_scheduled(kernel: LoopKernel) -> None:
15891603
assert kernel.linearization is not None
15901604

1591-
all_schedulable_insns = {not_none(insn.id) for insn in kernel.instructions}
1605+
all_schedulable_insns = {insn.id for insn in kernel.instructions}
15921606
from loopy.schedule import sched_item_to_insn_id
15931607
scheduled_insns = {
15941608
insn_id
@@ -1659,7 +1673,10 @@ def _get_sub_array_ref_swept_range(
16591673
from loopy.symbolic import get_access_map
16601674
domain = kernel.get_inames_domain(frozenset({iname_var.name
16611675
for iname_var in sar.swept_inames}))
1662-
return get_access_map(domain, sar.swept_inames, kernel.assumptions).range()
1676+
return get_access_map(
1677+
domain.to_set(),
1678+
sar.swept_inames,
1679+
kernel.assumptions.to_set()).range()
16631680

16641681

16651682
def _are_sub_array_refs_equivalent(
@@ -1875,7 +1892,7 @@ def pre_codegen_checks(t_unit: TranslationUnit) -> None:
18751892

18761893
def check_implemented_domains(
18771894
kernel: LoopKernel,
1878-
implemented_domains: Mapping[str, isl.Set],
1895+
implemented_domains: Mapping[str, Sequence[isl.Set]],
18791896
code: str | None = None,
18801897
) -> bool:
18811898
from islpy import align_two, dim_type
@@ -1942,7 +1959,7 @@ def check_implemented_domains(
19421959
d_minus_i = desired_domain - insn_impl_domain
19431960

19441961
parameter_inames = {
1945-
insn_domain.get_dim_name(dim_type.param, i)
1962+
not_none(insn_domain.get_dim_name(dim_type.param, i))
19461963
for i in range(insn_impl_domain.dim(dim_type.param))}
19471964

19481965
lines = []

0 commit comments

Comments
 (0)