11use crate :: MirPass ;
2- use rustc_hir:: def_id:: DefId ;
32use rustc_hir:: lang_items:: LangItem ;
43use rustc_index:: IndexVec ;
54use rustc_middle:: mir:: * ;
65use rustc_middle:: mir:: {
76 interpret:: Scalar ,
8- visit:: { PlaceContext , Visitor } ,
7+ visit:: { MutatingUseContext , NonMutatingUseContext , PlaceContext , Visitor } ,
98} ;
10- use rustc_middle:: ty:: { Ty , TyCtxt , TypeAndMut } ;
9+ use rustc_middle:: ty:: { self , ParamEnv , Ty , TyCtxt , TypeAndMut } ;
1110use rustc_session:: Session ;
1211
1312pub struct CheckAlignment ;
@@ -30,30 +29,32 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
3029
3130 let basic_blocks = body. basic_blocks . as_mut ( ) ;
3231 let local_decls = & mut body. local_decls ;
32+ let param_env = tcx. param_env_reveal_all_normalized ( body. source . def_id ( ) ) ;
3333
34+ // This pass inserts new blocks. Each insertion changes the Location for all
35+ // statements/blocks after. Iterating or visiting the MIR in order would require updating
36+ // our current location after every insertion. By iterating backwards, we dodge this issue:
37+ // The only Locations that an insertion changes have already been handled.
3438 for block in ( 0 ..basic_blocks. len ( ) ) . rev ( ) {
3539 let block = block. into ( ) ;
3640 for statement_index in ( 0 ..basic_blocks[ block] . statements . len ( ) ) . rev ( ) {
3741 let location = Location { block, statement_index } ;
3842 let statement = & basic_blocks[ block] . statements [ statement_index] ;
3943 let source_info = statement. source_info ;
4044
41- let mut finder = PointerFinder {
42- local_decls,
43- tcx,
44- pointers : Vec :: new ( ) ,
45- def_id : body. source . def_id ( ) ,
46- } ;
47- for ( pointer, pointee_ty) in finder. find_pointers ( statement) {
48- debug ! ( "Inserting alignment check for {:?}" , pointer. ty( & * local_decls, tcx) . ty) ;
45+ let mut finder =
46+ PointerFinder { tcx, local_decls, param_env, pointers : Vec :: new ( ) } ;
47+ finder. visit_statement ( statement, location) ;
4948
49+ for ( local, ty) in finder. pointers {
50+ debug ! ( "Inserting alignment check for {:?}" , ty) ;
5051 let new_block = split_block ( basic_blocks, location) ;
5152 insert_alignment_check (
5253 tcx,
5354 local_decls,
5455 & mut basic_blocks[ block] ,
55- pointer ,
56- pointee_ty ,
56+ local ,
57+ ty ,
5758 source_info,
5859 new_block,
5960 ) ;
@@ -63,69 +64,71 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
6364 }
6465}
6566
66- impl < ' tcx , ' a > PointerFinder < ' tcx , ' a > {
67- fn find_pointers ( & mut self , statement : & Statement < ' tcx > ) -> Vec < ( Place < ' tcx > , Ty < ' tcx > ) > {
68- self . pointers . clear ( ) ;
69- self . visit_statement ( statement, Location :: START ) ;
70- core:: mem:: take ( & mut self . pointers )
71- }
72- }
73-
7467struct PointerFinder < ' tcx , ' a > {
75- local_decls : & ' a mut LocalDecls < ' tcx > ,
7668 tcx : TyCtxt < ' tcx > ,
77- def_id : DefId ,
69+ local_decls : & ' a mut LocalDecls < ' tcx > ,
70+ param_env : ParamEnv < ' tcx > ,
7871 pointers : Vec < ( Place < ' tcx > , Ty < ' tcx > ) > ,
7972}
8073
8174impl < ' tcx , ' a > Visitor < ' tcx > for PointerFinder < ' tcx , ' a > {
82- fn visit_rvalue ( & mut self , rvalue : & Rvalue < ' tcx > , location : Location ) {
83- if let Rvalue :: AddressOf ( ..) = rvalue {
84- // Ignore dereferences inside of an AddressOf
85- return ;
75+ fn visit_place ( & mut self , place : & Place < ' tcx > , context : PlaceContext , location : Location ) {
76+ // We want to only check reads and writes to Places, so we specifically exclude
77+ // Borrows and AddressOf.
78+ match context {
79+ PlaceContext :: MutatingUse (
80+ MutatingUseContext :: Store
81+ | MutatingUseContext :: AsmOutput
82+ | MutatingUseContext :: Call
83+ | MutatingUseContext :: Yield
84+ | MutatingUseContext :: Drop ,
85+ ) => { }
86+ PlaceContext :: NonMutatingUse (
87+ NonMutatingUseContext :: Copy | NonMutatingUseContext :: Move ,
88+ ) => { }
89+ _ => {
90+ return ;
91+ }
8692 }
87- self . super_rvalue ( rvalue, location) ;
88- }
8993
90- fn visit_place ( & mut self , place : & Place < ' tcx > , context : PlaceContext , _location : Location ) {
91- if let PlaceContext :: NonUse ( _) = context {
92- return ;
93- }
9494 if !place. is_indirect ( ) {
9595 return ;
9696 }
9797
98+ // Since Deref projections must come first and only once, the pointer for an indirect place
99+ // is the Local that the Place is based on.
98100 let pointer = Place :: from ( place. local ) ;
99- let pointer_ty = pointer . ty ( & * self . local_decls , self . tcx ) . ty ;
101+ let pointer_ty = self . local_decls [ place . local ] . ty ;
100102
101- // We only want to check unsafe pointers
103+ // We only want to check places based on unsafe pointers
102104 if !pointer_ty. is_unsafe_ptr ( ) {
103- trace ! ( "Indirect, but not an unsafe ptr, not checking {:?}" , pointer_ty ) ;
105+ trace ! ( "Indirect, but not based on an unsafe ptr, not checking {:?}" , place ) ;
104106 return ;
105107 }
106108
107- let Some ( pointee) = pointer_ty. builtin_deref ( true ) else {
108- debug ! ( "Indirect but no builtin deref: {:?}" , pointer_ty) ;
109+ let pointee_ty =
110+ pointer_ty. builtin_deref ( true ) . expect ( "no builtin_deref for an unsafe pointer" ) . ty ;
111+ // Ideally we'd support this in the future, but for now we are limited to sized types.
112+ if !pointee_ty. is_sized ( self . tcx , self . param_env ) {
113+ debug ! ( "Unsafe pointer, but pointee is not known to be sized: {:?}" , pointer_ty) ;
109114 return ;
110- } ;
111- let mut pointee_ty = pointee. ty ;
112- if pointee_ty. is_array ( ) || pointee_ty. is_slice ( ) || pointee_ty. is_str ( ) {
113- pointee_ty = pointee_ty. sequence_element_type ( self . tcx ) ;
114115 }
115116
116- if !pointee_ty. is_sized ( self . tcx , self . tcx . param_env_reveal_all_normalized ( self . def_id ) ) {
117- debug ! ( "Unsafe pointer, but unsized: {:?}" , pointer_ty) ;
117+ // Try to detect types we are sure have an alignment of 1 and skip the check
118+ // We don't need to look for str and slices, we already rejected unsized types above
119+ let element_ty = match pointee_ty. kind ( ) {
120+ ty:: Array ( ty, _) => * ty,
121+ _ => pointee_ty,
122+ } ;
123+ if [ self . tcx . types . bool , self . tcx . types . i8 , self . tcx . types . u8 ] . contains ( & element_ty) {
124+ debug ! ( "Trivially aligned place type: {:?}" , pointee_ty) ;
118125 return ;
119126 }
120127
121- if [ self . tcx . types . bool , self . tcx . types . i8 , self . tcx . types . u8 , self . tcx . types . str_ ]
122- . contains ( & pointee_ty)
123- {
124- debug ! ( "Trivially aligned pointee type: {:?}" , pointer_ty) ;
125- return ;
126- }
128+ // Ensure that this place is based on an aligned pointer.
129+ self . pointers . push ( ( pointer, pointee_ty) ) ;
127130
128- self . pointers . push ( ( pointer , pointee_ty ) )
131+ self . super_place ( place , context , location ) ;
129132 }
130133}
131134
0 commit comments