2525
2626import logging
2727from collections import defaultdict
28+ from dataclasses import dataclass
2829from functools import reduce
29- from typing import TYPE_CHECKING
30+ from typing import TYPE_CHECKING , cast
3031
3132import numpy as np
33+ from typing_extensions import override
3234
3335import islpy as isl
3436from islpy import dim_type
7678if TYPE_CHECKING :
7779 from collections .abc import Mapping , Sequence
7880
81+ import pymbolic .primitives as p
7982 from pymbolic .typing import Expression
8083
8184 from loopy .kernel import LoopKernel
@@ -707,26 +710,31 @@ def subst_func(x):
707710# }}}
708711
709712
710- class _AccessCheckMapper (WalkMapper ):
711- def __init__ (self , kernel , callables_table ):
712- self .kernel = kernel
713- self .callables_table = callables_table
714- super ().__init__ ()
713+ @dataclass
714+ class _AccessCheckMapper (WalkMapper [[isl .BasicSet , str ]]):
715+ kernel : LoopKernel
716+ callables_table : CallablesTable
715717
716718 @memoize_method
717719 def _make_slab (self , space , iname , start , stop ):
718720 from loopy .isl_helpers import make_slab
719721 return make_slab (space , iname , start , stop )
720722
721723 @memoize_method
722- def _get_access_range (self , domain , subscript ):
723- from loopy .symbolic import UnableToDetermineAccessRangeError , get_access_map
724+ def _get_access_range (
725+ self ,
726+ domain : isl .BasicSet ,
727+ subscript : tuple [Expression , ...]
728+ ):
729+ from loopy .diagnostic import UnableToDetermineAccessRangeError
730+ from loopy .symbolic import get_access_map
724731 try :
725732 return get_access_map (domain , subscript ).range ()
726733 except UnableToDetermineAccessRangeError :
727734 return None
728735
729- def map_subscript (self , expr , domain , insn_id ):
736+ @override
737+ def map_subscript (self , expr : p .Subscript , domain : isl .BasicSet , insn_id : str ):
730738 WalkMapper .map_subscript (self , expr , domain , insn_id )
731739
732740 from pymbolic .primitives import Variable
@@ -742,6 +750,7 @@ def map_subscript(self, expr, domain, insn_id):
742750 shape = tv .shape
743751
744752 if shape is not None :
753+ assert isinstance (shape , tuple )
745754 subscript = expr .index
746755
747756 if not isinstance (subscript , tuple ):
@@ -787,7 +796,8 @@ def map_subscript(self, expr, domain, insn_id):
787796 " establish '%s' is a subset of '%s')."
788797 % (expr , insn_id , access_range , shape_domain ))
789798
790- def map_if (self , expr , domain , insn_id ):
799+ @override
800+ def map_if (self , expr : p . If , domain : isl .BasicSet , insn_id : str ):
791801 from loopy .symbolic import condition_to_set
792802 then_set = condition_to_set (domain .space , expr .condition )
793803 if then_set is None :
@@ -800,7 +810,8 @@ def map_if(self, expr, domain, insn_id):
800810 self .rec (expr .then , domain & then_set , insn_id )
801811 self .rec (expr .else_ , domain & else_set , insn_id )
802812
803- def map_call (self , expr , domain , insn_id ):
813+ @override
814+ def map_call (self , expr : p .Call , domain : isl .BasicSet , insn_id : str ):
804815 # perform access checks on the call arguments
805816 super ().map_call (expr , domain , insn_id )
806817
@@ -817,7 +828,9 @@ def map_call(self, expr, domain, insn_id):
817828 and isinstance (self .callables_table [expr .function .name ],
818829 CallableKernel )):
819830
820- subkernel = self .callables_table [expr .function .name ].subkernel
831+ subkernel = cast (
832+ "CallableKernel" ,
833+ self .callables_table [expr .function .name ]).subkernel
821834
822835 # The plan here is to add the constraints coming from the values
823836 # args passed at a call-site as assumptions to the callee. To avoid
@@ -835,8 +848,8 @@ def map_call(self, expr, domain, insn_id):
835848
836849 kw_space = isl .Space .create_from_names (
837850 subkernel .isl_context , set = [],
838- params = ( get_dependencies (tuple (kwargs .values ()))
839- | set ( kwargs .keys ())) )
851+ params = [ * get_dependencies (tuple (kwargs .values ())),
852+ * kwargs .keys ()] )
840853
841854 extra_assumptions = isl .BasicSet .universe (kw_space ).params ()
842855
@@ -894,8 +907,8 @@ def _check_bounds_inner(kernel: LoopKernel, callables_table: CallablesTable) ->
894907 domain , assumptions = isl .align_two (domain , kernel .assumptions )
895908 domain_with_assumptions = domain & assumptions
896909
897- def run_acm (expr ):
898- acm (expr , domain_with_assumptions , insn .id ) # noqa: B023
910+ def run_acm (expr : Expression ):
911+ acm (expr , domain_with_assumptions , not_none ( insn .id ) ) # noqa: B023
899912 return expr
900913
901914 insn .with_transformed_expressions (run_acm )
@@ -1875,7 +1888,7 @@ def pre_codegen_checks(t_unit: TranslationUnit) -> None:
18751888
18761889def check_implemented_domains (
18771890 kernel : LoopKernel ,
1878- implemented_domains : Mapping [str , isl .Set ],
1891+ implemented_domains : Mapping [str , Sequence [ isl .Set ] ],
18791892 code : str | None = None ,
18801893 ) -> bool :
18811894 from islpy import align_two , dim_type
@@ -1942,7 +1955,7 @@ def check_implemented_domains(
19421955 d_minus_i = desired_domain - insn_impl_domain
19431956
19441957 parameter_inames = {
1945- insn_domain .get_dim_name (dim_type .param , i )
1958+ not_none ( insn_domain .get_dim_name (dim_type .param , i ) )
19461959 for i in range (insn_impl_domain .dim (dim_type .param ))}
19471960
19481961 lines = []
0 commit comments