Skip to content

Commit 5841627

Browse files
auhtLPTK
andauthored
Constraint solving for function overloading (#213)
Co-authored-by: Lionel Parreaux <[email protected]>
1 parent c389926 commit 5841627

File tree

14 files changed

+969
-92
lines changed

14 files changed

+969
-92
lines changed

compiler/shared/main/scala/mlscript/compiler/ClassLifter.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ class ClassLifter(logDebugMsg: Boolean = false) {
510510
val nlhs = liftType(lb)
511511
val nrhs = liftType(ub)
512512
Bounds(nlhs._1, nrhs._1) -> (nlhs._2 ++ nrhs._2)
513-
case Constrained(base: Type, bounds, where) =>
513+
case Constrained(base: Type, bounds, where, tscs) =>
514514
val (nTargs, nCtx) = bounds.map { case (tv, Bounds(lb, ub)) =>
515515
val nlhs = liftType(lb)
516516
val nrhs = liftType(ub)
@@ -521,10 +521,18 @@ class ClassLifter(logDebugMsg: Boolean = false) {
521521
val nrhs = liftType(ub)
522522
Bounds(nlhs._1, nrhs._1) -> (nlhs._2 ++ nrhs._2)
523523
}.unzip
524+
val (tscs0, nCtx3) = tscs.map { case (tvs, cs) =>
525+
val (ntvs,c0) = tvs.map { case (p,v) =>
526+
val (nv, c) = liftType(v)
527+
(p,nv) -> c
528+
}.unzip
529+
val (ncs,c1) = cs.map(_.map(liftType).unzip).unzip
530+
(ntvs,ncs) -> (c0 ++ c1.flatten)
531+
}.unzip
524532
val (nBase, bCtx) = liftType(base)
525-
Constrained(nBase, nTargs, bounds2) ->
526-
((nCtx ++ nCtx2).fold(emptyCtx)(_ ++ _) ++ bCtx)
527-
case Constrained(_, _, _) => die
533+
Constrained(nBase, nTargs, bounds2, tscs0) ->
534+
((nCtx ++ nCtx2 ++ nCtx3.flatten).fold(emptyCtx)(_ ++ _) ++ bCtx)
535+
case Constrained(_, _, _, _) => die
528536
case Function(lhs, rhs) =>
529537
val nlhs = liftType(lhs)
530538
val nrhs = liftType(rhs)

shared/src/main/scala/mlscript/ConstraintSolver.scala

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,35 @@ class ConstraintSolver extends NormalForms { self: Typer =>
627627
recLb(ar.inner, b.inner)
628628
rec(b.inner.ub, ar.inner.ub, false)
629629
case (LhsRefined(S(b: ArrayBase), ts, r, _), _) => reportError()
630+
case (LhsRefined(S(ov: Overload), ts, r, trs), RhsBases(_, S(L(f: FunctionType)), _)) if noApproximateOverload =>
631+
TupleSetConstraints.mk(ov, f) match {
632+
case S(tsc) =>
633+
if (tsc.tvs.nonEmpty) {
634+
tsc.tvs.mapValuesIter(_.unwrapProxies).zipWithIndex.flatMap {
635+
case ((true, tv: TV), i) => tv.lowerBounds.iterator.map((_,tv,i,true))
636+
case ((false, tv: TV), i) => tv.upperBounds.iterator.map((_,tv,i,false))
637+
case _ => Nil
638+
}.find {
639+
case (b,_,i,_) =>
640+
tsc.updateImpl(i,b)
641+
tsc.constraints.isEmpty
642+
}.foreach {
643+
case (b,tv,_,p) => if (p) rec(b,tv,false) else rec(tv,b,false)
644+
}
645+
if (tsc.constraints.sizeCompare(1) === 0) {
646+
tsc.tvs.values.map(_.unwrapProxies).foreach {
647+
case tv: TV => tv.tsc.remove(tsc)
648+
case _ => ()
649+
}
650+
tsc.constraints.head.iterator.zip(tsc.tvs).foreach {
651+
case (c, (pol, t)) =>
652+
if (!pol) rec(c, t, false)
653+
if (pol) rec(t, c, false)
654+
}
655+
}
656+
}
657+
case N => reportError(S(msg"is not an instance of `${f.expNeg}`"))
658+
}
630659
case (LhsRefined(S(ov: Overload), ts, r, trs), _) =>
631660
annoying(Nil, LhsRefined(S(ov.approximatePos), ts, r, trs), Nil, done_rs) // TODO remove approx. with ambiguous constraints
632661
case (LhsRefined(S(Without(b, ns)), ts, r, _), RhsBases(pts, N | S(L(_)), _)) =>
@@ -832,13 +861,57 @@ class ConstraintSolver extends NormalForms { self: Typer =>
832861
val newBound = (cctx._1 ::: cctx._2.reverse).foldRight(rhs)((c, ty) =>
833862
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
834863
lhs.upperBounds ::= newBound // update the bound
864+
if (noApproximateOverload) {
865+
lhs.tsc.foreachEntry { (tsc, v) =>
866+
v.foreach { i =>
867+
if (!tsc.tvs(i)._1) {
868+
tsc.updateOn(i, rhs)
869+
if (tsc.constraints.isEmpty) reportError()
870+
}
871+
}
872+
}
873+
val u = lhs.tsc.keysIterator.filter(_.constraints.sizeCompare(1)===0).duplicate
874+
u._1.foreach { k =>
875+
k.tvs.mapValuesIter(_.unwrapProxies).foreach {
876+
case (_,tv: TV) => tv.tsc.remove(k)
877+
case _ => ()
878+
}
879+
}
880+
u._2.foreach { k =>
881+
k.constraints.head.iterator.zip(k.tvs).foreach {
882+
case (c, (pol, t)) => if (pol) rec(t, c, false) else rec(c, t, false)
883+
}
884+
}
885+
}
835886
lhs.lowerBounds.foreach(rec(_, rhs, true)) // propagate from the bound
836887

837888
case (lhs, rhs: TypeVariable) if lhs.level <= rhs.level =>
838889
println(s"NEW $rhs LB (${lhs.level})")
839890
val newBound = (cctx._1 ::: cctx._2.reverse).foldLeft(lhs)((ty, c) =>
840891
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
841892
rhs.lowerBounds ::= newBound // update the bound
893+
if (noApproximateOverload) {
894+
rhs.tsc.foreachEntry { (tsc, v) =>
895+
v.foreach { i =>
896+
if(tsc.tvs(i)._1) {
897+
tsc.updateOn(i, lhs)
898+
if (tsc.constraints.isEmpty) reportError()
899+
}
900+
}
901+
}
902+
val u = rhs.tsc.keysIterator.filter(_.constraints.sizeCompare(1)===0).duplicate
903+
u._1.foreach { k =>
904+
k.tvs.mapValuesIter(_.unwrapProxies).foreach {
905+
case (_,tv: TV) => tv.tsc.remove(k)
906+
case _ => ()
907+
}
908+
}
909+
u._2.foreach { k =>
910+
k.constraints.head.iterator.zip(k.tvs).foreach {
911+
case (c, (pol, t)) => if (pol) rec(t, c, false) else rec(c, t, false)
912+
}
913+
}
914+
}
842915
rhs.upperBounds.foreach(rec(lhs, _, true)) // propagate from the bound
843916

844917

@@ -1562,9 +1635,24 @@ class ConstraintSolver extends NormalForms { self: Typer =>
15621635
assert(lvl <= below, "this condition should be false for the result to be correct")
15631636
lvl
15641637
})
1638+
val freshentsc = tv.tsc.flatMap { case (tsc,_) =>
1639+
if (tsc.tvs.values.map(_.unwrapProxies).forall {
1640+
case tv: TV => !freshened.contains(tv)
1641+
case _ => true
1642+
}) S(tsc) else N
1643+
}
15651644
freshened += tv -> v
15661645
v.lowerBounds = tv.lowerBounds.mapConserve(freshen)
15671646
v.upperBounds = tv.upperBounds.mapConserve(freshen)
1647+
freshentsc.foreach { tsc =>
1648+
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
1649+
t.constraints = t.constraints.map(_.map(freshen))
1650+
t.tvs = t.tvs.map(x => (x._1,freshen(x._2)))
1651+
t.tvs.values.map(_.unwrapProxies).zipWithIndex.foreach {
1652+
case (tv: TV, i) => tv.tsc.updateWith(t)(_.map(_ + i).orElse(S(Set(i))))
1653+
case _ => ()
1654+
}
1655+
}
15681656
v
15691657
}
15701658

shared/src/main/scala/mlscript/NuTypeDefs.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,19 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>
11651165
ctx.nextLevel { implicit ctx: Ctx =>
11661166
assert(fd.tparams.sizeCompare(tparamsSkolems) === 0, (fd.tparams, tparamsSkolems))
11671167
vars ++ tparamsSkolems |> { implicit vars =>
1168-
typeTerm(body)
1168+
val ty = typeTerm(body)
1169+
if (noApproximateOverload) {
1170+
val ambiguous = ty.getVars.unsorted.flatMap(_.tsc.keys.flatMap(_.tvs))
1171+
.groupBy(_._2)
1172+
.filter { case (v,pvs) => pvs.sizeIs > 1 }
1173+
if (ambiguous.nonEmpty) raise(ErrorReport(
1174+
msg"ambiguous" -> N ::
1175+
ambiguous.map { case (v,_) =>
1176+
msg"cannot determine satisfiability of type ${v.expPos}" -> v.prov.loco
1177+
}.toList
1178+
, true))
1179+
}
1180+
ty
11691181
}
11701182
}
11711183
} else {

shared/src/main/scala/mlscript/TypeSimplifier.scala

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ trait TypeSimplifier { self: Typer =>
2525
println(s"allVarPols: ${printPols(allVarPols)}")
2626

2727
val renewed = MutMap.empty[TypeVariable, TypeVariable]
28+
val renewedtsc = MutMap.empty[TupleSetConstraints, TupleSetConstraints]
2829

2930
def renew(tv: TypeVariable): TypeVariable =
3031
renewed.getOrElseUpdate(tv,
@@ -78,8 +79,17 @@ trait TypeSimplifier { self: Typer =>
7879
).map(process(_, S(false -> tv)))
7980
.reduceOption(_ &- _).filterNot(_.isTop).toList
8081
else Nil
82+
if (noApproximateOverload)
83+
nv.tsc ++= tv.tsc.iterator.map { case (tsc, i) => renewedtsc.get(tsc) match {
84+
case S(tsc) => (tsc, i)
85+
case N if inPlace => (tsc, i)
86+
case N =>
87+
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
88+
renewedtsc += tsc -> t
89+
t.tvs = t.tvs.map(x => (x._1, process(x._2, N)))
90+
(t, i)
91+
}}
8192
}
82-
8393
nv
8494

8595
case ComposedType(true, l, r) =>
@@ -549,9 +559,17 @@ trait TypeSimplifier { self: Typer =>
549559
analyzed1.setAndIfUnset(tv -> pol(tv).getOrElse(false)) { apply(pol)(ty) }
550560
case N =>
551561
if (pol(tv) =/= S(false))
552-
analyzed1.setAndIfUnset(tv -> true) { tv.lowerBounds.foreach(apply(pol.at(tv.level, true))) }
562+
analyzed1.setAndIfUnset(tv -> true) {
563+
tv.lowerBounds.foreach(apply(pol.at(tv.level, true)))
564+
if (noApproximateOverload)
565+
tv.tsc.keys.flatMap(_.tvs).foreach(u => apply(pol.at(tv.level,u._1))(u._2))
566+
}
553567
if (pol(tv) =/= S(true))
554-
analyzed1.setAndIfUnset(tv -> false) { tv.upperBounds.foreach(apply(pol.at(tv.level, false))) }
568+
analyzed1.setAndIfUnset(tv -> false) {
569+
tv.upperBounds.foreach(apply(pol.at(tv.level, false)))
570+
if (noApproximateOverload)
571+
tv.tsc.keys.flatMap(_.tvs).foreach(u => apply(pol.at(tv.level,u._1))(u._2))
572+
}
555573
}
556574
case _ =>
557575
super.apply(pol)(st)
@@ -643,8 +661,11 @@ trait TypeSimplifier { self: Typer =>
643661
case tv: TypeVariable =>
644662
pol(tv) match {
645663
case S(pol_tv) =>
646-
if (analyzed2.add(pol_tv -> tv))
664+
if (analyzed2.add(pol_tv -> tv)) {
647665
processImpl(st, pol, pol_tv)
666+
if (noApproximateOverload)
667+
tv.tsc.keys.flatMap(_.tvs).foreach(u => processImpl(u._2,pol.at(tv.level,u._1),pol_tv))
668+
}
648669
case N =>
649670
if (analyzed2.add(true -> tv))
650671
// * To compute the positive co-occurrences
@@ -690,6 +711,7 @@ trait TypeSimplifier { self: Typer =>
690711
case S(p) =>
691712
(if (p) tv2.lowerBounds else tv2.upperBounds).foreach(go)
692713
// (if (p) getLbs(tv2) else getUbs(tv2)).foreach(go)
714+
if (noApproximateOverload) tv2.tsc.keys.flatMap(_.tvs).foreach(u => go(u._2))
693715
case N =>
694716
trace(s"Analyzing invar-occ of $tv2") {
695717
analyze2(tv2, pol)
@@ -789,7 +811,7 @@ trait TypeSimplifier { self: Typer =>
789811

790812
// * Remove variables that are 'dominated' by another type or variable
791813
// * A variable v dominated by T if T is in both of v's positive and negative cooccurrences
792-
allVars.foreach { case v => if (v.assignedTo.isEmpty && !varSubst.contains(v)) {
814+
allVars.foreach { case v => if (v.assignedTo.isEmpty && !varSubst.contains(v) && v.tsc.isEmpty) {
793815
println(s"2[v] $v ${coOccurrences.get(true -> v)} ${coOccurrences.get(false -> v)}")
794816

795817
coOccurrences.get(true -> v).iterator.flatMap(_.iterator).foreach {
@@ -807,6 +829,7 @@ trait TypeSimplifier { self: Typer =>
807829

808830
case w: TV if !(w is v) && !varSubst.contains(w) && !varSubst.contains(v) && !recVars(v)
809831
&& coOccurrences.get(false -> v).exists(_(w))
832+
&& w.tsc.isEmpty
810833
=>
811834
// * Here we know that v is 'dominated' by w, so v can be inlined.
812835
// * Note that we don't want to unify the two variables here
@@ -833,7 +856,7 @@ trait TypeSimplifier { self: Typer =>
833856

834857
// * Unify equivalent variables based on polar co-occurrence analysis:
835858
allVars.foreach { case v =>
836-
if (!v.assignedTo.isDefined && !varSubst.contains(v)) // TODO also handle v.assignedTo.isDefined?
859+
if (!v.assignedTo.isDefined && !varSubst.contains(v) && v.tsc.isEmpty) // TODO also handle v.assignedTo.isDefined?
837860
trace(s"3[v] $v +${coOccurrences.get(true -> v).mkString} -${coOccurrences.get(false -> v).mkString}") {
838861

839862
def go(pol: Bool): Unit = coOccurrences.get(pol -> v).iterator.flatMap(_.iterator).foreach {
@@ -850,6 +873,7 @@ trait TypeSimplifier { self: Typer =>
850873
)
851874
&& (v.level === w.level)
852875
// ^ Don't merge variables of differing levels
876+
&& w.tsc.isEmpty
853877
=>
854878
trace(s"[w] $w ${printPol(S(pol))}${coOccurrences.get(pol -> w).mkString}") {
855879

@@ -923,6 +947,7 @@ trait TypeSimplifier { self: Typer =>
923947
println(s"[rec] ${recVars}")
924948

925949
val renewals = MutMap.empty[TypeVariable, TypeVariable]
950+
val renewaltsc = MutMap.empty[TupleSetConstraints, TupleSetConstraints]
926951

927952
val semp = Set.empty[TV]
928953

@@ -999,7 +1024,7 @@ trait TypeSimplifier { self: Typer =>
9991024
nv
10001025
})
10011026
pol(tv) match {
1002-
case S(p) if inlineBounds && !occursInvariantly(tv) && !recVars.contains(tv) =>
1027+
case S(p) if inlineBounds && !occursInvariantly(tv) && !recVars.contains(tv) && tv.tsc.isEmpty =>
10031028
// * Inline the bounds of non-rec non-invar-occ type variables
10041029
println(s"Inlining [${printPol(p)}] bounds of $tv (~> $res)")
10051030
// if (p) mergeTransform(true, pol, tv, Set.single(tv), canDistribForall) | res
@@ -1017,6 +1042,15 @@ trait TypeSimplifier { self: Typer =>
10171042
res.lowerBounds = tv.lowerBounds.map(transform(_, pol.at(tv.level, true), Set.single(tv)))
10181043
if (occNums.contains(false -> tv))
10191044
res.upperBounds = tv.upperBounds.map(transform(_, pol.at(tv.level, false), Set.single(tv)))
1045+
if (noApproximateOverload)
1046+
res.tsc ++= tv.tsc.map { case (tsc, i) => renewaltsc.get(tsc) match {
1047+
case S(tsc) => (tsc, i)
1048+
case N =>
1049+
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
1050+
renewaltsc += tsc -> t
1051+
t.tvs = t.tvs.map(x => (x._1, transform(x._2, PolMap.neu, Set.empty)))
1052+
(t, i)
1053+
}}
10201054
}
10211055
res
10221056
}()

0 commit comments

Comments
 (0)