diff --git a/nemo/src/program_analysis/analysis.rs b/nemo/src/program_analysis/analysis.rs index bfc1fde2d..425b0be8e 100644 --- a/nemo/src/program_analysis/analysis.rs +++ b/nemo/src/program_analysis/analysis.rs @@ -65,7 +65,7 @@ pub struct RuleAnalysis { } /// Errors than can occur during rule analysis -#[derive(Error, Debug, Copy, Clone)] +#[derive(Error, Debug, Copy, Clone, PartialEq, Eq)] #[allow(clippy::enum_variant_names)] pub enum RuleAnalysisError { /// Unsupported feature: Overloading of predicate names by arity/type @@ -744,10 +744,26 @@ impl ChaseProgram { /// Check if the program contains rules with unsupported features pub fn check_for_unsupported_features(&self) -> Result<(), RuleAnalysisError> { - let mut arities = HashMap::new(); + let mut arities = self + .parsed_predicate_declarations() + .iter() + .map(|(predicate, types)| (predicate.clone(), types.len())) + .collect::>(); for source in self.sources() { - arities.insert(source.predicate.clone(), source.input_types().arity()); + match arities.entry(source.predicate.clone()) { + std::collections::hash_map::Entry::Occupied(slot) => { + // both declared and in a source + let arity = slot.get(); + + if *arity != source.input_types().arity() { + return Err(RuleAnalysisError::UnsupportedFeaturePredicateOverloading); + } + } + std::collections::hash_map::Entry::Vacant(slot) => { + slot.insert(source.input_types().arity()); + } + } } for rule in self.rules() { @@ -924,12 +940,13 @@ mod test { use std::collections::HashMap; use crate::{ + io::parser::parse_program, model::{ chase_model::{ChaseAtom, ChaseProgram, ChaseRule}, DataSourceDeclaration, DsvFile, Identifier, NativeDataSource, PrimitiveType, Term, TupleConstraint, Variable, }, - program_analysis::analysis::get_fresh_rule_predicate, + program_analysis::analysis::{get_fresh_rule_predicate, RuleAnalysisError}, }; fn get_test_rules_and_predicates() -> ( @@ -1412,4 +1429,81 @@ mod test { .unwrap(); assert_eq!(inferred_types, expected_types); } + + #[test] + fn overloading_is_unsupported() { + let program = ChaseProgram::try_from( + parse_program( + r#" + @source q[3]: load-rdf("dummy.nt") . + p(?x, ?y) :- q(?x), q(?y) . + "#, + ) + .unwrap(), + ) + .unwrap(); + + assert_eq!( + program.check_for_unsupported_features(), + Err(RuleAnalysisError::UnsupportedFeaturePredicateOverloading) + ); + + let program = ChaseProgram::try_from( + parse_program( + r#" + @source q[3]: load-rdf("dummy.nt") . + @declare q(integer, integer) . + "#, + ) + .unwrap(), + ) + .unwrap(); + + assert_eq!( + program.check_for_unsupported_features(), + Err(RuleAnalysisError::UnsupportedFeaturePredicateOverloading) + ); + + let program = + ChaseProgram::try_from(parse_program(r#"q(?x, ?y) :- q(?x), q(?y) ."#).unwrap()) + .unwrap(); + + assert_eq!( + program.check_for_unsupported_features(), + Err(RuleAnalysisError::UnsupportedFeaturePredicateOverloading) + ); + + let program = ChaseProgram::try_from( + parse_program( + r#" + p(?x, ?y) :- q(?x), q(?y) . + q(23, 42) . + "#, + ) + .unwrap(), + ) + .unwrap(); + + assert_eq!( + program.check_for_unsupported_features(), + Err(RuleAnalysisError::UnsupportedFeaturePredicateOverloading) + ); + + let program = ChaseProgram::try_from( + parse_program( + r#" + @declare q(integer, integer) . + p(?x, ?y) :- q(?x), q(?y) . + q(23) . + "#, + ) + .unwrap(), + ) + .unwrap(); + + assert_eq!( + program.check_for_unsupported_features(), + Err(RuleAnalysisError::UnsupportedFeaturePredicateOverloading) + ); + } }