2525
2626import logging
2727from collections import defaultdict
28+ from dataclasses import dataclass
2829from functools import reduce
29- from typing import TYPE_CHECKING
30+ from typing import TYPE_CHECKING , cast
3031
3132import numpy as np
33+ from typing_extensions import override
3234
3335import islpy as isl
3436from islpy import dim_type
7678if TYPE_CHECKING :
7779 from collections .abc import Mapping , Sequence
7880
81+ import pymbolic .primitives as p
7982 from pymbolic .typing import Expression
8083
8184 from loopy .kernel import LoopKernel
@@ -707,10 +710,12 @@ def subst_func(x):
707710# }}}
708711
709712
710- class _AccessCheckMapper (WalkMapper ):
711- def __init__ (self , kernel , callables_table ):
712- self .kernel = kernel
713- self .callables_table = callables_table
713+ @dataclass
714+ class _AccessCheckMapper (WalkMapper [[isl .Set , str ]]):
715+ kernel : LoopKernel
716+ callables_table : CallablesTable
717+
718+ def __post_init__ (self ) -> None :
714719 super ().__init__ ()
715720
716721 @memoize_method
@@ -719,14 +724,20 @@ def _make_slab(self, space, iname, start, stop):
719724 return make_slab (space , iname , start , stop )
720725
721726 @memoize_method
722- def _get_access_range (self , domain , subscript ):
723- from loopy .symbolic import UnableToDetermineAccessRangeError , get_access_map
727+ def _get_access_range (
728+ self ,
729+ domain : isl .Set ,
730+ subscript : tuple [Expression , ...]
731+ ):
732+ from loopy .diagnostic import UnableToDetermineAccessRangeError
733+ from loopy .symbolic import get_access_map
724734 try :
725735 return get_access_map (domain , subscript ).range ()
726736 except UnableToDetermineAccessRangeError :
727737 return None
728738
729- def map_subscript (self , expr , domain , insn_id ):
739+ @override
740+ def map_subscript (self , expr : p .Subscript , domain : isl .Set , insn_id : str ):
730741 WalkMapper .map_subscript (self , expr , domain , insn_id )
731742
732743 from pymbolic .primitives import Variable
@@ -742,6 +753,7 @@ def map_subscript(self, expr, domain, insn_id):
742753 shape = tv .shape
743754
744755 if shape is not None :
756+ assert isinstance (shape , tuple )
745757 subscript = expr .index
746758
747759 if not isinstance (subscript , tuple ):
@@ -787,7 +799,8 @@ def map_subscript(self, expr, domain, insn_id):
787799 " establish '%s' is a subset of '%s')."
788800 % (expr , insn_id , access_range , shape_domain ))
789801
790- def map_if (self , expr , domain , insn_id ):
802+ @override
803+ def map_if (self , expr : p . If , domain : isl .Set , insn_id : str ):
791804 from loopy .symbolic import condition_to_set
792805 then_set = condition_to_set (domain .space , expr .condition )
793806 if then_set is None :
@@ -800,7 +813,8 @@ def map_if(self, expr, domain, insn_id):
800813 self .rec (expr .then , domain & then_set , insn_id )
801814 self .rec (expr .else_ , domain & else_set , insn_id )
802815
803- def map_call (self , expr , domain , insn_id ):
816+ @override
817+ def map_call (self , expr : p .Call , domain : isl .Set , insn_id : str ):
804818 # perform access checks on the call arguments
805819 super ().map_call (expr , domain , insn_id )
806820
@@ -817,7 +831,9 @@ def map_call(self, expr, domain, insn_id):
817831 and isinstance (self .callables_table [expr .function .name ],
818832 CallableKernel )):
819833
820- subkernel = self .callables_table [expr .function .name ].subkernel
834+ subkernel = cast (
835+ "CallableKernel" ,
836+ self .callables_table [expr .function .name ]).subkernel
821837
822838 # The plan here is to add the constraints coming from the values
823839 # args passed at a call-site as assumptions to the callee. To avoid
@@ -835,8 +851,8 @@ def map_call(self, expr, domain, insn_id):
835851
836852 kw_space = isl .Space .create_from_names (
837853 subkernel .isl_context , set = [],
838- params = ( get_dependencies (tuple (kwargs .values ()))
839- | set ( kwargs .keys ())) )
854+ params = [ * get_dependencies (tuple (kwargs .values ())),
855+ * kwargs .keys ()] )
840856
841857 extra_assumptions = isl .BasicSet .universe (kw_space ).params ()
842858
@@ -894,7 +910,7 @@ def _check_bounds_inner(kernel: LoopKernel, callables_table: CallablesTable) ->
894910 domain , assumptions = isl .align_two (domain , kernel .assumptions )
895911 domain_with_assumptions = domain & assumptions
896912
897- def run_acm (expr ):
913+ def run_acm (expr : Expression ):
898914 acm (expr , domain_with_assumptions , insn .id ) # noqa: B023
899915 return expr
900916
@@ -1111,16 +1127,14 @@ def _check_variable_access_ordered_inner(kernel: LoopKernel) -> None:
11111127 # {{{ compute rev_depends, depends_on
11121128
11131129 # depends_on: mapping from insn_ids to their dependencies
1114- depends_on : dict [str , set [str ]] = {
1115- not_none (insn .id ): set () for insn in kernel .instructions }
1130+ depends_on : dict [str , set [str ]] = {insn .id : set () for insn in kernel .instructions }
11161131 # rev_depends: mapping from insn_ids to their reverse deps.
1117- rev_depends : dict [str , set [str ]] = {
1118- not_none (insn .id ): set () for insn in kernel .instructions }
1132+ rev_depends : dict [str , set [str ]] = {insn .id : set () for insn in kernel .instructions }
11191133
11201134 for insn in kernel .instructions :
1121- depends_on [not_none ( insn .id ) ].update (insn .depends_on )
1135+ depends_on [insn .id ].update (insn .depends_on )
11221136 for dep in insn .depends_on :
1123- rev_depends [dep ].add (not_none ( insn .id ) )
1137+ rev_depends [dep ].add (insn .id )
11241138
11251139 # }}}
11261140
@@ -1588,7 +1602,7 @@ def check_that_temporaries_are_defined_in_subkernels_where_used(
15881602def check_that_all_insns_are_scheduled (kernel : LoopKernel ) -> None :
15891603 assert kernel .linearization is not None
15901604
1591- all_schedulable_insns = {not_none ( insn .id ) for insn in kernel .instructions }
1605+ all_schedulable_insns = {insn .id for insn in kernel .instructions }
15921606 from loopy .schedule import sched_item_to_insn_id
15931607 scheduled_insns = {
15941608 insn_id
@@ -1659,7 +1673,10 @@ def _get_sub_array_ref_swept_range(
16591673 from loopy .symbolic import get_access_map
16601674 domain = kernel .get_inames_domain (frozenset ({iname_var .name
16611675 for iname_var in sar .swept_inames }))
1662- return get_access_map (domain , sar .swept_inames , kernel .assumptions ).range ()
1676+ return get_access_map (
1677+ domain .to_set (),
1678+ sar .swept_inames ,
1679+ kernel .assumptions .to_set ()).range ()
16631680
16641681
16651682def _are_sub_array_refs_equivalent (
@@ -1875,7 +1892,7 @@ def pre_codegen_checks(t_unit: TranslationUnit) -> None:
18751892
18761893def check_implemented_domains (
18771894 kernel : LoopKernel ,
1878- implemented_domains : Mapping [str , isl .Set ],
1895+ implemented_domains : Mapping [str , Sequence [ isl .Set ] ],
18791896 code : str | None = None ,
18801897 ) -> bool :
18811898 from islpy import align_two , dim_type
@@ -1942,7 +1959,7 @@ def check_implemented_domains(
19421959 d_minus_i = desired_domain - insn_impl_domain
19431960
19441961 parameter_inames = {
1945- insn_domain .get_dim_name (dim_type .param , i )
1962+ not_none ( insn_domain .get_dim_name (dim_type .param , i ) )
19461963 for i in range (insn_impl_domain .dim (dim_type .param ))}
19471964
19481965 lines = []
0 commit comments