From e9935075a0079865a692a2b3d50af842ddcdb894 Mon Sep 17 00:00:00 2001 From: Lukas Gerlach <12497479+monsterkrampe@users.noreply.github.com> Date: Mon, 27 Nov 2023 13:28:26 +0100 Subject: [PATCH] Fix initial type requirements (merge correctly) (#426) --- nemo/src/program_analysis/type_inference.rs | 60 +++++- .../type_inference/type_requirement.rs | 181 +++++++++--------- 2 files changed, 147 insertions(+), 94 deletions(-) diff --git a/nemo/src/program_analysis/type_inference.rs b/nemo/src/program_analysis/type_inference.rs index 46798b8a6..762299cce 100644 --- a/nemo/src/program_analysis/type_inference.rs +++ b/nemo/src/program_analysis/type_inference.rs @@ -27,12 +27,12 @@ type VariableTypesForRules = Vec>; type TypeInferenceResult = Result<(PredicateTypes, VariableTypesForRules), TypeError>; pub(super) fn infer_types(program: &ChaseProgram) -> TypeInferenceResult { - let pred_reqs = requirements_from_pred_decls(program.parsed_predicate_declarations()); - let import_reqs = requirements_from_imports(program.imports()); - let fact_reqs = requirements_from_facts(program.facts()); - let literal_reqs = requirements_from_literals_in_rules(program.rules()); - let aggregate_reqs = requirements_from_aggregates_in_rules(program.rules()); - let existential_reqs = requirements_from_existentials_in_rules(program.rules()); + let pred_reqs = requirements_from_pred_decls(program.parsed_predicate_declarations())?; + let import_reqs = requirements_from_imports(program.imports())?; + let fact_reqs = requirements_from_facts(program.facts())?; + let literal_reqs = requirements_from_literals_in_rules(program.rules())?; + let aggregate_reqs = requirements_from_aggregates_in_rules(program.rules())?; + let existential_reqs = requirements_from_existentials_in_rules(program.rules())?; let mut type_requirements = import_reqs; merge_type_requirements(&mut type_requirements, fact_reqs)?; @@ -858,4 +858,52 @@ mod test { let inferred_types_res = infer_types(&s_decl_unresolvable_conflict_with_fact_values); assert!(inferred_types_res.is_err()); } + + #[test] + fn infer_types_two_times_same_head_predicate() { + let ( + (_basic_rule, exis_rule, _rule_with_constant), + (_fact1, _fact2, _fact3), + (a, _b, _c, r, _s, _t, _q), + ) = get_test_rules_and_facts_and_predicates(); + + let x = Variable::Universal(Identifier("x".to_string())); + let y = Variable::Existential(Identifier("y".to_string())); + let z = Variable::Existential(Identifier("z".to_string())); + + let ty = PrimitiveTerm::Variable(y); + let tz = PrimitiveTerm::Variable(z); + + // R(!y, !z) :- A(x). + let exis_rule_2 = ChaseRule::new( + vec![PrimitiveAtom::new(r.clone(), vec![ty, tz])], + vec![], + vec![], + vec![VariableAtom::new(a.clone(), vec![x])], + vec![], + vec![], + vec![], + ); + + let two_times_same_head_predicate = ChaseProgram::builder() + .rule(exis_rule_2) + .rule(exis_rule) + .predicate_declaration(a.clone(), vec![PrimitiveType::String]) + .build(); + + let expected_types: HashMap> = [ + (a, vec![PrimitiveType::String]), + (r, vec![PrimitiveType::Any, PrimitiveType::Any]), + ( + get_fresh_rule_predicate(0), + vec![PrimitiveType::Any, PrimitiveType::Any], + ), + (get_fresh_rule_predicate(1), vec![PrimitiveType::Any]), + ] + .into_iter() + .collect(); + + let inferred_types = infer_types(&two_times_same_head_predicate).unwrap().0; + assert_eq!(inferred_types, expected_types); + } } diff --git a/nemo/src/program_analysis/type_inference/type_requirement.rs b/nemo/src/program_analysis/type_inference/type_requirement.rs index f155874b0..a0020ba66 100644 --- a/nemo/src/program_analysis/type_inference/type_requirement.rs +++ b/nemo/src/program_analysis/type_inference/type_requirement.rs @@ -38,7 +38,7 @@ impl TypeRequirement { pub(super) fn replace_with_max_type_if_compatible(self, other: Self) -> Option { match self { Self::Hard(t1) => match other { - Self::Hard(t2) => (t1 == t2).then_some(self), + Self::Hard(t2) => (t1 >= t2).then_some(self), Self::Soft(t2) => { let max = t1.max_type(&t2); // check if the max type is compatible with both types via partial ord @@ -104,11 +104,6 @@ fn add_type_requirements( .enumerate() .for_each(|(index, (a, b))| { let replacement = a.stricter_requirement(*b); - // if force_use_of_stricter_requirement { - // a.stricter_requirement(*b); - // } else { - // a.replace_with_max_type_if_compatible(*b) - // }; match replacement { Some(replacement) => *a = replacement, @@ -151,41 +146,48 @@ pub(super) fn override_type_requirements( pub(super) fn requirements_from_pred_decls( decls: &HashMap>, -) -> PredicateTypeRequirements { - decls - .iter() - .map(|(pred, types)| { - ( - pred.clone(), - types - .iter() - .copied() - .map(TypeRequirement::Hard) - .collect::>(), - ) - }) - .collect() +) -> Result { + let mut type_requirements = HashMap::new(); + + for (pred, types) in decls { + add_type_requirements( + &mut type_requirements, + pred.clone(), + types + .iter() + .copied() + .map(TypeRequirement::Hard) + .collect::>(), + )?; + } + + Ok(type_requirements) } pub(super) fn requirements_from_imports<'a, T: Iterator>( imports: T, -) -> PredicateTypeRequirements { - imports - .map(|import_spec| { - ( - import_spec.predicate().clone(), - import_spec - .type_constraint() - .iter() - .cloned() - .map(TypeRequirement::from) - .collect(), - ) - }) - .collect() +) -> Result { + let mut type_requirements = HashMap::new(); + + for import_spec in imports { + add_type_requirements( + &mut type_requirements, + import_spec.predicate().clone(), + import_spec + .type_constraint() + .iter() + .cloned() + .map(TypeRequirement::from) + .collect(), + )?; + } + + Ok(type_requirements) } -pub(super) fn requirements_from_facts(facts: &Vec) -> PredicateTypeRequirements { +pub(super) fn requirements_from_facts( + facts: &Vec, +) -> Result { let mut fact_decls: PredicateTypeRequirements = HashMap::new(); for fact in facts { let reqs_for_fact: Vec = fact @@ -198,16 +200,15 @@ pub(super) fn requirements_from_facts(facts: &Vec) -> PredicateTypeRe }) .collect(); - add_type_requirements(&mut fact_decls, fact.predicate(), reqs_for_fact) - .expect("Since fact requirements are all soft, there should be no conflicts."); + add_type_requirements(&mut fact_decls, fact.predicate(), reqs_for_fact)?; } - fact_decls + Ok(fact_decls) } pub(super) fn requirements_from_literals_in_rules( rules: &Vec, -) -> PredicateTypeRequirements { +) -> Result { let mut literal_decls: PredicateTypeRequirements = HashMap::new(); for chase_rule in rules { @@ -256,68 +257,72 @@ pub(super) fn requirements_from_literals_in_rules( }) .collect(); - add_type_requirements(&mut literal_decls, chase_atom.predicate(), reqs_for_atom) - .expect("Since literal requirements are all soft, there should be no conflicts."); + add_type_requirements(&mut literal_decls, chase_atom.predicate(), reqs_for_atom)?; } } - literal_decls + Ok(literal_decls) } pub(super) fn requirements_from_aggregates_in_rules( rules: &[ChaseRule], -) -> PredicateTypeRequirements { - rules +) -> Result { + let mut type_requirements = HashMap::new(); + + for (rule, atom) in rules .iter() .flat_map(|rule| rule.head().iter().map(move |atom| (rule, atom))) - .map(|(rule, atom)| { - ( - atom.predicate(), - atom.terms() - .iter() - .map(|term| { - if let PrimitiveTerm::Variable(Variable::Universal(identifier)) = term { - if identifier.name().starts_with(AGGREGATE_VARIABLE_PREFIX) { - let aggregate = rule - .aggregates() - .iter() - .find(|aggregate| aggregate.output_variable.name() == *identifier.0).expect("variable with aggregate prefix is missing an associated aggregate"); - if - aggregate.aggregate_operation == - AggregateOperation::Count - { - return TypeRequirement::Hard(PrimitiveType::Integer) - } + { + add_type_requirements( + &mut type_requirements, + atom.predicate(), + atom.terms() + .iter() + .map(|term| { + if let PrimitiveTerm::Variable(Variable::Universal(identifier)) = term { + if identifier.name().starts_with(AGGREGATE_VARIABLE_PREFIX) { + let aggregate = rule + .aggregates() + .iter() + .find(|aggregate| aggregate.output_variable.name() == *identifier.0).expect("variable with aggregate prefix is missing an associated aggregate"); + if + aggregate.aggregate_operation == + AggregateOperation::Count + { + return TypeRequirement::Hard(PrimitiveType::Integer) } } - TypeRequirement::None - }) - .collect(), - ) - }) - .collect() + } + TypeRequirement::None + }) + .collect(), + )?; + } + + Ok(type_requirements) } pub(super) fn requirements_from_existentials_in_rules( rules: &[ChaseRule], -) -> PredicateTypeRequirements { - rules - .iter() - .flat_map(|r| r.head()) - .map(|a| { - ( - a.predicate(), - a.terms() - .iter() - .map(|t| { - if matches!(t, PrimitiveTerm::Variable(Variable::Existential(_))) { - TypeRequirement::Hard(PrimitiveType::Any) - } else { - TypeRequirement::None - } - }) - .collect::>(), - ) - }) - .collect() +) -> Result { + let mut type_requirements = HashMap::new(); + + for atom in rules.iter().flat_map(|r| r.head()) { + add_type_requirements( + &mut type_requirements, + atom.predicate(), + atom.terms() + .iter() + .map(|t| { + if matches!(t, PrimitiveTerm::Variable(Variable::Existential(_))) { + TypeRequirement::Hard(PrimitiveType::Any) + } else { + TypeRequirement::None + } + }) + .collect::>(), + )?; + } + + Ok(type_requirements) }