Support for rigid variables in the solver.

This commit is contained in:
Olivier 2022-01-17 09:36:13 +01:00
parent 92eed0ebbc
commit b7b0bd9fe8
13 changed files with 275 additions and 50 deletions

View File

@ -306,25 +306,6 @@ let convert env params ty =
exception VariableConflict of string
let convert_annot env ((rigidity, flex_vars, ty) : ML.type_annotation)
: (variable * F.nominal_type co, 'r) binder
= fun k ->
assert (rigidity = ML.Flexible);
let@ params, witnesses =
flex_vars |> mapM_both (fun alpha k ->
let@ v = exist in
k (
(alpha, v),
let+ ty' = witness v in (alpha, ty'))
)
in
let@ v = convert env params ty in
k (v,
let+ ws = witnesses in
let translate_var alpha = List.assoc alpha ws in
ML2F.translate_type translate_var ty
)
(* -------------------------------------------------------------------------- *)
(* We will need a type environment to keep trace of term variables that must
@ -402,11 +383,30 @@ let hastype (typedecl_env : ML.datatype_env) (t : ML.term) (w : variable) : F.no
u'))
| ML.Annot (pos, t, annot) ->
let@ v, annot' = convert_annot typedecl_env annot in
let+ () = v -- w
and+ t' = hastype t v
and+ annot' = annot' in
F.Annot (pos, t', annot')
let convert_annot typedecl_env params w t ty =
let@ v = convert typedecl_env params ty in
let+ () = v -- w
and+ t' = hastype t v
and+ ty' = witness v
in F.Annot (pos, t', ty')
in
begin match annot with
| (_, [], ty) ->
convert_annot typedecl_env [] w t ty
| (ML.Flexible, vs, ty) ->
let@ params =
vs |> mapM_now (fun alpha k -> let@ v = exist in k (alpha, v)) in
convert_annot typedecl_env params w t ty
| (ML.Rigid, vs, ty) ->
let+ (alphas, t', tys) =
letr1 vs
(fun tyvs w -> convert_annot typedecl_env tyvs w t ty)
(fun s -> instance' s w)
in
F.ftyapp (F.ftyabs alphas t') tys
end
| ML.Tuple (pos, ts) ->
let on_term (t:ML.term) : ('b * 'c co, 'r) binder =
@ -553,15 +553,25 @@ let hastype (typedecl_env : ML.datatype_env) (t : ML.term) (w : variable) : F.no
| ML.PWildcard pos ->
k ([], pure (F.PWildcard pos))
| ML.PAnnot (pos, pat, annot) ->
let@ v, annot' = convert_annot typedecl_env annot in
let@ (pat_env, pat) = hastype_pat typedecl_env pat v in
let+ () = v -- w
and+ res = k (pat_env,
let+ pat = pat
and+ annot' = annot'
in F.PAnnot(pos, pat, annot'))
in res
| ML.PAnnot (pos, pat, (rigidity, vars, ty)) ->
begin match rigidity with
| ML.Rigid ->
failwith "Rigid variables are not supported in pattern annotation"
| ML.Flexible ->
let@ params =
vars |> mapM_now (fun alpha k ->
let@ v = exist in k(alpha,v)
)
in
let@ v = convert typedecl_env params ty in
let@ (pat_env, pat) = hastype_pat typedecl_env pat v in
let+ () = v -- w
and+ res = k (pat_env,
let+ pat = pat
and+ ty' = witness v
in F.PAnnot(pos, pat, ty'))
in res
end
| ML.PTuple (pos, pats) ->

View File

@ -427,6 +427,22 @@ let test_regression2 () =
"let f = fun x -> let g = fun y -> (x, y) in g in fun x -> fun y -> f"
regression2
let a = ML.TyVar (dummy_pos, "'a")
let b = ML.TyVar (dummy_pos, "'b")
let id_annot annot =
ML.(Annot (dummy_pos, Abs(dummy_pos, "x", Var (dummy_pos, "x")), annot))
let test_id_rigid () =
test_ok
"(fun x -> x : for 'a. 'a -> 'a)"
(id_annot (ML.Rigid, ["'a"], ML.TyArrow (dummy_pos, a, a)))
let test_id_flexible () =
test_ok
"(fun x -> x : some 'a 'b. 'a -> 'b)"
(id_annot (ML.Flexible, ["'a"; "'b"], ML.TyArrow (dummy_pos, a, b)))
let test_suite =
let open Alcotest in
test_suite ::
@ -461,10 +477,16 @@ let test_suite =
"pattern matching",
[
test_case "match tuple" `Quick test_match_tuple;
test_case "match none" `Quick test_match_tuple;
test_case "match none" `Quick test_match_none;
test_case "match some" `Quick test_match_some;
test_case "match some annotated" `Quick test_match_some_annotated;
]
) ; (
"rigid",
[
test_case "id rigid" `Quick test_id_rigid;
test_case "id flexible" `Quick test_id_flexible;
]
)
]

View File

@ -0,0 +1,2 @@
let _ =
(fun x -> x : some 'a 'b. 'a -> 'b)

View File

@ -0,0 +1,2 @@
let _ =
(fun x -> x : for 'a 'b. 'a -> 'b)

View File

@ -0,0 +1,2 @@
let _ =
(fun x -> x : for 'a. 'a -> 'a)

View File

@ -0,0 +1,2 @@
let _ =
(fun x -> x : for 'a. 'a -> 'a) ()

View File

@ -0,0 +1,2 @@
let f =
fun x -> (x : for 'a. 'a)

View File

@ -190,3 +190,48 @@ Pattern-matching
end
Converting the System F term to de Bruijn style...
Type-checking the System F term...
$ cat id-rigid.midml
let _ =
(fun x -> x : for 'a. 'a -> 'a)
$ ../TestMidML.exe id-rigid.midml
Type inference and translation to System F...
Formatting the System F term...
Pretty-printing the System F term...
FUN a0 -> (FUN a1 -> (fun (x : a1) -> x : a1 -> a1)) [a0]
Converting the System F term to de Bruijn style...
Type-checking the System F term...
$ cat id-rigid2.midml
let _ =
(fun x -> x : for 'a. 'a -> 'a) ()
$ ../TestMidML.exe id-rigid2.midml
Type inference and translation to System F...
Formatting the System F term...
Pretty-printing the System F term...
(FUN a0 -> (fun (x : a0) -> x : a0 -> a0)) [{}] ()
Converting the System F term to de Bruijn style...
Type-checking the System F term...
$ cat id-rigid-wrong.midml
let _ =
(fun x -> x : for 'a 'b. 'a -> 'b)
$ ../TestMidML.exe id-rigid-wrong.midml
Type inference and translation to System F...
$ cat unit-rigid-wrong.midml
let _ =
(() : for 'a. 'a)
$ ../TestMidML.exe unit-rigid-wrong.midml
Type inference and translation to System F...
$ cat rigid-level-escape-wrong.midml
let f =
fun x -> (x : for 'a. 'a)
$ ../TestMidML.exe rigid-level-escape-wrong.midml
Type inference and translation to System F...

View File

@ -0,0 +1,2 @@
let _ =
(() : for 'a. 'a)

View File

@ -80,6 +80,7 @@ type 'a data = {
mutable structure: 'a OS.structure;
mutable rank: rank;
mutable generic: generic;
mutable is_rigid: bool
}
(* The module [Data] satisfies the signature [USTRUCTURE] required by the
@ -90,10 +91,10 @@ module Data = struct
type 'a structure =
'a data
let make structure rank =
let make structure rank is_rigid =
(* A fresh variable is ordinary, not generic. *)
let generic = false in
{ structure; rank; generic }
{ structure; rank; generic; is_rigid }
let map f data =
{ data with structure = OS.map f data.structure }
@ -104,6 +105,23 @@ module Data = struct
exception InconsistentConjunction =
OS.InconsistentConjunction
let merge d1 d2 : int * bool =
match (d1.is_rigid, d2.is_rigid) with
| false, false ->
( min d1.rank d2.rank, false )
| false, true ->
if d1.rank < d2.rank then
raise InconsistentConjunction
else
( d2.rank, true )
| true, false ->
if d1.rank > d2.rank then
raise InconsistentConjunction
else
( d1.rank, true )
| true, true ->
raise InconsistentConjunction
(* [conjunction] is invoked by the unifier when two equivalence classes are
unified. It is in charge of computing the data associated with the new
class. *)
@ -114,11 +132,13 @@ module Data = struct
let structure = OS.conjunction equate data1.structure data2.structure in
(* The rank of the new class is the minimum of the ranks of the two
original classes. *)
let rank = min data1.rank data2.rank in
let (rank, is_rigid) = merge data1 data2 in
if is_rigid && structure <> None then
raise InconsistentConjunction;
(* The unifier never acts on generic variables. *)
assert (not data1.generic && not data2.generic);
let generic = false in
{ structure; rank; generic }
{ structure; rank; generic; is_rigid }
end
@ -239,7 +259,13 @@ let register { pool; _ } v r =
let fresh state structure =
let r = state.young in
let v = U.fresh (Data.make structure r) in
let v = U.fresh (Data.make structure r false) in
register state v r;
v
let fresh_rigid state structure =
let r = state.young in
let v = U.fresh (Data.make structure r true) in
register state v r;
v
@ -617,6 +643,28 @@ let exit ~rectypes state roots =
(* Done. *)
quantifiers, schemes
let exit_rigid ~rectypes state root =
let young_vars = discover_young_generation state in
let on_var w =
if U.is_representative w && rank w = state.young then
match structure w with
| None ->
if not (U.structure w).is_rigid then
set_rank w (state.young - 1)
| Some _ ->
()
in
List.iter on_var young_vars.inhabitants;
let _, schemes = exit ~rectypes state [root] in
match schemes with
| [s] ->
s
| _ ->
assert false
(* -------------------------------------------------------------------------- *)
(* Instantiation amounts to copying a fragment of a graph. The variables that

View File

@ -75,6 +75,8 @@ module Make (S : GSTRUCTURE) : sig
[enter]/[exit] balance is at least one. *)
val fresh: state -> variable S.structure option -> variable
val fresh_rigid: state -> variable S.structure option -> variable
(* A variable can be turned into a trivial scheme, with no quantifiers and
no generic part: in other words, a monomorphic type scheme. Non-trivial
type schemes are created by the functions [enter] and [exit] below. *)
@ -120,6 +122,8 @@ module Make (S : GSTRUCTURE) : sig
val exit: rectypes:bool -> state -> variable list -> variable list * scheme list
val exit_rigid: rectypes:bool -> state -> variable -> scheme
(* [instantiate] takes a fresh copy of a type scheme. The generic variables
of the type scheme are replaced with freshly created variables, which are
automatically registered (hence, the current state is updated). The function

View File

@ -25,9 +25,15 @@ module Make
type tevar =
X.tevar
type svar =
int
module TeVarMap =
Map.Make(struct include X type t = tevar end)
module SVMap =
Map.Make(Int)
(* -------------------------------------------------------------------------- *)
(* The type variables that appear in constraints are immutable: they
@ -39,6 +45,9 @@ type variable =
let fresh : unit -> variable =
Utils.gensym()
let fresh_svar : unit -> svar =
Utils.gensym()
module VarTable = Hashtbl.Make(struct
type t = variable
let hash = Hashtbl.hash
@ -96,6 +105,8 @@ type _ co =
term variable [x]. Its result is a list of types that indicates
how the type scheme was instantiated. *)
| CInstance' : svar * variable -> O.ty list co
| CDef : tevar * variable * 'a co -> 'a co
(**The constraint [CDef (x, v, c)] binds the term variable [x] to
the trivial (monomorphic) type scheme [v] in the constraint [c]. *)
@ -118,6 +129,9 @@ type _ co =
- the value [a2] produced by solving the constraint [c2].
*)
| CLetRigid : variable list * variable * 'a co * svar * 'b co ->
(O.tyvar list * 'a * 'b) co
(* -------------------------------------------------------------------------- *)
(* A pretty-printer for constraints, used while debugging. *)
@ -176,6 +190,14 @@ module Printer = struct
next c1 ^^
string " in")) ^/^
self c2
| CLetRigid (vs, z, c1, s, c2) ->
string "letr " ^^
separate_map (string " and ") var (vs @ [z]) ^^
string " where" ^^ group (nest 2 (break 1 ^^
next c1 ^^
string " in")) ^/^
self c2 ^^
string " : " ^^ var s
| _ ->
next c
@ -207,9 +229,12 @@ module Printer = struct
separate space [var v1; string "="; var v2]
| CInstance (x, v) ->
tevar x ^^ utf8string "" ^^ var v
| CInstance'(sv, w) ->
var sv ^^ utf8string "" ^^ var w
| CExist _
| CDef _
| CLet _
| CLetRigid _
| CConj _
->
(* Introduce parentheses. *)
@ -288,6 +313,12 @@ end) = struct
VarTable.add table v (Some uv);
uv
let ubind_rigid (v : variable) so =
assert (not (VarTable.mem table v));
let uv = G.fresh_rigid X.state (Option.map (S.map uvar) so) in
VarTable.add table v (Some uv);
uv
end (* UVar *)
(* -------------------------------------------------------------------------- *)
@ -386,6 +417,13 @@ let exit range ~rectypes state vs =
let decode = new_decoder ~rectypes:true in
raise (Cycle (range, decode v))
let exit_rigid range ~rectypes state z =
try
G.exit_rigid ~rectypes state z
with U.Cycle v ->
let decode = new_decoder ~rectypes:true in
raise (Cycle (range, decode v))
(* -------------------------------------------------------------------------- *)
(* The toplevel constraint that is passed to the solver must have been
@ -411,12 +449,15 @@ let rec ok : type a . a co -> bool =
(* The left-hand constraint [c1] does not need to be [ok], since it is
examined after a call to [G.enter]. *)
ok c2
| CLetRigid (_vs, _z, c1, _s, c2) ->
ok c1 && ok c2
| CConj (c1, c2) ->
ok c1 && ok c2
| CEq _
| CExist _
| CWitness _
| CInstance _
| CInstance' _
| CDef _ ->
(* These forms are not [ok], as they involve (free or binding
occurrences of) type variables. *)
@ -468,25 +509,25 @@ let solve ~(rectypes : bool) (type a) (c : a co) : a =
range (the range annotation that was most recently encountered on the
way down). *)
let rec solve : type a . env -> range -> a co -> a on_sol =
fun env range c -> match c with
let rec solve : type a . env -> 'b SVMap.t -> range -> a co -> a on_sol =
fun env senv range c -> match c with
| CRange (range, c) ->
solve env range c
solve env senv range c
| CTrue ->
On_sol (fun () -> ())
| CMap (c, f) ->
let (On_sol r) = solve env range c in
let (On_sol r) = solve env senv range c in
On_sol (fun () -> f (r ()))
| CConj (c1, c2) ->
let (On_sol r1) = solve env range c1 in
let (On_sol r2) = solve env range c2 in
let (On_sol r1) = solve env senv range c1 in
let (On_sol r2) = solve env senv range c2 in
On_sol (fun () -> (r1 (), r2 ()))
| CEq (v, w) ->
unify range (uvar v) (uvar w);
On_sol (fun () -> ())
| CExist (v, s, c) ->
ignore (ubind v s);
solve env range c
solve env senv range c
| CWitness v ->
On_sol (fun () -> decode (uvar v))
| CInstance (x, w) ->
@ -498,9 +539,14 @@ let solve ~(rectypes : bool) (type a) (c : a co) : a =
let witnesses, v = G.instantiate state s in
unify range v (uvar w);
On_sol (fun () -> List.map decode witnesses)
| CInstance' (sv, w) ->
let s = SVMap.find sv senv in
let witnesses, v = G.instantiate state s in
unify range v (uvar w);
On_sol (fun () -> List.map decode witnesses)
| CDef (x, v, c) ->
let env = Env.bind x (G.trivial (uvar v)) env in
solve env range c
solve env senv range c
| CLet (xvs, c1, c2) ->
(* Warn the generalization engine that we are entering the left-hand
side of a [let] construct. *)
@ -510,7 +556,7 @@ let solve ~(rectypes : bool) (type a) (c : a co) : a =
basically, but they also serve as named entry points. *)
let vs = List.map (fun (_, v) -> ubind v None) xvs in
(* Solve the constraint [c1]. *)
let (On_sol r1) = solve env range c1 in
let (On_sol r1) = solve env senv range c1 in
(* Ask the generalization engine to perform an occurs check, to adjust
the ranks of the type variables in the young generation (i.e., all
of the type variables that were registered since the call to
@ -526,18 +572,31 @@ let solve ~(rectypes : bool) (type a) (c : a co) : a =
) xvs ss (env, [])
in
(* Proceed to solve [c2] in the extended environment. *)
let (On_sol r2) = solve env range c2 in
let (On_sol r2) = solve env senv range c2 in
On_sol (fun () ->
List.map decode_variable generalizable,
List.map (fun (x, s) -> (x, decode_scheme decode s)) xss,
r1 (),
r2 ())
| CLetRigid (vs, z, c1, sv, c2) ->
G.enter state;
let vs = List.map (fun v -> ubind_rigid v None) vs in
let z = ubind z None in
let (On_sol r1) = solve env senv range c1 in
let s = exit_rigid range ~rectypes state z in
let senv = SVMap.add sv s senv in
let (On_sol r2) = solve env senv range c2 in
On_sol (fun () ->
List.map decode_variable vs,
r1 (),
r2 ())
in
let env = Env.empty
and senv = SVMap.empty
and range = Lexing.(dummy_pos, dummy_pos) in
(* Phase 1: solve the constraint. *)
let (On_sol r) = solve env range c in
let (On_sol r) = solve env senv range c in
(* Phase 2: elaborate. *)
r()
@ -680,6 +739,8 @@ let instance x v =
let instance_ x v =
CMap (instance x v, ignore)
let instance' sv v =
CInstance' (sv, v)
(* -------------------------------------------------------------------------- *)
(* Constraint abstractions. *)
@ -734,6 +795,21 @@ let let0 c1 =
letn [] (fun _ -> c1) CTrue <$$>
fun (_, generalizable, v1, ()) -> (generalizable, v1)
let letr1
: 'tyvar list
-> (('tyvar * variable) list -> variable -> 'a co)
-> (svar -> 'b co)
-> (O.tyvar list * 'a * 'b) co
= fun alphas f1 f2 ->
let xvss = List.map (fun a ->
a, fresh ()
) alphas in
let z = fresh () in
let c1 = f1 xvss z in
let sv = fresh_svar () in
let c2 = f2 sv in
CLetRigid (List.map snd xvss, z, c1, sv, c2)
(* -------------------------------------------------------------------------- *)
(* Correlation with the source code. *)

View File

@ -21,6 +21,8 @@ module Make
(* The type [tevar] of term variables is provided by [X]. *)
open X
type svar
(* The type ['a structure] of shallow types is provided by [S]
(and repeated by [O]). *)
@ -128,6 +130,7 @@ module Make
(* [instance_ x v] is equivalent to [instance x v <$$> ignore]. *)
val instance_: tevar -> variable -> unit co
val instance': svar -> variable -> ty list co
(* ---------------------------------------------------------------------- *)
(* Construction of constraint abstractions, a.k.a. generalization. *)
@ -150,6 +153,11 @@ module Make
val let1: tevar -> (variable -> 'a co) -> 'b co ->
(scheme * tyvar list * 'a * 'b) co
val letr1: 'tyvar list
-> (('tyvar * variable) list -> variable -> 'a co)
-> (svar -> 'b co)
-> (tyvar list * 'a * 'b) co
(* [let0 c] has the same meaning as [c], but, like [let1], produces a list [vs]
of the type variables that may appear in the result of [c]. *)
val let0: 'a co -> (tyvar list * 'a) co