Skip to content

Commit

Permalink
Fix initial type requirements (merge correctly) (#426)
Browse files Browse the repository at this point in the history
  • Loading branch information
monsterkrampe authored Nov 27, 2023
1 parent 08c2dbd commit e993507
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 94 deletions.
60 changes: 54 additions & 6 deletions nemo/src/program_analysis/type_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ type VariableTypesForRules = Vec<HashMap<Variable, PrimitiveType>>;
type TypeInferenceResult = Result<(PredicateTypes, VariableTypesForRules), TypeError>;

pub(super) fn infer_types(program: &ChaseProgram) -> TypeInferenceResult {
let pred_reqs = requirements_from_pred_decls(program.parsed_predicate_declarations());
let import_reqs = requirements_from_imports(program.imports());
let fact_reqs = requirements_from_facts(program.facts());
let literal_reqs = requirements_from_literals_in_rules(program.rules());
let aggregate_reqs = requirements_from_aggregates_in_rules(program.rules());
let existential_reqs = requirements_from_existentials_in_rules(program.rules());
let pred_reqs = requirements_from_pred_decls(program.parsed_predicate_declarations())?;
let import_reqs = requirements_from_imports(program.imports())?;
let fact_reqs = requirements_from_facts(program.facts())?;
let literal_reqs = requirements_from_literals_in_rules(program.rules())?;
let aggregate_reqs = requirements_from_aggregates_in_rules(program.rules())?;
let existential_reqs = requirements_from_existentials_in_rules(program.rules())?;

let mut type_requirements = import_reqs;
merge_type_requirements(&mut type_requirements, fact_reqs)?;
Expand Down Expand Up @@ -858,4 +858,52 @@ mod test {
let inferred_types_res = infer_types(&s_decl_unresolvable_conflict_with_fact_values);
assert!(inferred_types_res.is_err());
}

#[test]
fn infer_types_two_times_same_head_predicate() {
let (
(_basic_rule, exis_rule, _rule_with_constant),
(_fact1, _fact2, _fact3),
(a, _b, _c, r, _s, _t, _q),
) = get_test_rules_and_facts_and_predicates();

let x = Variable::Universal(Identifier("x".to_string()));
let y = Variable::Existential(Identifier("y".to_string()));
let z = Variable::Existential(Identifier("z".to_string()));

let ty = PrimitiveTerm::Variable(y);
let tz = PrimitiveTerm::Variable(z);

// R(!y, !z) :- A(x).
let exis_rule_2 = ChaseRule::new(
vec![PrimitiveAtom::new(r.clone(), vec![ty, tz])],
vec![],
vec![],
vec![VariableAtom::new(a.clone(), vec![x])],
vec![],
vec![],
vec![],
);

let two_times_same_head_predicate = ChaseProgram::builder()
.rule(exis_rule_2)
.rule(exis_rule)
.predicate_declaration(a.clone(), vec![PrimitiveType::String])
.build();

let expected_types: HashMap<Identifier, Vec<PrimitiveType>> = [
(a, vec![PrimitiveType::String]),
(r, vec![PrimitiveType::Any, PrimitiveType::Any]),
(
get_fresh_rule_predicate(0),
vec![PrimitiveType::Any, PrimitiveType::Any],
),
(get_fresh_rule_predicate(1), vec![PrimitiveType::Any]),
]
.into_iter()
.collect();

let inferred_types = infer_types(&two_times_same_head_predicate).unwrap().0;
assert_eq!(inferred_types, expected_types);
}
}
181 changes: 93 additions & 88 deletions nemo/src/program_analysis/type_inference/type_requirement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl TypeRequirement {
pub(super) fn replace_with_max_type_if_compatible(self, other: Self) -> Option<Self> {
match self {
Self::Hard(t1) => match other {
Self::Hard(t2) => (t1 == t2).then_some(self),
Self::Hard(t2) => (t1 >= t2).then_some(self),
Self::Soft(t2) => {
let max = t1.max_type(&t2);
// check if the max type is compatible with both types via partial ord
Expand Down Expand Up @@ -104,11 +104,6 @@ fn add_type_requirements(
.enumerate()
.for_each(|(index, (a, b))| {
let replacement = a.stricter_requirement(*b);
// if force_use_of_stricter_requirement {
// a.stricter_requirement(*b);
// } else {
// a.replace_with_max_type_if_compatible(*b)
// };

match replacement {
Some(replacement) => *a = replacement,
Expand Down Expand Up @@ -151,41 +146,48 @@ pub(super) fn override_type_requirements(

pub(super) fn requirements_from_pred_decls(
decls: &HashMap<Identifier, Vec<PrimitiveType>>,
) -> PredicateTypeRequirements {
decls
.iter()
.map(|(pred, types)| {
(
pred.clone(),
types
.iter()
.copied()
.map(TypeRequirement::Hard)
.collect::<Vec<_>>(),
)
})
.collect()
) -> Result<PredicateTypeRequirements, TypeError> {
let mut type_requirements = HashMap::new();

for (pred, types) in decls {
add_type_requirements(
&mut type_requirements,
pred.clone(),
types
.iter()
.copied()
.map(TypeRequirement::Hard)
.collect::<Vec<_>>(),
)?;
}

Ok(type_requirements)
}

pub(super) fn requirements_from_imports<'a, T: Iterator<Item = &'a ImportSpec>>(
imports: T,
) -> PredicateTypeRequirements {
imports
.map(|import_spec| {
(
import_spec.predicate().clone(),
import_spec
.type_constraint()
.iter()
.cloned()
.map(TypeRequirement::from)
.collect(),
)
})
.collect()
) -> Result<PredicateTypeRequirements, TypeError> {
let mut type_requirements = HashMap::new();

for import_spec in imports {
add_type_requirements(
&mut type_requirements,
import_spec.predicate().clone(),
import_spec
.type_constraint()
.iter()
.cloned()
.map(TypeRequirement::from)
.collect(),
)?;
}

Ok(type_requirements)
}

pub(super) fn requirements_from_facts(facts: &Vec<ChaseFact>) -> PredicateTypeRequirements {
pub(super) fn requirements_from_facts(
facts: &Vec<ChaseFact>,
) -> Result<PredicateTypeRequirements, TypeError> {
let mut fact_decls: PredicateTypeRequirements = HashMap::new();
for fact in facts {
let reqs_for_fact: Vec<TypeRequirement> = fact
Expand All @@ -198,16 +200,15 @@ pub(super) fn requirements_from_facts(facts: &Vec<ChaseFact>) -> PredicateTypeRe
})
.collect();

add_type_requirements(&mut fact_decls, fact.predicate(), reqs_for_fact)
.expect("Since fact requirements are all soft, there should be no conflicts.");
add_type_requirements(&mut fact_decls, fact.predicate(), reqs_for_fact)?;
}

fact_decls
Ok(fact_decls)
}

pub(super) fn requirements_from_literals_in_rules(
rules: &Vec<ChaseRule>,
) -> PredicateTypeRequirements {
) -> Result<PredicateTypeRequirements, TypeError> {
let mut literal_decls: PredicateTypeRequirements = HashMap::new();

for chase_rule in rules {
Expand Down Expand Up @@ -256,68 +257,72 @@ pub(super) fn requirements_from_literals_in_rules(
})
.collect();

add_type_requirements(&mut literal_decls, chase_atom.predicate(), reqs_for_atom)
.expect("Since literal requirements are all soft, there should be no conflicts.");
add_type_requirements(&mut literal_decls, chase_atom.predicate(), reqs_for_atom)?;
}
}

literal_decls
Ok(literal_decls)
}

pub(super) fn requirements_from_aggregates_in_rules(
rules: &[ChaseRule],
) -> PredicateTypeRequirements {
rules
) -> Result<PredicateTypeRequirements, TypeError> {
let mut type_requirements = HashMap::new();

for (rule, atom) in rules
.iter()
.flat_map(|rule| rule.head().iter().map(move |atom| (rule, atom)))
.map(|(rule, atom)| {
(
atom.predicate(),
atom.terms()
.iter()
.map(|term| {
if let PrimitiveTerm::Variable(Variable::Universal(identifier)) = term {
if identifier.name().starts_with(AGGREGATE_VARIABLE_PREFIX) {
let aggregate = rule
.aggregates()
.iter()
.find(|aggregate| aggregate.output_variable.name() == *identifier.0).expect("variable with aggregate prefix is missing an associated aggregate");
if
aggregate.aggregate_operation ==
AggregateOperation::Count
{
return TypeRequirement::Hard(PrimitiveType::Integer)
}
{
add_type_requirements(
&mut type_requirements,
atom.predicate(),
atom.terms()
.iter()
.map(|term| {
if let PrimitiveTerm::Variable(Variable::Universal(identifier)) = term {
if identifier.name().starts_with(AGGREGATE_VARIABLE_PREFIX) {
let aggregate = rule
.aggregates()
.iter()
.find(|aggregate| aggregate.output_variable.name() == *identifier.0).expect("variable with aggregate prefix is missing an associated aggregate");
if
aggregate.aggregate_operation ==
AggregateOperation::Count
{
return TypeRequirement::Hard(PrimitiveType::Integer)
}
}
TypeRequirement::None
})
.collect(),
)
})
.collect()
}
TypeRequirement::None
})
.collect(),
)?;
}

Ok(type_requirements)
}

pub(super) fn requirements_from_existentials_in_rules(
rules: &[ChaseRule],
) -> PredicateTypeRequirements {
rules
.iter()
.flat_map(|r| r.head())
.map(|a| {
(
a.predicate(),
a.terms()
.iter()
.map(|t| {
if matches!(t, PrimitiveTerm::Variable(Variable::Existential(_))) {
TypeRequirement::Hard(PrimitiveType::Any)
} else {
TypeRequirement::None
}
})
.collect::<Vec<_>>(),
)
})
.collect()
) -> Result<PredicateTypeRequirements, TypeError> {
let mut type_requirements = HashMap::new();

for atom in rules.iter().flat_map(|r| r.head()) {
add_type_requirements(
&mut type_requirements,
atom.predicate(),
atom.terms()
.iter()
.map(|t| {
if matches!(t, PrimitiveTerm::Variable(Variable::Existential(_))) {
TypeRequirement::Hard(PrimitiveType::Any)
} else {
TypeRequirement::None
}
})
.collect::<Vec<_>>(),
)?;
}

Ok(type_requirements)
}

0 comments on commit e993507

Please sign in to comment.