diff --git a/Makefile b/Makefile index 831dbe36..f1fe29bf 100644 --- a/Makefile +++ b/Makefile @@ -5,10 +5,12 @@ DIR_GUARD=@mkdir -p $(@D) #Name of the final executable file to be generated EX_NAME=birds +TEST_EX_NAME=birds_unit_tests # source folder SOURCE_DIR=src LOGIC_SOURCE_DIR=src/logic +TEST_SOURCE_DIR=src/test # binary folder for compilation BIN_DIR=bin @@ -17,13 +19,15 @@ RELEASE_DIR = release LOGIC_RELEASE_DIR = release/logic OBJ_DIR=${SOURCE_DIR} LOGIC_OBJ_DIR=${LOGIC_SOURCE_DIR} +TEST_OBJ_DIR=${TEST_SOURCE_DIR} -OCAMLC_FLAGS=-bin-annot -w -26 -I $(OBJ_DIR) -I $(LOGIC_OBJ_DIR) +OCAMLC_FLAGS=-bin-annot -w -26 -I $(OBJ_DIR) -I $(LOGIC_OBJ_DIR) -I $(TEST_OBJ_DIR) OCAMLOPT_FLAGS=-bin-annot -w -26 -I $(RELEASE_DIR) -I $(LOGIC_RELEASE_DIR) -OCAMLDEP_FLAGS=-I $(SOURCE_DIR) -I $(LOGIC_SOURCE_DIR) +OCAMLDEP_FLAGS=-I $(SOURCE_DIR) -I $(LOGIC_SOURCE_DIR) -I $(TEST_SOURCE_DIR) #Name of the files that are part of the project MAIN_FILE=main +TEST_MAIN_FILE=test_main LOGIC_FILES=\ lib intro formulas prop fol skolem fol_ex\ @@ -39,17 +43,31 @@ TOP_FILES=\ ast2theorem \ bx\ debugger\ + simplification\ TOP_FILES_WITH_MLI=\ parser expr conversion ast2sql ast2theorem\ + simplification\ + +TEST_ONLY_FILES=\ + ast2sql_operation_based_conversion_test\ + simplification_test\ + FILES=\ $(LOGIC_FILES:%=logic/%)\ $(TOP_FILES) \ -.PHONY: all release clean depend #annot +TEST_FILES=\ + $(FILES)\ + $(TEST_ONLY_FILES:%=test/%)\ + +.PHONY: all release clean depend test #annot all: $(BIN_DIR)/$(EX_NAME) +test: $(BIN_DIR)/$(TEST_EX_NAME) + ./$(BIN_DIR)/$(TEST_EX_NAME) + #Rule for generating the final executable file $(BIN_DIR)/$(EX_NAME): $(FILES:%=$(OBJ_DIR)/%.cmo) $(OBJ_DIR)/$(MAIN_FILE).cmo $(DIR_GUARD) @@ -60,6 +78,14 @@ $(OBJ_DIR)/$(MAIN_FILE).cmo: $(FILES:%=$(OBJ_DIR)/%.cmo) $(SOURCE_DIR)/$(MAIN_FI $(DIR_GUARD) ocamlfind ocamlc $(OCAMLC_FLAGS) -package $(PACKAGES) -thread -o $(OBJ_DIR)/$(MAIN_FILE) -c $(SOURCE_DIR)/$(MAIN_FILE).ml +$(BIN_DIR)/$(TEST_EX_NAME): $(TEST_FILES:%=$(OBJ_DIR)/%.cmo) $(TEST_OBJ_DIR)/$(TEST_MAIN_FILE).cmo + $(DIR_GUARD) + ocamlfind ocamlc $(OCAMLC_FLAGS) -package $(PACKAGES) -thread -linkpkg $(TEST_FILES:%=$(OBJ_DIR)/%.cmo) $(TEST_OBJ_DIR)/$(TEST_MAIN_FILE).cmo -o $(BIN_DIR)/$(TEST_EX_NAME) + +$(TEST_OBJ_DIR)/$(TEST_MAIN_FILE).cmo: $(TEST_FILES:%=$(OBJ_DIR)/%.cmo) $(TEST_SOURCE_DIR)/$(TEST_MAIN_FILE).ml + $(DIR_GUARD) + ocamlfind ocamlc $(OCAMLC_FLAGS) -package $(PACKAGES) -thread -o $(TEST_OBJ_DIR)/$(TEST_MAIN_FILE) -c $(TEST_SOURCE_DIR)/$(TEST_MAIN_FILE).ml + #Special rules for creating the lexer and parser $(SOURCE_DIR)/parser.ml $(SOURCE_DIR)/parser.mli: $(SOURCE_DIR)/parser.mly ocamlyacc $< @@ -91,7 +117,7 @@ $(OBJ_DIR)/%.cmi $(OBJ_DIR)/%.cmo $(OBJ_DIR)/%.cmt: $(SOURCE_DIR)/%.ml include depend clean: - rm -r -f $(BIN_DIR)/* $(RELEASE_DIR)/* $(SOURCE_DIR)/parser.mli $(SOURCE_DIR)/parser.ml $(SOURCE_DIR)/lexer.ml $(OBJ_DIR)/*.cmt $(LOGIC_OBJ_DIR)/*.cmt $(OBJ_DIR)/*.cmti $(LOGIC_OBJ_DIR)/*.cmti $(OBJ_DIR)/*.cmo $(LOGIC_OBJ_DIR)/*.cmo $(OBJ_DIR)/*.cmi $(LOGIC_OBJ_DIR)/*.cmi + rm -r -f $(BIN_DIR)/* $(RELEASE_DIR)/* $(SOURCE_DIR)/parser.mli $(SOURCE_DIR)/parser.ml $(SOURCE_DIR)/lexer.ml $(OBJ_DIR)/*.cmt $(LOGIC_OBJ_DIR)/*.cmt $(TEST_OBJ_DIR)/*.cmt $(OBJ_DIR)/*.cmti $(LOGIC_OBJ_DIR)/*.cmti $(TEST_OBJ_DIR)/*.cmti $(OBJ_DIR)/*.cmo $(LOGIC_OBJ_DIR)/*.cmo $(TEST_OBJ_DIR)/*.cmo $(OBJ_DIR)/*.cmi $(LOGIC_OBJ_DIR)/*.cmi $(TEST_OBJ_DIR)/*.cmi depend: ocamlfind ocamldep $(OCAMLDEP_FLAGS) $(FILES:%=$(SOURCE_DIR)/%.ml) $(SOURCE_DIR)/lexer.mll $(SOURCE_DIR)/parser.mli |sed -e 's/$(SOURCE_DIR)/$(BIN_DIR)/g' > depend diff --git a/src/ast2sql.ml b/src/ast2sql.ml index 45f71e4b..cc05b533 100644 --- a/src/ast2sql.ml +++ b/src/ast2sql.ml @@ -35,17 +35,23 @@ type sql_operator = | SqlRelNotEqual | SqlRelGeneral of string +type sql_schema_name = string + +type sql_table_name = string + type sql_column_name = string +type sql_instance_name = string + type sql_vterm = | SqlConst of Expr.const - | SqlColumn of sql_column_name + | SqlColumn of sql_instance_name option * sql_column_name | SqlUnaryOp of sql_unary_operator * sql_vterm | SqlBinaryOp of sql_binary_operator * sql_vterm * sql_vterm | SqlAggVar of sql_agg_function * sql_vterm type sql_select_clause = - | SqlSelect of (sql_vterm * string) list + | SqlSelect of (sql_vterm * sql_column_name) list type sql_comp_const = | SqlCompConst of sql_vterm * sql_operator * const @@ -56,12 +62,16 @@ type sql_group_by = type sql_having = | SqlHaving of sql_comp_const list +type sql_union_operation = + | SqlUnionOp + | SqlUnionAllOp + type sql_from_target = - | SqlFromColumn of sql_column_name - | SqlFromOther of sql_union + | SqlFromTable of sql_schema_name option * sql_table_name + | SqlFromOther of sql_query and sql_from_clause_entry = - sql_from_target * sql_column_name + sql_from_target * sql_instance_name and sql_from_clause = | SqlFrom of sql_from_clause_entry list @@ -69,6 +79,7 @@ and sql_from_clause = and sql_constraint = | SqlConstraint of sql_vterm * sql_operator * sql_vterm | SqlNotExist of sql_from_clause * sql_where_clause + | SqlExist of sql_from_clause * sql_where_clause and sql_where_clause = | SqlWhere of sql_constraint list @@ -84,23 +95,26 @@ and sql_query = agg : sql_aggregation_clause; } | SqlQuerySelectWhereFalse - -and sql_union_operation = - | SqlUnionOp - | SqlUnionAllOp - -and sql_union = | SqlUnion of sql_union_operation * sql_query list +type sql_operation = + | SqlCreateTemporaryTable of table_name * sql_query + | SqlCreateView of table_name * sql_query + | SqlInsertInto of table_name * sql_from_clause + | SqlDeleteFrom of table_name * sql_where_clause + let rec stringify_sql_vterm (vt : sql_vterm) : string = match vt with | SqlConst c -> string_of_const c - | SqlColumn column_name -> + | SqlColumn (None, column_name) -> column_name + | SqlColumn (Some instance_name, column_name) -> + Printf.sprintf "%s.%s" instance_name column_name + | SqlUnaryOp (un_op, vt1) -> let s_op = match un_op with @@ -156,11 +170,9 @@ let stringify_sql_comp_const (SqlCompConst (vt, op, c) : sql_comp_const) : strin let rec stringify_sql_from_target (target : sql_from_target) : string = match target with - | SqlFromColumn column_name -> - column_name - - | SqlFromOther sql_union -> - Printf.sprintf "(%s)" (stringify_sql_union sql_union) + | SqlFromTable (None, table) -> table + | SqlFromTable (Some schema, table) -> Printf.sprintf "%s.%s" schema table + | SqlFromOther sql_query -> Printf.sprintf "(%s)" (stringify_sql_query sql_query) and stringify_sql_from_clause (SqlFrom froms : sql_from_clause) : string = @@ -174,7 +186,7 @@ and stringify_sql_from_clause (SqlFrom froms : sql_from_clause) : string = Printf.sprintf "%s AS %s" (stringify_sql_from_target target) name ) |> String.concat ", " in - Printf.sprintf "FROM %s" s + Printf.sprintf " FROM %s" s and stringify_sql_constraint (sql_constraint : sql_constraint) : string = @@ -188,7 +200,12 @@ and stringify_sql_constraint (sql_constraint : sql_constraint) : string = | SqlNotExist (from, where) -> let s_from = stringify_sql_from_clause from in let s_where = stringify_sql_where_clause where in - Printf.sprintf "NOT EXISTS ( SELECT * %s %s )" s_from s_where + Printf.sprintf "NOT EXISTS ( SELECT *%s%s )" s_from s_where + + | SqlExist (from, where) -> + let s_from = stringify_sql_from_clause from in + let s_where = stringify_sql_where_clause where in + Printf.sprintf "EXISTS ( SELECT *%s%s )" s_from s_where and stringify_sql_where_clause (SqlWhere constraints : sql_where_clause) : string = @@ -200,7 +217,7 @@ and stringify_sql_where_clause (SqlWhere constraints : sql_where_clause) : strin let s = constraints |> List.map stringify_sql_constraint |> String.concat " AND " in - Printf.sprintf "WHERE %s" s + Printf.sprintf " WHERE %s" s and stringify_sql_aggregation_clause (agg : sql_aggregation_clause) : string = @@ -208,7 +225,7 @@ and stringify_sql_aggregation_clause (agg : sql_aggregation_clause) : string = let s_group_by = match column_names with | [] -> "" - | _ :: _ -> Printf.sprintf "GROUP BY %s" (String.concat ", " column_names) + | _ :: _ -> Printf.sprintf " GROUP BY %s" (String.concat ", " column_names) in let s_having = match comp_consts with @@ -217,9 +234,9 @@ and stringify_sql_aggregation_clause (agg : sql_aggregation_clause) : string = | _ :: _ -> let s = comp_consts |> List.map stringify_sql_comp_const |> String.concat " AND " in - Printf.sprintf "HAVING %s" s + Printf.sprintf " HAVING %s" s in - Printf.sprintf "%s %s" s_group_by s_having + Printf.sprintf "%s%s" s_group_by s_having and stringify_sql_query (sql : sql_query) : string = @@ -232,16 +249,30 @@ and stringify_sql_query (sql : sql_query) : string = let s_from = stringify_sql_from_clause from in let s_where = stringify_sql_where_clause where in let s_agg = stringify_sql_aggregation_clause agg in - Printf.sprintf "%s %s %s %s" s_select s_from s_where s_agg + Printf.sprintf "%s%s%s%s" s_select s_from s_where s_agg + | SqlUnion (union_op, queries) -> + let sep = + match union_op with + | SqlUnionOp -> " UNION " + | SqlUnionAllOp -> " UNION ALL " + in + queries |> List.map stringify_sql_query |> String.concat sep -and stringify_sql_union (SqlUnion (union_op, queries) : sql_union) : string = - let sep = - match union_op with - | SqlUnionOp -> " UNION " - | SqlUnionAllOp -> " UNION ALL " - in - queries |> List.map stringify_sql_query |> String.concat sep + +let stringify_sql_operation (sql_op : sql_operation) : string = + match sql_op with + | SqlCreateTemporaryTable (table, sql_query) -> + Printf.sprintf "CREATE TEMPORARY TABLE %s AS %s;" table (stringify_sql_query sql_query) + + | SqlCreateView (table, sql_query) -> + Printf.sprintf "CREATE VIEW %s AS %s;" table (stringify_sql_query sql_query) + + | SqlInsertInto (table, sql_from_clause) -> + Printf.sprintf "INSERT INTO %s SELECT *%s;" table (stringify_sql_from_clause sql_from_clause) + + | SqlDeleteFrom (table, sql_where_clause) -> + Printf.sprintf "DELETE FROM %s%s;" table (stringify_sql_where_clause sql_where_clause) (** Given an aggregate function name, checks if it is supported and returns it. *) @@ -285,7 +316,7 @@ let sql_of_vterm (vt : vartab) (eqt : eqtab) (expr : vterm) : sql_vterm = * is the name of the respective rterm's table column *) if Hashtbl.mem vt (string_of_var variable) then let column = List.hd (Hashtbl.find vt (string_of_var variable)) in - SqlColumn column + SqlColumn (None, column) (* If the variable does not appear in a positive rterm, but * it does in an equality value, then the value is the eq's evaluation *) else if Hashtbl.mem eqt (Var variable) then @@ -316,7 +347,7 @@ let var_to_col (vt : vartab) (eqt : eqtab) (key : symtkey) (variable : var) : sq * is the name of the respective rterm's table column *) if Hashtbl.mem vt (string_of_var variable) then let column = List.hd (Hashtbl.find vt (string_of_var variable)) in - SqlColumn column + SqlColumn (None, column) (* If the variable does not appear in a positive rterm, but * it does in an equality value, then the value is the eq's * constant, the var has to be removed from the eqtab *) @@ -428,7 +459,7 @@ let get_aggregation_sql (vt : vartab) (cnt : colnamtab) (head : rterm) (agg_eqs (group_by_sql, having_sql) -let rec non_rec_unfold_sql_of_symtkey (dbschema : string) (idb : symtable) (cnt : colnamtab) (goal : symtkey) : sql_union = +let rec non_rec_unfold_sql_of_symtkey (dbschema : string) (idb : symtable) (cnt : colnamtab) (goal : symtkey) : sql_query = (* get all the rule having this query in head *) (* print_endline ("Reach " ^ (string_of_symtkey goal)); *) if not (Hashtbl.mem idb goal) then @@ -462,9 +493,9 @@ let rec non_rec_unfold_sql_of_symtkey (dbschema : string) (idb : symtable) (cnt in let edb_alias (pname : string) (arity : int) (n : int) : sql_from_clause_entry = if str_contains pname "__tmp_" then - (SqlFromColumn pname, pname ^ "_a" ^ (string_of_int arity) ^ "_" ^ (string_of_int n)) + (SqlFromTable (None, pname), pname ^ "_a" ^ (string_of_int arity) ^ "_" ^ (string_of_int n)) else - (SqlFromColumn (dbschema ^ "." ^ pname), pname ^ "_a" ^ (string_of_int arity) ^ "_" ^ (string_of_int n)) + (SqlFromTable (Some dbschema, pname), pname ^ "_a" ^ (string_of_int arity) ^ "_" ^ (string_of_int n)) in let set_alias (rterm : rterm) (a_lst, n) = let pname = get_rterm_predname rterm in @@ -487,7 +518,7 @@ let rec non_rec_unfold_sql_of_symtkey (dbschema : string) (idb : symtable) (cnt | [] -> acc | hd :: tl -> - let eq_rels el = SqlConstraint (SqlColumn hd, SqlRelEqual, SqlColumn el) in + let eq_rels el = SqlConstraint (SqlColumn (None, hd), SqlRelEqual, SqlColumn (None, el)) in (List.map eq_rels tl) :: acc in let fvt = List.flatten (Hashtbl.fold var_const vt []) in @@ -525,19 +556,19 @@ let rec non_rec_unfold_sql_of_symtkey (dbschema : string) (idb : symtable) (cnt if Hashtbl.mem idb key then SqlFrom [ (SqlFromOther (non_rec_unfold_sql_of_symtkey dbschema idb cnt (pname,arity)), alias) ] else if str_contains pname "__tmp_" then - SqlFrom [ (SqlFromColumn pname, alias) ] + SqlFrom [ (SqlFromTable (None, pname), alias) ] else - SqlFrom [ (SqlFromColumn (dbschema ^ "." ^ pname), alias) ] + SqlFrom [ (SqlFromTable (Some dbschema, pname), alias) ] in (* print_endline "___neg sql___"; print_string from_sql; print_endline "___neg sql___"; *) (* Get the where sql of the rterm *) let build_const (acc : sql_constraint list) (col : sql_column_name) (var : var) : sql_constraint list = - let left = SqlColumn (alias ^ "." ^ col) in + let left = SqlColumn (Some alias, col) in match var with | NamedVar vn -> let right = if Hashtbl.mem vt vn then - SqlColumn (List.hd (Hashtbl.find vt vn)) + SqlColumn (None, List.hd (Hashtbl.find vt vn)) else if Hashtbl.mem eqt (Var var) then sql_of_vterm vt eqt (Hashtbl.find eqt (Var var)) else @@ -553,7 +584,7 @@ let rec non_rec_unfold_sql_of_symtkey (dbschema : string) (idb : symtable) (cnt let vn = string_of_var var in let right = if Hashtbl.mem vt vn then - SqlColumn (List.hd (Hashtbl.find vt vn)) + SqlColumn (None, List.hd (Hashtbl.find vt vn)) else if Hashtbl.mem eqt (Var var) then sql_of_vterm vt eqt (Hashtbl.find eqt (Var var)) else @@ -616,7 +647,7 @@ let non_rec_unfold_sql_of_query (dbschema : string) (idb : symtable) (cnt : coln else let cols = Hashtbl.find cnt (symtkey_of_rterm query) in let sel_lst = - List.map (fun (a, b) -> (SqlColumn (qrule_alias ^ "." ^ a), b)) (List.combine cols cols_by_var) + List.map (fun (a, b) -> (SqlColumn (Some qrule_alias, a), b)) (List.combine cols cols_by_var) in let sql_from = (SqlFromOther (non_rec_unfold_sql_of_symtkey dbschema local_idb cnt (symtkey_of_rterm (rule_head qrule))), qrule_alias) @@ -915,7 +946,7 @@ let non_rec_unfold_sql_of_update (dbschema : string) (log : bool) (optimize : bo SELECT array_agg(tbl) INTO array_"^ (get_rterm_predname delta)^" FROM ("^ "SELECT "^"(ROW("^(String.concat "," (Hashtbl.find cnt (symtkey_of_rterm delta))) ^") :: "^dbschema^"."^ pname ^").* FROM ("^ - (stringify_sql_union (non_rec_unfold_sql_of_symtkey dbschema local_idb cnt (symtkey_of_rterm (rule_head qrule)))) ^") AS "^(get_rterm_predname delta)^"_extra_alias) AS tbl" + (stringify_sql_query (non_rec_unfold_sql_of_symtkey dbschema local_idb cnt (symtkey_of_rterm (rule_head qrule)))) ^") AS "^(get_rterm_predname delta)^"_extra_alias) AS tbl" (* ^" EXCEPT SELECT * FROM "^dbschema^"."^ pname *) @@ -951,7 +982,7 @@ let non_rec_unfold_sql_of_update (dbschema : string) (log : bool) (optimize : bo SELECT array_agg(tbl) INTO array_"^ (get_rterm_predname delta)^" FROM (" ^ "SELECT "^"(ROW("^(String.concat "," (Hashtbl.find cnt (symtkey_of_rterm delta))) ^") :: "^dbschema^"."^ pname ^").* FROM ("^ - (stringify_sql_union (non_rec_unfold_sql_of_symtkey dbschema local_idb cnt (symtkey_of_rterm (rule_head qrule))))^") AS "^(get_rterm_predname delta)^"_extra_alias) AS tbl;", + (stringify_sql_query (non_rec_unfold_sql_of_symtkey dbschema local_idb cnt (symtkey_of_rterm (rule_head qrule))))^") AS "^(get_rterm_predname delta)^"_extra_alias) AS tbl;", (* delete each tuple *) " IF array_"^ (get_rterm_predname delta)^" IS DISTINCT FROM NULL THEN @@ -1566,3 +1597,801 @@ AS $$ $$; " in trigger_pgsql + + +module VarMap = Map.Make(String) + +(* A module for substitutions that map variables to + + - a pair of a table instance and a column name, or + - a constant value. *) +module Subst : sig + + type entry = + | Occurrence of instance_name * column_name + | EqualToConst of const + + type t + + val empty : t + + val add : named_var -> entry -> t -> t + + val fold : (named_var -> entry * entry list -> 'a -> 'a) -> t -> 'a -> 'a + +end = struct + + type entry = + | Occurrence of instance_name * column_name + | EqualToConst of const + + module InternalMap = VarMap + + type t = (entry * entry list) InternalMap.t + + + let empty = + InternalMap.empty + + + let add x entry subst = + match subst |> InternalMap.find_opt x with + | None -> + subst |> InternalMap.add x (entry, []) + + | Some (entry0, entry_acc) -> + subst |> InternalMap.add x (entry0, entry :: entry_acc) + + + let fold f subst acc = + InternalMap.fold (fun x (entry0, entry_acc) acc -> + f x (entry0, List.rev entry_acc) acc + ) subst acc + +end + +type argument = + | ArgNamedVar of named_var + | ArgConst of const + | ArgAnon + +type delta_kind = + | Insert + | Delete + +type delta_key = delta_kind * table_name + +type positive_predicate = + | PositivePred of table_name * argument list + | PositiveDelta of delta_key * argument list + +type negative_predicate = + | NegativePred of table_name * argument list + | NegativeDelta of delta_key * argument list + +type comparison_operator = + | EqualTo + | NotEqualTo + | LessThan + | GreaterThan + | LessThanOrEqualTo + | GreaterThanOrEqualTo + +let show_comparison_operator = function + | EqualTo -> "==" + | NotEqualTo -> "<>" + | LessThan -> "<" + | GreaterThan -> ">" + | LessThanOrEqualTo -> "<=" + | GreaterThanOrEqualTo -> ">=" + +type comparison = + | Comparison of comparison_operator * vterm * vterm + +module TableEnv = Map.Make(String) + +type table_environment = (column_name list) TableEnv.t + +type error_detail = + | InRule of rule + | InComparison of comparison + | InGroup of delta_key + +type error = + | InvalidArgInHead of { var : var; error_detail : error_detail } + | InvalidArgInBody of { var : var; error_detail : error_detail } + | ArityMismatch of { expected : int; got : int } + | UnknownComparisonOperator of string + | EqualToMoreThanOneConstant of { variable : named_var; const1 : const; const2 : const } + | HeadVariableDoesNotOccurInBody of named_var + | UnexpectedNamedVar of { named_var : named_var; error_detail : error_detail } + | UnexpectedVarForm of { var : var; error_detail : error_detail } + | UnknownBinaryOperator of { op : string; error_detail : error_detail } + | UnknownUnaryOperator of { op : string; error_detail : error_detail } + | UnknownTable of { table : table_name; error_detail : error_detail } + | HasMoreThanOneRuleGroup of delta_key + | DeltaNotFound of delta_key + + +let show_error_detail (error_detail : error_detail) = + match error_detail with + | InRule rule -> + Printf.sprintf "in rule %s" (string_of_rule rule) + | InComparison (Comparison (op, vt1, vt2)) -> + Printf.sprintf "in comparison %s %s %s" + (string_of_vterm vt1) (show_comparison_operator op) (string_of_vterm vt2) + | InGroup (Insert, table) -> + Printf.sprintf "in group +%s" table + | InGroup (Delete, table) -> + Printf.sprintf "in group -%s" table + + +let show_error = function + | InvalidArgInHead { var; error_detail } -> + Printf.sprintf "invalid arg %s in the rule head; %s" (string_of_var var) (show_error_detail error_detail) + | InvalidArgInBody { var; error_detail } -> + Printf.sprintf "invalid arg %s in the rule body; %s" (string_of_var var) (show_error_detail error_detail) + | ArityMismatch r -> + Printf.sprintf "arity mismatch (expected: %d, got: %d)" r.expected r.got + | UnknownComparisonOperator op -> + Printf.sprintf "unknown comparison operator %s" op + | EqualToMoreThanOneConstant r -> + Printf.sprintf "variable %s are required to be equal to more than one constants; %s and %s" + r.variable (string_of_const r.const1) (string_of_const r.const2) + | HeadVariableDoesNotOccurInBody named_var -> + Printf.sprintf "variable %s in a rule head does not occur in the rule body" named_var + | UnexpectedNamedVar { named_var; error_detail } -> + Printf.sprintf "unexpected named variable %s; %s" named_var (show_error_detail error_detail) + | UnexpectedVarForm { var; error_detail } -> + Printf.sprintf "unexpected variable form: %s; %s" (string_of_var var) (show_error_detail error_detail) + | UnknownBinaryOperator { op; error_detail } -> + Printf.sprintf "unknown binary operator %s; %s" op (show_error_detail error_detail) + | UnknownUnaryOperator { op; error_detail } -> + Printf.sprintf "unknown unary operator %s; %s" op (show_error_detail error_detail) + | UnknownTable { table; error_detail } -> + Printf.sprintf "unknown table %s; %s" table (show_error_detail error_detail) + | HasMoreThanOneRuleGroup (Insert, table) -> + Printf.sprintf "+%s has more than one rule group" table + | HasMoreThanOneRuleGroup (Delete, table) -> + Printf.sprintf "-%s has more than one rule group" table + | DeltaNotFound (Insert, table) -> + Printf.sprintf "no rule has already been defined for +%s" table + | DeltaNotFound (Delete, table) -> + Printf.sprintf "no rule has already been defined for -%s" table + + +let get_column_names_from_table ~(error_detail : error_detail) (table_env : table_environment) (table : table_name) : (column_name list, error) result = + let open ResultMonad in + match table_env |> TableEnv.find_opt table with + | None -> err @@ UnknownTable { table; error_detail } + | Some cols -> return cols + + +(* Gets the list `cols` of column names of table named `table` and zips it with `xs`. + Returns `Error _` when the length of `xs` is different from that of `cols`. *) +let combine_column_names ~(error_detail : error_detail) (table_env : table_environment) (table : table_name) (xs : 'a list) : ((column_name * 'a) list, error) result = + let open ResultMonad in + get_column_names_from_table ~error_detail table_env table >>= fun columns -> + try + return (List.combine columns xs) + with + | _ -> + err @@ ArityMismatch { + expected = List.length columns; + got = List.length xs; + } + + +let validate_args_in_head ~(error_detail : error_detail) (table_env : table_environment) (table : table_name) (args : var list) = + let open ResultMonad in + args |> List.fold_left (fun res arg -> + res >>= fun x_acc -> + match arg with + | NamedVar x -> return @@ x :: x_acc + | _ -> err @@ InvalidArgInHead { var = arg; error_detail } + ) (return []) >>= fun x_acc -> + let vars = List.rev x_acc in + combine_column_names ~error_detail table_env table vars + + +type head_spec = + | PredHead of table_name * (column_name * named_var) list + | DeltaHead of delta_kind * table_name * (column_name * named_var) list + + +let get_spec_from_head ~(error_detail : error_detail) (table_env : table_environment) (head : rterm) : (head_spec, error) result = + let open ResultMonad in + match head with + | Pred (table, args) -> + validate_args_in_head ~error_detail table_env table args >>= fun columns_and_vars -> + return @@ PredHead(table, columns_and_vars) + + | Deltainsert (table, args) -> + validate_args_in_head ~error_detail table_env table args >>= fun columns_and_vars -> + return @@ DeltaHead(Insert, table, columns_and_vars) + + | Deltadelete (table, args) -> + validate_args_in_head ~error_detail table_env table args >>= fun columns_and_vars -> + return @@ DeltaHead(Delete, table, columns_and_vars) + + +let get_comparison_operator (op_str : string) : (comparison_operator, error) result = + let open ResultMonad in + match op_str with + | "=" -> return EqualTo + | "<>" -> return NotEqualTo + | "<" -> return LessThan + | ">" -> return GreaterThan + | "<=" -> return LessThanOrEqualTo + | ">=" -> return GreaterThanOrEqualTo + | _ -> err @@ UnknownComparisonOperator op_str + + +let negate_comparison_operator = function + | EqualTo -> NotEqualTo + | NotEqualTo -> EqualTo + | LessThan -> GreaterThanOrEqualTo + | GreaterThan -> LessThanOrEqualTo + | LessThanOrEqualTo -> GreaterThan + | GreaterThanOrEqualTo -> LessThan + + +let validate_args_in_body ~(error_detail : error_detail) (vars : var list) : (argument list, error) result = + let open ResultMonad in + vars |> List.fold_left (fun res var -> + res >>= fun arg_acc -> + match var with + | NamedVar x -> return @@ ArgNamedVar x :: arg_acc + | ConstVar c -> return @@ ArgConst c :: arg_acc + | AnonVar -> return @@ ArgAnon :: arg_acc + | _ -> err @@ InvalidArgInBody { var; error_detail } + ) (return []) >>= fun arg_acc -> + return @@ List.rev arg_acc + + +(* Separate predicates in a given rule body into positive ones, negative ones, and comparisons. *) +let decompose_body ~(error_detail : error_detail) (body : term list) : (positive_predicate list * negative_predicate list * comparison list, error) result = + let open ResultMonad in + body |> List.fold_left (fun res term -> + res >>= fun (pos_acc, neg_acc, comp_acc) -> + match term with + | Rel (Pred (table, vars)) -> + validate_args_in_body ~error_detail vars >>= fun args -> + return (PositivePred (table, args) :: pos_acc, neg_acc, comp_acc) + + | Rel (Deltainsert (table, vars)) -> + validate_args_in_body ~error_detail vars >>= fun args -> + let delta_key = (Insert, table) in + return (PositiveDelta (delta_key, args) :: pos_acc, neg_acc, comp_acc) + + | Rel (Deltadelete (table, vars)) -> + validate_args_in_body ~error_detail vars >>= fun args -> + let delta_key = (Delete, table) in + return (PositiveDelta (delta_key, args) :: pos_acc, neg_acc, comp_acc) + + | Not (Pred (table, vars)) -> + validate_args_in_body ~error_detail vars >>= fun args -> + return (pos_acc, NegativePred (table, args) :: neg_acc, comp_acc) + + | Not (Deltainsert (table, vars)) -> + validate_args_in_body ~error_detail vars >>= fun args -> + let delta_key = (Insert, table) in + return (pos_acc, NegativeDelta (delta_key, args) :: neg_acc, comp_acc) + + | Not (Deltadelete (table, vars)) -> + validate_args_in_body ~error_detail vars >>= fun args -> + let delta_key = (Delete, table) in + return (pos_acc, NegativeDelta (delta_key, args) :: neg_acc, comp_acc) + + | Equat (Equation (op_str, t1, t2)) -> + get_comparison_operator op_str >>= fun op -> + return (pos_acc, neg_acc, Comparison (op, t1, t2) :: comp_acc) + + | Noneq (Equation (op_str, t1, t2)) -> + get_comparison_operator op_str >>= fun op -> + let op_dual = negate_comparison_operator op in + return (pos_acc, neg_acc, Comparison (op_dual, t1, t2) :: comp_acc) + + ) (return ([], [], [])) >>= fun (pos_acc, neg_acc, comp_acc) -> + return (List.rev pos_acc, List.rev neg_acc, List.rev comp_acc) + + +module DeltaKey = struct + type t = delta_key + + let compare = Stdlib.compare +end + +module DeltaEnv = Map.Make(DeltaKey) + +type delta_environment = (instance_name * column_name list) DeltaEnv.t + + +let assign_or_find_instance_names (delta_env : delta_environment) (poss : positive_predicate list) : ((positive_predicate * instance_name) list * instance_name list, error) result = + let open ResultMonad in + poss |> List.fold_left (fun res pos -> + res >>= fun (index, named_pos_acc, referred_instance_acc) -> + match pos with + | PositivePred (table, _args) -> + let instance = Printf.sprintf "%s_%d" table index in + return (index + 1, (pos, instance) :: named_pos_acc, referred_instance_acc) + + | PositiveDelta (delta_key, _args) -> + begin + match delta_env |> DeltaEnv.find_opt delta_key with + | None -> + err @@ DeltaNotFound delta_key + + | Some (instance, _cols) -> + return (index, (pos, instance) :: named_pos_acc, instance :: referred_instance_acc) + end + + ) (return (0, [], [])) >>= fun (_, named_pos_acc, referred_instance_acc) -> + return @@ (List.rev named_pos_acc, List.rev referred_instance_acc) + + +type as_const_or_var = + | AsNamedVar of named_var + | AsConst of const + | NotConstOrNamedVar + + +let as_const_or_var (vt : vterm) : as_const_or_var = + match vt with + | Const c -> AsConst c + | Var (NamedVar x) -> AsNamedVar x + | Var (ConstVar c) -> AsConst c + | _ -> NotConstOrNamedVar + + +let get_sql_binary_operation ~(error_detail : error_detail) (bin_op_str : string) : (sql_binary_operator, error) result = + let open ResultMonad in + match bin_op_str with + | "+" -> return SqlPlus + | "-" -> return SqlMinus + | "*" -> return SqlTimes + | "/" -> return SqlDivides + | "^" -> return SqlLor + | _ -> err @@ UnknownBinaryOperator { op = bin_op_str; error_detail } + + +let get_sql_unary_operation ~(error_detail : error_detail) (un_op_str : string) : (sql_unary_operator, error) result = + let open ResultMonad in + match un_op_str with + | "-" -> return SqlNegate + | _ -> err @@ UnknownUnaryOperator { op = un_op_str; error_detail } + + +let get_named_var (varmap : Subst.entry VarMap.t) (x : named_var) : sql_vterm option = + let open ResultMonad in + match varmap |> VarMap.find_opt x with + | None -> None + | Some (Subst.Occurrence (instance, column)) -> Some (SqlColumn (Some instance, column)) + | Some (Subst.EqualToConst c) -> Some (SqlConst c) + + +let sql_of_vterm_new ~(error_detail : error_detail) (varmap : Subst.entry VarMap.t) (vt : vterm) : (sql_vterm, error) result = + let open ResultMonad in + let rec aux (vt : vterm) = + match vt with + | Const c + | Var (ConstVar c) -> + return @@ SqlConst c + + | Var (NamedVar x) -> + begin + match get_named_var varmap x with + | None -> err @@ UnexpectedNamedVar { named_var = x; error_detail } + | Some sql_vt -> return sql_vt + end + + | Var var -> + err @@ UnexpectedVarForm { var; error_detail } + + | BinaryOp (bin_op_str, vt1, vt2) -> + get_sql_binary_operation ~error_detail bin_op_str >>= fun sql_bin_op -> + aux vt1 >>= fun sql_vt1 -> + aux vt2 >>= fun sql_vt2 -> + return @@ SqlBinaryOp (sql_bin_op, sql_vt1, sql_vt2) + + | UnaryOp (un_op_str, vt1) -> + get_sql_unary_operation ~error_detail un_op_str >>= fun sql_un_op -> + aux vt1 >>= fun sql_vt1 -> + return @@ SqlUnaryOp (sql_un_op, sql_vt1) + in + aux vt + + +let sql_vterm_of_arg ~(error_detail : error_detail) (varmap : Subst.entry VarMap.t) (arg : argument) : (sql_vterm option, error) result = + let open ResultMonad in + match arg with + | ArgNamedVar x -> + begin + match get_named_var varmap x with + | None -> err @@ UnexpectedNamedVar { named_var = x; error_detail } + | Some sql_vt -> return @@ Some sql_vt + end + + | ArgConst c -> + return @@ Some (SqlConst c) + + | ArgAnon -> + return None + + +let combine_delta_column_names (delta_env : delta_environment) (delta_key : delta_key) (args : argument list) : (instance_name * (column_name * argument) list, error) result = + let open ResultMonad in + match delta_env |> DeltaEnv.find_opt delta_key with + | None -> + err @@ DeltaNotFound delta_key + + | Some (instance, cols) -> + begin + try + return @@ (instance, List.combine cols args) + with + | _ -> + err @@ ArityMismatch { expected = List.length cols; got = List.length args } + end + + +(* Extends `subst` by traversing occurrence of variables in positive predicates. *) +let extend_substitution_by_traversing_positives ~(error_detail : error_detail) (table_env : table_environment) (delta_env : delta_environment) (named_poss : (positive_predicate * instance_name) list) (subst : Subst.t) : (Subst.t, error) result = + let open ResultMonad in + named_poss |> List.fold_left (fun res (pos, instance) -> + res >>= fun subst -> + begin + match pos with + | PositivePred (table, args) -> + combine_column_names ~error_detail table_env table args + + | PositiveDelta (delta_key, args) -> + combine_delta_column_names delta_env delta_key args >>= fun (_instance, columns_and_args) -> + return columns_and_args + + end >>= fun columns_and_args -> + let subst = + columns_and_args |> List.fold_left (fun subst (column, arg) -> + match arg with + | ArgNamedVar x -> subst |> Subst.add x (Subst.Occurrence (instance, column)) + | ArgConst _ -> subst + | ArgAnon -> subst + ) subst + in + return subst + ) (return subst) >>= fun subst -> + return subst + + +(* Extends `subst` by constraints where a variable is equal to a constant + Consumed equality constraints are removed from `comps`. *) +let extend_substitution_by_traversing_conparisons (comps : comparison list) (subst : Subst.t) : comparison list * Subst.t = + let (comp_acc, subst) = + comps |> List.fold_left (fun (comp_acc, subst) comp -> + let Comparison (op, vt1, vt2) = comp in + match op with + | EqualTo -> + begin + match (as_const_or_var vt1, as_const_or_var vt2) with + | (AsNamedVar x, AsConst c) -> (comp_acc, subst |> Subst.add x (Subst.EqualToConst c)) + | (AsConst c, AsNamedVar x) -> (comp_acc, subst |> Subst.add x (Subst.EqualToConst c)) + | _ -> (comp :: comp_acc, subst) + end + + | _ -> + (comp :: comp_acc, subst) + ) ([], subst) + in + let comps = List.rev comp_acc in + (comps, subst) + + +let partition_map f xs = + let (acc1, acc2) = + xs |> List.fold_left (fun (acc1, acc2) x -> + match f x with + | Ok v1 -> (v1 :: acc1, acc2) + | Error v2 -> (acc1, v2 :: acc2) + ) ([], []) + in + (List.rev acc1, List.rev acc2) + + +(* The type for representing a rule without its head predicate, + i.e., the part `(X_1, ..., X_m) :- C_1, ..., C_n` of + a rule `±r(X_1, ..., X_m) :- C_1, ..., C_n`. *) +type headless_rule = { + columns_and_vars : (column_name * named_var) list; + body : term list; +} + + +let convert_rule_to_operation_based_sql ~(error_detail : error_detail) (table_env : table_environment) (delta_env : delta_environment) (headless_rule : headless_rule) : (sql_query, error) result = + let open ResultMonad in + let columns_and_vars = headless_rule.columns_and_vars in + let body = headless_rule.body in + decompose_body ~error_detail body >>= fun (poss, negs, comps) -> + + assign_or_find_instance_names delta_env poss >>= fun (named_poss, referred_instances) -> + let subst = Subst.empty in + extend_substitution_by_traversing_positives ~error_detail table_env delta_env named_poss subst >>= fun subst -> + let (comps, subst) = extend_substitution_by_traversing_conparisons comps subst in + + (* Converts `subst` into SQL constraints and `varmap`: *) + Subst.fold (fun x (entry, entries) res -> + res >>= fun (sql_constraint_acc, varmap) -> + let (consts, occurrences) = + (entry :: entries) |> partition_map (function + | Subst.EqualToConst c -> Ok c + | Subst.Occurrence (instance, column) -> Error (instance, column) + ) + in + match (consts, occurrences) with + | ([], []) -> + assert false + + | ([], (instance0, column0) :: occurrence_rest) -> + let sql_constraint_acc = + let right = SqlColumn (Some instance0, column0) in + occurrence_rest |> List.fold_left (fun sql_constraint_acc (instance, column) -> + SqlConstraint (SqlColumn (Some instance, column), SqlRelEqual, right) :: sql_constraint_acc + ) sql_constraint_acc + in + let varmap = varmap |> VarMap.add x (Subst.Occurrence (instance0, column0)) in + return (sql_constraint_acc, varmap) + + | ([ c ], _) -> + let sql_constraint_acc = + let right = SqlConst c in + occurrences |> List.fold_left (fun sql_constraint_acc (table, column) -> + SqlConstraint (SqlColumn (Some table, column), SqlRelEqual, right) :: sql_constraint_acc + ) sql_constraint_acc + in + let varmap = varmap |> VarMap.add x (Subst.EqualToConst c) in + return (sql_constraint_acc, varmap) + + | (c1 :: c2 :: _, _) -> + err @@ EqualToMoreThanOneConstant { + variable = x; + const1 = c1; + const2 = c2; + } + ) subst (return ([], VarMap.empty)) >>= fun (sql_constraint_acc, varmap) -> + + (* Adds comparison constraints to SQL constraints: *) + comps |> List.fold_left (fun res comp -> + let Comparison (op, vt1, vt2) = comp in + let error_detail = InComparison comp in + res >>= fun sql_constraint_acc -> + sql_of_vterm_new ~error_detail varmap vt1 >>= fun sql_vt1 -> + sql_of_vterm_new ~error_detail varmap vt2 >>= fun sql_vt2 -> + let sql_op = + match op with + | EqualTo -> SqlRelEqual + | NotEqualTo -> SqlRelNotEqual + | LessThan -> SqlRelGeneral "<" + | GreaterThan -> SqlRelGeneral ">" + | LessThanOrEqualTo -> SqlRelGeneral "<=" + | GreaterThanOrEqualTo -> SqlRelGeneral ">=" + in + return @@ SqlConstraint (sql_vt1, sql_op, sql_vt2) :: sql_constraint_acc + ) (return sql_constraint_acc) >>= fun sql_constraint_acc -> + + (* Adds constraints that stem from negative predicates: *) + negs |> List.fold_left (fun res neg -> + res >>= fun sql_constraint_acc -> + begin + match neg with + | NegativePred (table, args) -> + combine_column_names ~error_detail table_env table args >>= fun columns_and_args -> + return (table, columns_and_args) + + | NegativeDelta (delta_key, args) -> + combine_delta_column_names delta_env delta_key args + + end >>= fun (table, columns_and_args) -> + let instance = "t" in + let sql_from = SqlFrom [ (SqlFromTable (None, table), instance) ] in + columns_and_args |> List.fold_left (fun res (column, arg) -> + res >>= fun acc -> + sql_vterm_of_arg ~error_detail varmap arg >>= function + | None -> (* corresponds to underscore *) + return @@ acc + + | Some sql_vt -> + return @@ SqlConstraint (SqlColumn (Some instance, column), SqlRelEqual, sql_vt) :: acc + + ) (return []) >>= fun acc -> + let sql_where = SqlWhere (List.rev acc) in + return @@ SqlNotExist (sql_from, sql_where) :: sql_constraint_acc + ) (return sql_constraint_acc) >>= fun sql_constraint_acc -> + + (* Builds the SELECT clause: *) + columns_and_vars |> List.fold_left (fun res (column0, x0) -> + res >>= fun selected_acc -> + match varmap |> VarMap.find_opt x0 with + | None -> + err @@ HeadVariableDoesNotOccurInBody x0 + + | Some (Subst.Occurrence (instance, column)) -> + return @@ (SqlColumn (Some instance, column), column0) :: selected_acc + + | Some (Subst.EqualToConst c) -> + return @@ (SqlConst c, column0) :: selected_acc + + ) (return []) >>= fun selected_acc -> + let sql_select = SqlSelect (List.rev selected_acc) in + + (* Builds the FROM clause: *) + let from_clause_entries = + List.concat [ + named_poss |> List.map (fun (pos, instance) -> + match pos with + | PositivePred (table, _args) -> [ (SqlFromTable (None, table), instance) ] + | _ -> [] + ) |> List.concat; + referred_instances |> List.map (fun instance -> + (SqlFromTable (None, instance), instance) + ) + ] + in + let sql_from = SqlFrom from_clause_entries in + + (* Builds the WHERE clause: *) + let sql_where = SqlWhere (List.rev sql_constraint_acc) in + + return @@ SqlQuery { + select = sql_select; + from = sql_from; + where = sql_where; + agg = (SqlGroupBy [], SqlHaving []); + } + + +module DeltaKeySet = Set.Make(DeltaKey) + +type rule_group = + | PredGroup of table_name * headless_rule + | DeltaGroup of delta_key * headless_rule list + +type delta_grouping_state = { + current_target : delta_key; + current_accumulated : headless_rule list; + already_handled : DeltaKeySet.t; +} + + +let divide_rules_into_groups (table_env : table_environment) (rules : Expr.rule list) : (rule_group list, error) result = + let open ResultMonad in + rules |> List.fold_left (fun res rule -> + res >>= fun (state_opt, group_acc) -> + let (head, body) = rule in + let error_detail = InRule rule in + get_spec_from_head ~error_detail table_env head >>= function + | PredHead(table, columns_and_vars) -> + let group = PredGroup(table, { columns_and_vars; body }) in + begin + match state_opt with + | None -> + return (None, group :: group_acc) + + | Some state -> + let group_prev = DeltaGroup(state.current_target, List.rev state.current_accumulated) in + return (None, group :: group_prev :: group_acc) + end + + | DeltaHead(delta_kind, table, columns_and_vars) -> + let delta_key = (delta_kind, table) in + let intermediate = { columns_and_vars; body } in + begin + match state_opt with + | None -> + return (Some { + current_target = delta_key; + current_accumulated = [ intermediate ]; + already_handled = DeltaKeySet.empty; + }, group_acc) + + | Some state -> + if state.already_handled |> DeltaKeySet.mem delta_key then + err @@ HasMoreThanOneRuleGroup delta_key + else if delta_key = state.current_target then + return (Some { state with + current_accumulated = intermediate :: state.current_accumulated; + }, group_acc) + else + let group = DeltaGroup(state.current_target, List.rev state.current_accumulated) in + return @@ (Some { + current_target = delta_key; + current_accumulated = [ intermediate ]; + already_handled = state.already_handled |> DeltaKeySet.add state.current_target; + }, group :: group_acc) + end + + ) (return (None, [])) >>= fun (state_opt, group_acc) -> + match state_opt with + | None -> + return (List.rev group_acc) + + | Some state -> + let group_last = DeltaGroup(state.current_target, List.rev state.current_accumulated) in + let groups = List.rev (group_last :: group_acc) in + return groups + + +let convert_expr_to_operation_based_sql (expr : expr) : (sql_operation list, error) result = + let open ResultMonad in + let table_env = + let defs = + match expr.view with + | None -> expr.sources + | Some view -> view :: expr.sources + in + defs |> List.fold_left (fun table_env (table, cols_and_types) -> + let cols = cols_and_types |> List.map (fun (col, _) -> col) in + table_env |> TableEnv.add table cols + ) TableEnv.empty + in + let rules = List.rev expr.rules in (* `expr` holds its rules in the reversed order *) + divide_rules_into_groups table_env rules >>= fun rule_groups -> + rule_groups |> List.fold_left (fun res rule_group -> + res >>= fun (i, creation_acc, update_acc, delta_env) -> + let temporary_table = Printf.sprintf "temp%d" i in + match rule_group with + | PredGroup(table, headless_rule) -> + let error_detail = + let rule = + let head = Pred (table, headless_rule.columns_and_vars |> List.map (fun (_, x) -> NamedVar x)) in + (head, headless_rule.body) + in + InRule rule + in + convert_rule_to_operation_based_sql ~error_detail table_env DeltaEnv.empty headless_rule >>= fun sql_query -> + let creation = SqlCreateView (table, sql_query) in + return (i + 1, creation :: creation_acc, update_acc, delta_env) + + | DeltaGroup(delta_key, headless_rules) -> + headless_rules |> List.fold_left (fun res_acc headless_rule -> + res_acc >>= fun sql_query_acc -> + let error_detail = + let rule = + let vars = headless_rule.columns_and_vars |> List.map (fun (_, x) -> NamedVar x) in + match delta_key with + | (Insert, table) -> (Deltainsert (table, vars), headless_rule.body) + | (Delete, table) -> (Deltadelete (table, vars), headless_rule.body) + in + InRule rule + in + convert_rule_to_operation_based_sql ~error_detail table_env delta_env headless_rule >>= fun sql_query -> + return @@ sql_query :: sql_query_acc + ) (return []) >>= fun sql_query_acc -> + let sql_query = + let sql_queries = List.rev sql_query_acc in + SqlUnion (SqlUnionOp, sql_queries) + in + let (delta_kind, table) = delta_key in + let error_detail = InGroup delta_key in + get_column_names_from_table ~error_detail table_env table >>= fun cols -> + let delta_env = delta_env |> DeltaEnv.add delta_key (temporary_table, cols) in + let creation = SqlCreateTemporaryTable (temporary_table, sql_query) in + let update = + let instance_name = "inst" in + match delta_kind with + | Insert -> + SqlInsertInto + (temporary_table, + SqlFrom [ (SqlFromTable (None, temporary_table), instance_name) ]) + + | Delete -> + SqlDeleteFrom + (temporary_table, + SqlWhere [ + SqlExist (SqlFrom [ (SqlFromTable (None, temporary_table), instance_name) ], SqlWhere []) ]) + in + return (i + 1, creation :: creation_acc, update :: update_acc, delta_env) + + ) (return (0, [], [], DeltaEnv.empty)) >>= fun (_, creation_acc, update_acc, _) -> + return @@ List.concat [ + List.rev creation_acc; + List.rev update_acc; + ] diff --git a/src/ast2sql.mli b/src/ast2sql.mli index 9483411e..3f6b5f6b 100644 --- a/src/ast2sql.mli +++ b/src/ast2sql.mli @@ -1,3 +1,15 @@ +open Utils + val unfold_view_sql : string -> bool -> Expr.expr -> string val unfold_delta_trigger_stt : string -> bool -> bool -> string -> string -> bool -> bool -> Expr.expr -> string + +type error + +val show_error : error -> string + +type sql_operation + +val stringify_sql_operation : sql_operation -> string + +val convert_expr_to_operation_based_sql : Expr.expr -> (sql_operation list, error) result diff --git a/src/main.ml b/src/main.ml index 1e907296..2b3790ad 100644 --- a/src/main.ml +++ b/src/main.ml @@ -390,6 +390,13 @@ let main () = if (!verification) then print_endline @@ "-- Program is validated --"; let oc =if !outputf = "" then stdout else open_out !outputf in if (not has_get) then fprintf oc "\n/*view definition (get):\n%s*/\n\n" view_rules_string; + let ast2 = + match Simplification.simplify ast2.rules with + | Ok rules -> + { ast2 with rules } + | Error e -> + failwith "failed to simplify rules" (* TODO: detailed error report *) + in let sql = Ast2sql.unfold_view_sql (!dbschema) (!log) ast2 in fprintf oc "%s\n" sql; let trigger_sql = Ast2sql.unfold_delta_trigger_stt (!dbschema) (!log) (!dejima_ud) shell_script (!dejima_user) (!inc) (!optimize) (constraint2rule ast2) in diff --git a/src/simplification.ml b/src/simplification.ml new file mode 100644 index 00000000..daef8b1e --- /dev/null +++ b/src/simplification.ml @@ -0,0 +1,690 @@ + +open Expr +open Utils + + +module Const = struct + type t = const + + let compare (c1 : const) (c2 : const) : int = + match (c1, c2) with + | (Int n1, Int n2) -> compare n1 n2 (* `Int.compare` can be used for OCaml >= 4.08 *) + | (Int _, _) -> 1 + | (_, Int _) -> -1 + | (Real r1, Real r2) -> Float.compare r1 r2 + | (Real _, _) -> 1 + | (_, Real _) -> -1 + | (String s1, String s2) -> String.compare s1 s2 + | (String _, _) -> 1 + | (_, String _) -> -1 + | (Bool b1, Bool b2) -> compare b1 b2 (* `Bool.compare` can be used for OCaml >= 4.08 *) + | (Bool _, _) -> 1 + | (_, Bool _) -> -1 + | (Null, Null) -> 0 + + (* Note: The equality of floats is NOT conform to IEEE754’s equality *) + let equal (c1 : const) (c2 : const) : bool = + compare c1 c2 = 0 +end + +module ConstSet = Set.Make(Const) + + +type table_name = string + +type var_name = string + +type intermediate_predicate = + | ImPred of table_name + | ImDeltaInsert of table_name + | ImDeltaDelete of table_name + +type intermediate_head_var = + | ImHeadVar of var_name + + +module HeadVarSubst = Map.Make(String) + + +type head_var_substitution = intermediate_head_var HeadVarSubst.t + + +let head_var_equal (ImHeadVar x1) (ImHeadVar x2) = + String.equal x1 x2 + + +type intermediate_body_var = + | ImBodyNamedVar of var_name + | ImBodyAnonVar + + +let body_var_compare (imbvar1 : intermediate_body_var) (imbvar2 : intermediate_body_var) : int = + match (imbvar1, imbvar2) with + | (ImBodyNamedVar x1, ImBodyNamedVar x2) -> String.compare x1 x2 + | (ImBodyNamedVar _, ImBodyAnonVar) -> 1 + | (ImBodyAnonVar, ImBodyAnonVar) -> 0 + | (ImBodyAnonVar, ImBodyNamedVar _) -> -1 + + +module Predicate = struct + type t = intermediate_predicate + + let compare (impred1 : t) (impred2 : t) : int = + match (impred1, impred2) with + | (ImPred t1, ImPred t2) -> String.compare t1 t2 + | (ImPred _, _) -> 1 + | (_, ImPred _) -> -1 + | (ImDeltaInsert t1, ImDeltaInsert t2) -> String.compare t1 t2 + | (ImDeltaInsert _, _) -> 1 + | (_, ImDeltaInsert _) -> -1 + | (ImDeltaDelete t1, ImDeltaDelete t2) -> String.compare t1 t2 +end + +module PredicateMap = Map.Make(Predicate) + +module PredicateSet = Set.Make(Predicate) + +module VariableMap = Map.Make(String) + + +type body_term_arguments = intermediate_body_var list + + +let string_of_body_term_arguments imbvars = + imbvars |> List.map (function + | ImBodyNamedVar x -> x + | ImBodyAnonVar -> "_" + ) |> String.concat ", " + + +module BodyTermArguments = struct + type t = body_term_arguments + + let compare (args1 : t) (args2 : t) : int = + let rec aux args1 args2 = + match (args1, args2) with + | ([], []) -> 0 + | ([], _ :: _) -> 1 + | (_ :: _, []) -> -1 + + | (x1 :: xs1, x2 :: xs2) -> + begin + match body_var_compare x1 x2 with + | 0 -> aux xs1 xs2 + | nonzero -> nonzero + end + in + aux args1 args2 + (* `List.compare` can be used for OCaml >= 4.12 *) +end + +module BodyTermArgumentsSet = Set.Make(BodyTermArguments) + + +type predicate_map = BodyTermArgumentsSet.t PredicateMap.t + +type constant_requirement = + | EqualTo of const + | NotEqualTo of ConstSet.t + +type equation_map = constant_requirement VariableMap.t + +type intermediate_rule = { + head_predicate : intermediate_predicate; + head_arguments : intermediate_head_var list; + positive_terms : predicate_map; + negative_terms : predicate_map; + equations : equation_map; +} + +type error = + | UnexpectedHeadVarForm of var + | UnexpectedBodyVarForm of var + | UnsupportedEquation of eterm + | NonequalityNotSupported of eterm + + +let constant_requirement_equal (cr1 : constant_requirement) (cr2 : constant_requirement) : bool = + match (cr1, cr2) with + | (EqualTo c1, EqualTo c2) -> Const.equal c1 c2 + | (NotEqualTo cset1, NotEqualTo cset2) -> ConstSet.equal cset1 cset2 + | _ -> false + + +let predicate_equal (impred1 : intermediate_predicate) (impred2 : intermediate_predicate) : bool = + match (impred1, impred2) with + | (ImPred t1, ImPred t2) -> String.equal t1 t2 + | (ImDeltaInsert t1, ImDeltaInsert t2) -> String.equal t1 t2 + | (ImDeltaDelete t1, ImDeltaDelete t2) -> String.equal t1 t2 + | _ -> false + + +let predicate_map_equal : predicate_map -> predicate_map -> bool = + PredicateMap.equal BodyTermArgumentsSet.equal + + +let head_arguments_equal (args1 : intermediate_head_var list) (args2 : intermediate_head_var list) : bool = + try + List.fold_left2 (fun b x1 x2 -> + b && head_var_equal x1 x2 + ) true args1 args2 + with + | Invalid_argument _ -> false + (* `List.equal` can be used for OCaml >= 4.12 *) + + +(* Checks that `imrule1` and `imrule2` are syntactically equal + (i.e. exactly the same including variable names). *) +let rule_equal (imrule1 : intermediate_rule) (imrule2 : intermediate_rule) : bool = + List.fold_left ( && ) true [ + predicate_equal imrule1.head_predicate imrule2.head_predicate; + head_arguments_equal imrule1.head_arguments imrule2.head_arguments; + predicate_map_equal imrule1.positive_terms imrule2.positive_terms; + predicate_map_equal imrule1.negative_terms imrule2.negative_terms; + VariableMap.equal constant_requirement_equal imrule1.equations imrule2.equations; + ] + + +let convert_head_var (var : var) : (intermediate_head_var, error) result = + let open ResultMonad in + match var with + | NamedVar x -> return (ImHeadVar x) + | _ -> err (UnexpectedHeadVarForm var) + + +let convert_body_var (var : var) : (intermediate_body_var, error) result = + let open ResultMonad in + match var with + | NamedVar x -> return (ImBodyNamedVar x) + | AnonVar -> return ImBodyAnonVar + | _ -> err (UnexpectedBodyVarForm var) + + +let separate_predicate_and_vars (rterm : rterm) : intermediate_predicate * var list = + match rterm with + | Pred (t, vars) -> (ImPred t, vars) + | Deltainsert (t, vars) -> (ImDeltaInsert t, vars) + | Deltadelete (t, vars) -> (ImDeltaDelete t, vars) + + +let convert_head_rterm (rterm : rterm) : (intermediate_predicate * intermediate_head_var list, error) result = + let open ResultMonad in + let (impred, vars) = separate_predicate_and_vars rterm in + vars |> mapM convert_head_var >>= fun imhvars -> + return (impred, imhvars) + + +let convert_body_rterm (rterm : rterm) : (intermediate_predicate * body_term_arguments, error) result = + let open ResultMonad in + let (impred, vars) = separate_predicate_and_vars rterm in + vars |> mapM convert_body_var >>= fun imbvars -> + return (impred, imbvars) + + +let convert_eterm ~(negated : bool) (eterm : eterm) : (var_name * constant_requirement, error) result = + let open ResultMonad in + begin + match eterm with + | Equation("=", Var (NamedVar x), Const c) -> return (x, true, c) + | Equation("=", Var (NamedVar x), Var (ConstVar c)) -> return (x, true, c) + | Equation("=", Const c, Var (NamedVar x)) -> return (x, true, c) + | Equation("=", Var (ConstVar c), Var (NamedVar x)) -> return (x, true, c) + | Equation("<>", Var (NamedVar x), Const c) -> return (x, false, c) + | Equation("<>", Var (NamedVar x), Var (ConstVar c)) -> return (x, false, c) + | Equation("<>", Const c, Var (NamedVar x)) -> return (x, false, c) + | Equation("<>", Var (ConstVar c), Var (NamedVar x)) -> return (x, false, c) + | _ -> err (UnsupportedEquation eterm) + end >>= fun (x, equal, c) -> + let equal = if negated then not equal else equal in + if equal then + return (x, EqualTo c) + else + return (x, NotEqualTo (ConstSet.singleton c)) + + +let extend_predicate_map (impred : intermediate_predicate) (args : body_term_arguments) (predmap : predicate_map) : predicate_map = + let argsset = + match predmap |> PredicateMap.find_opt impred with + | None -> BodyTermArgumentsSet.empty + | Some(argsset) -> argsset + in + predmap |> PredicateMap.add impred (argsset |> BodyTermArgumentsSet.add args) + + +let check_equation_map (x : var_name) (cr : constant_requirement) (eqnmap : equation_map) : equation_map option = + match eqnmap |> VariableMap.find_opt x with + | None -> + Some (eqnmap |> VariableMap.add x cr) + + | Some cr0 -> + begin + match (cr0, cr) with + | (EqualTo c0, EqualTo c) -> + if Const.equal c0 c then + Some eqnmap + else + None + + | (NotEqualTo cset0, EqualTo c) -> + if cset0 |> ConstSet.mem c then + None + else + Some (eqnmap |> VariableMap.add x (EqualTo c)) + + | (EqualTo c0, NotEqualTo cset) -> + if cset |> ConstSet.mem c0 then + None + else + Some eqnmap + + | (NotEqualTo cset0, NotEqualTo cset) -> + Some (eqnmap |> VariableMap.add x (NotEqualTo (ConstSet.union cset0 cset))) + end + + +(* Converts rules to intermediate ones. + The application `convert_rule rule` returns: + - `Error _` if `rule` is syntactically incorrect (or in unsupported forms), + - `Ok None` if it turns out that `rule` is syntactically correct but + obviously unsatisfiable according to its equations, or + - `Ok (Some imrule)` otherwise, i.e., if `rule` can be successfully converted to `imrule`. *) +let convert_rule (rule : rule) : (intermediate_rule option, error) result = + let open ResultMonad in + let (head, body) = rule in + convert_head_rterm head >>= fun (impred_head, imhvars) -> + body |> foldM (fun opt term -> + match opt with + | None -> + return None + + | Some (predmap_pos, predmap_neg, eqnmap) -> + begin + match term with + | Rel rterm -> + convert_body_rterm rterm >>= fun (impred, imbvars) -> + let predmap_pos = predmap_pos |> extend_predicate_map impred imbvars in + return (Some (predmap_pos, predmap_neg, eqnmap)) + + | Not rterm -> + convert_body_rterm rterm >>= fun (impred, imbvars) -> + let predmap_neg = predmap_neg |> extend_predicate_map impred imbvars in + return (Some (predmap_pos, predmap_neg, eqnmap)) + + | Equat eterm -> + convert_eterm ~negated:false eterm >>= fun (x, cr) -> + begin + match eqnmap |> check_equation_map x cr with + | None -> + (* If it turns out that the list of equations is unsatisfiable: *) + return None + + | Some eqnmap -> + return (Some (predmap_pos, predmap_neg, eqnmap)) + end + + | Noneq eterm -> + convert_eterm ~negated:true eterm >>= fun (x, cr) -> + begin + match eqnmap |> check_equation_map x cr with + | None -> + (* If it turns out that the list of equations is unsatisfiable: *) + return None + + | Some eqnmap -> + return (Some (predmap_pos, predmap_neg, eqnmap)) + end + end + ) (Some (PredicateMap.empty, PredicateMap.empty, VariableMap.empty)) >>= function + | None -> + return None + + | Some (predmap_pos, predmap_neg, eqnmap) -> + return (Some { + head_predicate = impred_head; + head_arguments = imhvars; + positive_terms = predmap_pos; + negative_terms = predmap_neg; + equations = eqnmap; + }) + + +let revert_head (impred : intermediate_predicate) (imhvars : intermediate_head_var list) : rterm = + let vars = imhvars |> List.map (function ImHeadVar x -> NamedVar x) in + match impred with + | ImPred t -> Pred (t, vars) + | ImDeltaInsert t -> Deltainsert (t, vars) + | ImDeltaDelete t -> Deltadelete (t, vars) + + +let revert_body_term ~(positive : bool) (impred : intermediate_predicate) (args : body_term_arguments) : term = + let vars = + args |> List.map (function + | ImBodyNamedVar x -> NamedVar x + | ImBodyAnonVar -> AnonVar + ) + in + let rterm = + match impred with + | ImPred t -> Pred (t, vars) + | ImDeltaInsert t -> Deltainsert (t, vars) + | ImDeltaDelete t -> Deltadelete (t, vars) + in + if positive then + Rel rterm + else + Not rterm + + +let revert_body_terms ~(positive : bool) ((impred, argsset) : intermediate_predicate * BodyTermArgumentsSet.t) : term list = + let argss = argsset |> BodyTermArgumentsSet.elements in + argss |> List.map (revert_body_term ~positive impred) + + +let revert_rule (imrule : intermediate_rule) : rule = + let { head_predicate; head_arguments; positive_terms; negative_terms; equations } = imrule in + let head = revert_head head_predicate head_arguments in + let terms_pos = + positive_terms |> PredicateMap.bindings |> List.map (revert_body_terms ~positive:true) |> List.concat + in + let terms_neg = + negative_terms |> PredicateMap.bindings |> List.map (revert_body_terms ~positive:false) |> List.concat + in + let terms_eq = + equations |> VariableMap.bindings |> List.map (fun (x, cr) -> + match cr with + | EqualTo c -> + [ Equat (Equation ("=", Var (NamedVar x), Const c)) ] + + | NotEqualTo cset -> + cset |> ConstSet.elements |> List.map (fun c -> + Noneq (Equation ("=", Var (NamedVar x), Const c)) + ) + ) |> List.concat + (* `List.concat_map` can be used for OCaml >= 4.10 *) + in + let body = List.concat [ terms_pos; terms_neg; terms_eq ] in + (head, body) + + +type occurrence_count_map = int VariableMap.t + + +let increment_occurrence_count (x : var_name) (count_map : occurrence_count_map) : occurrence_count_map = + match count_map |> VariableMap.find_opt x with + | None -> count_map |> VariableMap.add x 1 + | Some count -> count_map |> VariableMap.add x (count + 1) + + +let has_only_one_occurrence (count_map : occurrence_count_map) (x : var_name) : bool = + match count_map |> VariableMap.find_opt x with + | Some 1 -> true + | _ -> false + + +let fold_predicate_map_for_counting (predmap : predicate_map) (count_map : occurrence_count_map) : occurrence_count_map = + PredicateMap.fold (fun impred argsset count_map -> + BodyTermArgumentsSet.fold (fun args count_map -> + args |> List.fold_left (fun count_map arg -> + match arg with + | ImBodyNamedVar x -> count_map |> increment_occurrence_count x + | ImBodyAnonVar -> count_map + ) count_map + ) argsset count_map + ) predmap count_map + + +let erase_sole_occurrences_in_predicate_map (count_map : occurrence_count_map) (predmap : predicate_map) : predicate_map = + predmap |> PredicateMap.map (fun argsset -> + argsset |> BodyTermArgumentsSet.map (fun args -> + args |> List.map (fun arg -> + match arg with + | ImBodyNamedVar x -> + if x |> has_only_one_occurrence count_map then + ImBodyAnonVar + else + arg + + | ImBodyAnonVar -> + arg + ) + ) + ) + + +let erase_sole_occurrences (imrule : intermediate_rule) : intermediate_rule = + let { head_predicate; head_arguments; positive_terms; negative_terms; equations } = imrule in + + (* Counts occurrence of each variables: *) + let count_map = + VariableMap.empty + |> fold_predicate_map_for_counting positive_terms + |> fold_predicate_map_for_counting negative_terms + |> VariableMap.fold (fun x _c count_map -> count_map |> increment_occurrence_count x) equations + in + + (* Removes variables occurring in head arguments: *) + let count_map = + head_arguments |> List.fold_left (fun count_map (ImHeadVar x) -> + count_map |> VariableMap.remove x + ) count_map + in + + (* Converts variables that have only one occurrence with the underscore: *) + let positive_terms = positive_terms |> erase_sole_occurrences_in_predicate_map count_map in + let negative_terms = negative_terms |> erase_sole_occurrences_in_predicate_map count_map in + let equations = + VariableMap.fold (fun x c equations_new -> + if x |> has_only_one_occurrence count_map then + equations_new + else + equations_new |> VariableMap.add x c + ) equations VariableMap.empty + in + { head_predicate; head_arguments; positive_terms; negative_terms; equations } + + +let is_looser ~than:(args1 : body_term_arguments) (args2 : body_term_arguments) : bool = + match List.combine args1 args2 with + | exception Invalid_argument _ -> + false + + | zipped -> + zipped |> List.for_all (function + | (_, ImBodyAnonVar) -> true + | (ImBodyAnonVar, ImBodyNamedVar _) -> false + | (ImBodyNamedVar x1, ImBodyNamedVar x2) -> String.equal x1 x2 + ) + + +let remove_looser_positive_terms (argsset : BodyTermArgumentsSet.t) : BodyTermArgumentsSet.t = + let rec aux (acc : body_term_arguments list) ~(criterion : body_term_arguments) (targets : body_term_arguments list) = + match targets |> List.filter (fun target -> not (is_looser ~than:criterion target)) with + | [] -> + BodyTermArgumentsSet.of_list (criterion :: acc) + + | head :: tail -> + aux (criterion :: acc) ~criterion:head tail + in + (* Sorted in descending lexicographical order as to variable name lists: *) + let argss_sorted_desc = argsset |> BodyTermArgumentsSet.elements |> List.rev in + match argss_sorted_desc with + | [] -> + argsset + + | head :: tail -> + aux [] ~criterion:head tail + + +let remove_looser_negative_terms (argsset : BodyTermArgumentsSet.t) : BodyTermArgumentsSet.t = + let rec aux (acc : body_term_arguments list) ~(criterion : body_term_arguments) (targets : body_term_arguments list) = + match targets |> List.filter (fun target -> is_looser ~than:criterion target) with + | [] -> + BodyTermArgumentsSet.of_list (criterion :: acc) + + | head :: tail -> + aux (criterion :: acc) ~criterion:head tail + in + (* Sorted in ascending lexicographical order as to variable name lists: *) + let argss_sorted_asc = argsset |> BodyTermArgumentsSet.elements in + match argss_sorted_asc with + | [] -> + argsset + + | head :: tail -> + aux [] ~criterion:head tail + + +let remove_looser_terms (imrule : intermediate_rule) : intermediate_rule = + let { head_predicate; head_arguments; positive_terms; negative_terms; equations } = imrule in + let positive_terms = positive_terms |> PredicateMap.map remove_looser_positive_terms in + let negative_terms = negative_terms |> PredicateMap.map remove_looser_negative_terms in + { head_predicate; head_arguments; positive_terms; negative_terms; equations } + + +let simplify_rule_step (imrule : intermediate_rule) : intermediate_rule = + let imrule = erase_sole_occurrences imrule in + remove_looser_terms imrule + + +let rec simplify_rule_recursively (imrule1 : intermediate_rule) : intermediate_rule = + let imrule2 = simplify_rule_step imrule1 in + if rule_equal imrule1 imrule2 then + (* If the simplification reaches a fixpoint: *) + imrule2 + else + simplify_rule_recursively imrule2 + + +let has_contradicting_body (imrule : intermediate_rule) : bool = + let { positive_terms; negative_terms; _ } = imrule in + let dom = + PredicateSet.empty + |> PredicateMap.fold (fun impred _ dom -> dom |> PredicateSet.add impred) positive_terms + |> PredicateMap.fold (fun impred _ dom -> dom |> PredicateSet.add impred) negative_terms + in + PredicateSet.fold (fun impred found -> + if found then + true + else + match (positive_terms |> PredicateMap.find_opt impred, negative_terms |> PredicateMap.find_opt impred) with + | (Some argsset_pos, Some argsset_neg) -> + let argss_pos = BodyTermArgumentsSet.elements argsset_pos in + let argss_neg = BodyTermArgumentsSet.elements argsset_neg in + let posnegs = + argss_pos |> List.map (fun args_pos -> + argss_neg |> List.map (fun args_neg -> (args_pos, args_neg)) + ) |> List.concat + in + posnegs |> List.exists (fun (args_pos, args_neg) -> + is_looser ~than:args_pos args_neg + ) + + | _ -> + false + ) dom false + + +let substitute_predicate_map (subst : head_var_substitution) (predmap1 : predicate_map) : predicate_map = + predmap1 |> PredicateMap.map (fun argsset -> + BodyTermArgumentsSet.fold (fun args acc -> + let args = + args |> List.map (function + | ImBodyNamedVar x1 -> + begin + match subst |> HeadVarSubst.find_opt x1 with + | None -> ImBodyNamedVar x1 + | Some (ImHeadVar x2) -> ImBodyNamedVar x2 + end + + | ImBodyAnonVar -> + ImBodyAnonVar + ) + in + acc |> BodyTermArgumentsSet.add args + ) argsset BodyTermArgumentsSet.empty + ) + + +let substitute_equation_map (subst : head_var_substitution) (eqns1 : equation_map) : equation_map = + VariableMap.fold (fun x1 cr acc -> + match subst |> HeadVarSubst.find_opt x1 with + | None -> acc |> VariableMap.add x1 cr + | Some (ImHeadVar x2) -> acc |> VariableMap.add x2 cr + ) eqns1 VariableMap.empty + + +let are_alpha_equivalent_rules (imrule1 : intermediate_rule) (imrule2 : intermediate_rule) : bool = + let + { + head_predicate = hp1; head_arguments = hvars1; + positive_terms = poss1; negative_terms = negs1; equations = eqns1; + } = imrule1 + in + let + { + head_predicate = hp2; head_arguments = hvars2; + positive_terms = poss2; negative_terms = negs2; equations = eqns2; + } = imrule2 + in + if not (predicate_equal hp1 hp2) then + false + else + match List.combine hvars1 hvars2 with + | exception Invalid_argument(_) -> + false + + | zipped -> + let subst = + zipped |> List.fold_left (fun subst (ImHeadVar hvar1, ImHeadVar hvar2) -> + subst |> HeadVarSubst.add hvar1 (ImHeadVar hvar2) + ) HeadVarSubst.empty + in + let poss1subst = poss1 |> substitute_predicate_map subst in + let negs1subst = negs1 |> substitute_predicate_map subst in + let eqns1subst = eqns1 |> substitute_equation_map subst in + List.fold_left ( && ) true [ + PredicateMap.equal BodyTermArgumentsSet.equal poss1subst poss1; + PredicateMap.equal BodyTermArgumentsSet.equal negs1subst negs1; + VariableMap.equal constant_requirement_equal eqns1subst eqns1; + ] + + +let remove_duplicate_rules (imrules : intermediate_rule list) : intermediate_rule list = + let rec aux acc imrules = + match imrules with + | [] -> + List.rev acc + + | imrule_head :: imrules_tail -> + let imrules_tail = + imrules_tail |> List.filter (fun imrule -> + not (are_alpha_equivalent_rules imrule_head imrule) + ) + in + aux (imrule_head :: acc) imrules_tail + in + aux [] imrules + + +let simplify (rules : rule list) : (rule list, error) result = + let open ResultMonad in + + (* Converts each rule to an intermediate rule (with unsatisfiable ones removed): *) + rules |> foldM (fun imrule_acc rule -> + convert_rule rule >>= function + | None -> return imrule_acc + | Some imrule -> return (imrule :: imrule_acc) + ) [] >>= fun imrule_acc -> + let imrules = List.rev imrule_acc in + + (* Performs per-rule simplification: *) + let imrules = imrules |> List.map simplify_rule_recursively in + + (* Removes rules that have a contradicting body: *) + let imrules = imrules |> List.filter (fun imrule -> not (has_contradicting_body imrule)) in + + (* Removes duplicate rules here *) + let imrules = imrules |> remove_duplicate_rules in + + let rules = imrules |> List.map revert_rule in + return rules diff --git a/src/simplification.mli b/src/simplification.mli new file mode 100644 index 00000000..6d04300f --- /dev/null +++ b/src/simplification.mli @@ -0,0 +1,6 @@ + +open Expr + +type error + +val simplify : rule list -> (rule list, error) result diff --git a/src/test/ast2sql_operation_based_conversion_test.ml b/src/test/ast2sql_operation_based_conversion_test.ml new file mode 100644 index 00000000..1cc4ea7d --- /dev/null +++ b/src/test/ast2sql_operation_based_conversion_test.ml @@ -0,0 +1,139 @@ + +open Utils +open Expr + + +type test_case = { + title : string; + expr : expr; + expected : string; +} + +type test_result = + | Pass + | Fail of { expected : string; got : string } + + +let run_test (test_case : test_case) : (test_result, Ast2sql.error) result = + let open ResultMonad in + let expr = test_case.expr in + let expected = test_case.expected in + + Ast2sql.convert_expr_to_operation_based_sql expr >>= fun sql_operations -> + let got = sql_operations |> List.map Ast2sql.stringify_sql_operation |> String.concat " " in + + if String.equal got expected then + return Pass + else + return (Fail { expected; got }) + + +(* Runs all the test cases in the given list, prints every result, + and returns whether a failure has occurred. *) +let run_tests (test_cases : test_case list) : bool = + test_cases |> List.fold_left (fun has_failed test_case -> + let title = test_case.title in + match run_test test_case with + | Ok Pass -> + Printf.printf "- %s: OK\n" title; + has_failed + + | Ok (Fail { expected; got }) -> + Printf.printf "! %s: FAILED\n" title; + Printf.printf "expected:\n\"%s\"\n" expected; + Printf.printf "got:\n\"%s\"\n" got; + true + + | Error _ -> + Printf.printf "! %s: FAILED (error)\n" title; + true + ) false + + +let main () = + let test_cases = + [ + { + title = + "ed and eed"; + expr = + { + rules = begin + let rule1 = + (* "+eed(E, D) :- ed(E, D), D = 'A', E != 'Joe', ¬eed(E, D)." *) + Deltainsert ("eed", [ NamedVar "E"; NamedVar "D" ]), [ + Rel (Pred ("ed", [ NamedVar "E"; NamedVar "D" ])); + Equat (Equation ("=", Var (NamedVar "D"), Const (String "'A'"))); + Equat (Equation ("<>", Var (NamedVar "E"), Const (String "'Joe'"))); + Not (Pred ("eed", [ NamedVar "E"; NamedVar "D" ])); + ] + in + let rule2 = + (* "-eed(E, D) :- ed(V1, D), eed(E, D), E = 'Joe', D = 'A', V1 != 'Joe', ¬eed(V1, D)." *) + Deltadelete ("eed", [ NamedVar "E"; NamedVar "D" ]), [ + Rel (Pred ("ed", [ NamedVar "V1"; NamedVar "D" ])); + Rel (Pred ("eed", [ NamedVar "E"; NamedVar "D" ])); + Equat (Equation ("=", Var (NamedVar "E"), Const (String "'Joe'"))); + Equat (Equation ("=", Var (NamedVar "D"), Const (String "'A'"))); + Equat (Equation ("<>", Var (NamedVar "V1"), Const (String "'Joe'"))); + Not (Pred ("eed", [ NamedVar "V1"; NamedVar "D" ])); + ] + in + let rule3 = + (* "+ed(E, D) :- ed(V1, D), E = 'Joe', D = 'A', V1 != 'Joe', ¬ed(E, D), ¬eed(V1, D)." *) + Deltainsert ("ed", [ NamedVar "E"; NamedVar "D" ]), [ + Rel (Pred ("ed", [ NamedVar "V1"; NamedVar "D" ])); + Equat (Equation ("=", Var (NamedVar "E"), Const (String "'Joe'"))); + Equat (Equation ("=", Var (NamedVar "D"), Const (String "'A'"))); + Equat (Equation ("<>", Var (NamedVar "V1"), Const (String "'Joe'"))); + Not (Pred ("ed", [ NamedVar "E"; NamedVar "D" ])); + Not (Pred ("eed", [ NamedVar "V1"; NamedVar "D" ])); + ] + in + [ rule3; rule2; rule1 ] (* `expr` holds its rules in the reversed order *) + end; + facts = []; + query = None; + sources = [ + ("ed", [ ("emp_name", Sstring); ("dept_name", Sstring) ]); + ("eed", [ ("emp_name", Sstring); ("dept_name", Sstring) ]); + ]; + view = None; + constraints = []; + primary_keys = []; + }; + expected = + let query1 = + String.concat " " [ + "SELECT ed_0.emp_name AS emp_name, 'A' AS dept_name FROM ed AS ed_0 WHERE"; + "ed_0.dept_name = 'A' AND ed_0.emp_name <> 'Joe' AND"; + "NOT EXISTS ( SELECT * FROM eed AS t WHERE t.emp_name = ed_0.emp_name AND t.dept_name = 'A' )"; + ] + in + let query2 = + String.concat " " [ + "SELECT 'Joe' AS emp_name, 'A' AS dept_name FROM ed AS ed_0, eed AS eed_1 WHERE"; + "ed_0.dept_name = 'A' AND eed_1.dept_name = 'A' AND eed_1.emp_name = 'Joe' AND ed_0.emp_name <> 'Joe' AND"; + "NOT EXISTS ( SELECT * FROM eed AS t WHERE t.emp_name = ed_0.emp_name AND t.dept_name = 'A' )"; + ] + in + let query3 = + String.concat " " [ + "SELECT 'Joe' AS emp_name, 'A' AS dept_name FROM ed AS ed_0 WHERE"; + "ed_0.dept_name = 'A' AND ed_0.emp_name <> 'Joe' AND"; + "NOT EXISTS ( SELECT * FROM ed AS t WHERE t.emp_name = 'Joe' AND t.dept_name = 'A' ) AND"; + "NOT EXISTS ( SELECT * FROM eed AS t WHERE t.emp_name = ed_0.emp_name AND t.dept_name = 'A' )"; + ] + in + String.concat " " [ + Printf.sprintf "CREATE TEMPORARY TABLE temp0 AS %s;" query1; + Printf.sprintf "CREATE TEMPORARY TABLE temp1 AS %s;" query2; + Printf.sprintf "CREATE TEMPORARY TABLE temp2 AS %s;" query3; + "INSERT INTO temp0 SELECT * FROM temp0 AS inst;"; + "DELETE FROM temp1 WHERE EXISTS ( SELECT * FROM temp1 AS inst );"; + "INSERT INTO temp2 SELECT * FROM temp2 AS inst;"; + ] + }; + ] + in + run_tests test_cases diff --git a/src/test/simplification_test.ml b/src/test/simplification_test.ml new file mode 100644 index 00000000..4b276c7f --- /dev/null +++ b/src/test/simplification_test.ml @@ -0,0 +1,185 @@ + +open Utils +open Expr + + +type test_case = { + title : string; + input : rule list; + expected : rule list; +} + +type test_result = + | Pass + | Fail of { expected : string; got : string } + + +let run_test (test_case : test_case) = + let open ResultMonad in + Simplification.simplify test_case.input >>= fun got -> + let s_got = got |> List.map string_of_rule |> String.concat "; " in + let s_expected = test_case.expected |> List.map string_of_rule |> String.concat "; " in + if String.equal s_got s_expected then + return Pass + else + return (Fail { expected = s_expected; got = s_got }) + + +(* Runs all the test cases in the given list, prints every result, + and returns whether a failure has occurred. *) +let run_tests (test_cases : test_case list) : bool = + test_cases |> List.fold_left (fun has_failed test_case -> + let title = test_case.title in + match run_test test_case with + | Ok Pass -> + Printf.printf "- %s: OK\n" title; + has_failed + + | Ok (Fail { expected; got }) -> + Printf.printf "! %s: FAILED\n" title; + Printf.printf "expected:\n\"%s\"\n" expected; + Printf.printf "got:\n\"%s\"\n" got; + true + + | Error _ -> + Printf.printf "! %s: FAILED (error)\n" title; + true + ) false + + +let main () = + let track = NamedVar "TRACK" in + let date = NamedVar "DATE" in + let rating = NamedVar "RATING" in + let album = NamedVar "ALBUM" in + let quantity = NamedVar "QUANTITY" in + run_tests [ + { + title = "empty"; + input = []; + expected = []; + }; + { + title = "(1)"; + input = [ + (* (1): + -tracks(TRACK, DATE, RATING, ALBUM) :- + albums(ALBUM, _), + albums(ALBUM, V6845), + tracks(TRACK, DATE, RATING, ALBUM), + tracks(TRACK, DATE, RATING, ALBUM), + RATING = 1. *) + (Deltadelete ("tracks", [ track; date; rating; album ]), [ + Rel (Pred ("albums", [ album; AnonVar ])); + Rel (Pred ("albums", [ album; NamedVar "V6845" ])); + Rel (Pred ("tracks", [ track; date; rating; album ])); + Rel (Pred ("tracks", [ track; date; rating; album ])); + Equat (Equation ("=", Var rating, Const (Int 1))); + ]); + ]; + expected = [ + (* (1) simplified: + -tracks(TRACK, DATE, RATING, ALBUM) :- + albums(ALBUM, _), + tracks(TRACK, DATE, RATING, ALBUM), + RATING = 1. *) + (Deltadelete ("tracks", [ track; date; rating; album ]), [ + Rel (Pred ("albums", [ album; AnonVar ])); + Rel (Pred ("tracks", [ track; date; rating; album ])); + Equat (Equation ("=", Var rating, Const (Int 1))); + ]) + ]; + }; + { + title = "(2): erased by contradiction"; + input = [ + (* (2): + -tracks(TRACK, DATE, RATING, ALBUM) :- + albums(ALBUM, V34), + albums(ALBUM, V6846), + tracks(TRACK, DATE, RATING, ALBUM), + tracks(V31, V32, V33, ALBUM), + tracks(TRACK, DATE, RATING, ALBUM), + RATING = 1, + not tracks(V31, V32, V33, ALBUM). *) + (Deltadelete ("tracks", [ track; date; rating; album ]), [ + Rel (Pred ("albums", [ album; NamedVar "V34" ])); + Rel (Pred ("albums", [ album; NamedVar "V6846" ])); + Rel (Pred ("tracks", [ track; date; rating; album ])); + Rel (Pred ("tracks", [ NamedVar "V31"; NamedVar "V32"; NamedVar "V33"; album ])); + Rel (Pred ("tracks", [ track; date; rating; album ])); + Equat (Equation ("=", Var rating, Const (Int 1))); + Not (Pred ("tracks", [ NamedVar "V31"; NamedVar "V32"; NamedVar "V33"; album ])); + ]); + ]; + expected = []; + }; + { + title = "(7)"; + input = [ + (* (7): + -albums(ALBUM, QUANTITY) :- + albums(ALBUM, QUANTITY), + albums(ALBUM, QUANTITY), + tracks(_, _, _, ALBUM), + tracks(V6853, V6854, V6855, ALBUM), + V6855 = 1. *) + (Deltadelete ("albums", [ album; quantity ]), [ + Rel (Pred ("albums", [ album; quantity ])); + Rel (Pred ("albums", [ album; quantity ])); + Rel (Pred ("tracks", [ AnonVar; AnonVar; AnonVar; album ])); + Rel (Pred ("tracks", [ NamedVar "V6853"; NamedVar "V6854"; NamedVar "V6855"; album ])); + Equat (Equation ("=", Var (NamedVar "V6855"), Const (Int 1))); + ]); + ]; + expected = [ + (* (7) simplified: + -albums(ALBUM, QUANTITY) :- + albums(ALBUM, QUANTITY), + tracks(_, _, V6855, ALBUM), + V6855 = 1. *) + (Deltadelete ("albums", [ album; quantity ]), [ + Rel (Pred ("albums", [ album; quantity ])); + Rel (Pred ("tracks", [ AnonVar; AnonVar; NamedVar "V6855"; album ])); + Equat (Equation ("=", Var (NamedVar "V6855"), Const (Int 1))); + ]); + ]; + }; + { + title = "(32)"; + input = [ + (* (32): + -albums(ALBUM, QUANTITY) :- + albums(ALBUM, QUANTITY), + albums(ALBUM, QUANTITY), + tracks(TRACK, DATE, RATING, ALBUM), + tracks(V6847, V6848, V6849, ALBUM), + V6849 = 1, + not RATING = 1. *) + (Deltadelete ("albums", [ album; quantity ]), [ + Rel (Pred ("albums", [ album; quantity ])); + Rel (Pred ("albums", [ album; quantity ])); + Rel (Pred ("tracks", [ track; date; rating; album ])); + Rel (Pred ("tracks", [ NamedVar "V6847"; NamedVar "V6848"; NamedVar "V6849"; album ])); + Equat (Equation ("=", Var (NamedVar "V6849"), Const (Int 1))); + Noneq (Equation ("=", Var rating, Const (Int 1))); + ]); + ]; + expected = [ + (* (32) simplified: + -albums(ALBUM, QUANTITY) :- + albums(ALBUM, QUANTITY), + tracks(_, _, RATING, ALBUM), + tracks(_, _, V6849, ALBUM), + not RATING = 1, + V6849 = 1. *) + (Deltadelete ("albums", [ album; quantity ]), [ + Rel (Pred ("albums", [ album; quantity ])); + Rel (Pred ("tracks", [ AnonVar; AnonVar; rating; album ])); + Rel (Pred ("tracks", [ AnonVar; AnonVar; NamedVar "V6849"; album ])); + Noneq (Equation ("=", Var rating, Const (Int 1))); + Equat (Equation ("=", Var (NamedVar "V6849"), Const (Int 1))); + ]); + ]; + }; + ] diff --git a/src/test/test_main.ml b/src/test/test_main.ml new file mode 100644 index 00000000..ac069633 --- /dev/null +++ b/src/test/test_main.ml @@ -0,0 +1,10 @@ + +let () = + let has_failed = + List.exists (fun b -> b) [ + Ast2sql_operation_based_conversion_test.main (); + Simplification_test.main (); + (* You can add more tests here *) + ] + in + exit (if has_failed then 1 else 0) diff --git a/src/utils.ml b/src/utils.ml index bd0048c2..df96160a 100644 --- a/src/utils.ml +++ b/src/utils.ml @@ -619,3 +619,52 @@ let colored_string color str = match color with | "purple" -> "\027[35m"^str^"\027[0m" | "brown" -> "\027[33m"^str^"\027[0m" | _ -> str + + +module ResultMonad : sig + val return : 'a -> ('a, 'e) result + val err : 'e -> ('a, 'e) result + val map_err : ('e1 -> 'e2) -> ('a, 'e1) result -> ('a, 'e2) result + val foldM : ('a -> 'b -> ('a, 'e) result) -> 'a -> 'b list -> ('a, 'e) result + val mapM : ('a -> ('b, 'e) result) -> 'a list -> ('b list, 'e) result + val ( >>= ) : ('a, 'e) result -> ('a -> ('b, 'e) result) -> ('b, 'e) result +end = struct + + let return v = + Ok v + + let err e = + Error e + + let ( >>= ) v f = + match v with + | Ok x -> f x + | Error e -> Error e + + let map_err f v = + match v with + | Ok x -> Ok x + | Error e -> Error (f e) + + let foldM f acc vs = + vs |> List.fold_left (fun res v -> + res >>= fun acc -> + f acc v + ) (return acc) + + let mapM f vs = + vs |> foldM (fun acc v -> + f v >>= fun y -> + return (y :: acc) + ) [] >>= fun acc -> + return (List.rev acc) +end + + +type named_var = string + +type table_name = string + +type column_name = string + +type instance_name = string