Skip to content

Commit 0ef8ace

Browse files
Merge #2567
2567: Handle impl Trait more correctly r=flodiebold a=flodiebold When calling a function, argument-position impl Trait is transparent; same for return-position impl Trait when inside the function. So in these cases, we need to represent that type not by `Ty::Opaque`, but by a type variable that can be unified with whatever flows into there. Co-authored-by: Florian Diebold <[email protected]>
2 parents 4e24b25 + 9185359 commit 0ef8ace

File tree

4 files changed

+91
-3
lines changed

4 files changed

+91
-3
lines changed

crates/ra_hir_ty/src/infer.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use hir_def::{
3232
use hir_expand::{diagnostics::DiagnosticSink, name::name};
3333
use ra_arena::map::ArenaMap;
3434
use ra_prof::profile;
35+
use test_utils::tested_by;
3536

3637
use super::{
3738
primitive::{FloatTy, IntTy},
@@ -274,6 +275,29 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
274275
self.normalize_associated_types_in(ty)
275276
}
276277

278+
/// Replaces `impl Trait` in `ty` by type variables and obligations for
279+
/// those variables. This is done for function arguments when calling a
280+
/// function, and for return types when inside the function body, i.e. in
281+
/// the cases where the `impl Trait` is 'transparent'. In other cases, `impl
282+
/// Trait` is represented by `Ty::Opaque`.
283+
fn insert_vars_for_impl_trait(&mut self, ty: Ty) -> Ty {
284+
ty.fold(&mut |ty| match ty {
285+
Ty::Opaque(preds) => {
286+
tested_by!(insert_vars_for_impl_trait);
287+
let var = self.table.new_type_var();
288+
let var_subst = Substs::builder(1).push(var.clone()).build();
289+
self.obligations.extend(
290+
preds
291+
.iter()
292+
.map(|pred| pred.clone().subst_bound_vars(&var_subst))
293+
.filter_map(Obligation::from_predicate),
294+
);
295+
var
296+
}
297+
_ => ty,
298+
})
299+
}
300+
277301
/// Replaces Ty::Unknown by a new type var, so we can maybe still infer it.
278302
fn insert_type_vars_shallow(&mut self, ty: Ty) -> Ty {
279303
match ty {
@@ -414,7 +438,8 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
414438

415439
self.infer_pat(*pat, &ty, BindingMode::default());
416440
}
417-
self.return_ty = self.make_ty(&data.ret_type);
441+
let return_ty = self.make_ty(&data.ret_type);
442+
self.return_ty = self.insert_vars_for_impl_trait(return_ty);
418443
}
419444

420445
fn infer_body(&mut self) {

crates/ra_hir_ty/src/infer/expr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
613613
continue;
614614
}
615615

616+
let param_ty = self.insert_vars_for_impl_trait(param_ty);
616617
let param_ty = self.normalize_associated_types_in(param_ty);
617618
self.infer_expr_coerce(arg, &Expectation::has_type(param_ty.clone()));
618619
}

crates/ra_hir_ty/src/marks.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ test_utils::marks!(
66
type_var_resolves_to_int_var
77
match_ergonomics_ref
88
coerce_merge_fail_fallback
9+
insert_vars_for_impl_trait
910
);

crates/ra_hir_ty/src/tests/traits.rs

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
use super::{infer, type_at, type_at_pos};
2-
use crate::test_db::TestDB;
31
use insta::assert_snapshot;
2+
43
use ra_db::fixture::WithFixture;
4+
use test_utils::covers;
5+
6+
use super::{infer, infer_with_mismatches, type_at, type_at_pos};
7+
use crate::test_db::TestDB;
58

69
#[test]
710
fn infer_await() {
@@ -1486,3 +1489,61 @@ fn test<T, U>() where T: Trait<U::Item>, U: Trait<T::Item> {
14861489
// this is a legitimate cycle
14871490
assert_eq!(t, "{unknown}");
14881491
}
1492+
1493+
#[test]
1494+
fn unify_impl_trait() {
1495+
covers!(insert_vars_for_impl_trait);
1496+
assert_snapshot!(
1497+
infer_with_mismatches(r#"
1498+
trait Trait<T> {}
1499+
1500+
fn foo(x: impl Trait<u32>) { loop {} }
1501+
fn bar<T>(x: impl Trait<T>) -> T { loop {} }
1502+
1503+
struct S<T>(T);
1504+
impl<T> Trait<T> for S<T> {}
1505+
1506+
fn default<T>() -> T { loop {} }
1507+
1508+
fn test() -> impl Trait<i32> {
1509+
let s1 = S(default());
1510+
foo(s1);
1511+
let x: i32 = bar(S(default()));
1512+
S(default())
1513+
}
1514+
"#, true),
1515+
@r###"
1516+
[27; 28) 'x': impl Trait<u32>
1517+
[47; 58) '{ loop {} }': ()
1518+
[49; 56) 'loop {}': !
1519+
[54; 56) '{}': ()
1520+
[69; 70) 'x': impl Trait<T>
1521+
[92; 103) '{ loop {} }': T
1522+
[94; 101) 'loop {}': !
1523+
[99; 101) '{}': ()
1524+
[172; 183) '{ loop {} }': T
1525+
[174; 181) 'loop {}': !
1526+
[179; 181) '{}': ()
1527+
[214; 310) '{ ...t()) }': S<i32>
1528+
[224; 226) 's1': S<u32>
1529+
[229; 230) 'S': S<u32>(T) -> S<T>
1530+
[229; 241) 'S(default())': S<u32>
1531+
[231; 238) 'default': fn default<u32>() -> T
1532+
[231; 240) 'default()': u32
1533+
[247; 250) 'foo': fn foo(impl Trait<u32>) -> ()
1534+
[247; 254) 'foo(s1)': ()
1535+
[251; 253) 's1': S<u32>
1536+
[264; 265) 'x': i32
1537+
[273; 276) 'bar': fn bar<i32>(impl Trait<T>) -> T
1538+
[273; 290) 'bar(S(...lt()))': i32
1539+
[277; 278) 'S': S<i32>(T) -> S<T>
1540+
[277; 289) 'S(default())': S<i32>
1541+
[279; 286) 'default': fn default<i32>() -> T
1542+
[279; 288) 'default()': i32
1543+
[296; 297) 'S': S<i32>(T) -> S<T>
1544+
[296; 308) 'S(default())': S<i32>
1545+
[298; 305) 'default': fn default<i32>() -> T
1546+
[298; 307) 'default()': i32
1547+
"###
1548+
);
1549+
}

0 commit comments

Comments
 (0)