Skip to content

Commit

Permalink
Array contract deduplication (#1674)
Browse files Browse the repository at this point in the history
* Contract deduplication for arrays as well

* Fix reverted boolean check

* Fix other boolean error + clippy warning
  • Loading branch information
yannham authored Oct 12, 2023
1 parent a3505a1 commit 47823f5
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 31 deletions.
74 changes: 62 additions & 12 deletions core/src/eval/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use crate::{
*,
},
transform::Closurizable,
typecheck::eq::{contract_eq, EvalEnvsRef},
};

use malachite::{
Expand Down Expand Up @@ -1823,17 +1824,56 @@ impl<R: ImportResolver, C: Cache> VirtualMachine<R, C> {
// applied because we don't have a way of tracking which elements
// should take which contracts.

let (ctrs_left, ctrs_common) : (Vec<_>, Vec<_>) = attrs1
.pending_contracts
.into_iter()
.partition(|ctr| !attrs2.pending_contracts.contains(ctr));

let ctrs_right = attrs2
.pending_contracts
// Separate contracts between the parts that aren't common, and
// must be applied right away, and the common part, which can be
// kept lazy.
let mut ctrs_left = attrs1.pending_contracts;
// We use a vector of `Option` so that we can set the elements to
// remove to `None` and make a single pass at the end
// to retain the remaining ones.
let mut ctrs_right_sieve : Vec<_> = attrs2.pending_contracts.into_iter().map(Some).collect();
let mut ctrs_common = Vec::new();

// We basically compute the intersection (`ctr_common`),
// `ctrs_left - ctr_common`, and `ctrs_right - ctr_common`.
let ctrs_left : Vec<_> =
ctrs_left
.into_iter()
.filter(|ctr| {
!ctrs_left.contains(ctr) && !ctrs_common.contains(ctr)
});
// We don't deduplicate polymorphic contracts, because
// they're not idempotent.
if ctr.can_have_poly_ctrs() {
return true;
}

let envs_left = EvalEnvsRef {
eval_env: &env1,
initial_env: &self.initial_env,
};

let twin_index = ctrs_right_sieve
.iter()
.filter_map(|ctr| ctr.as_ref())
.position(|other_ctr| {
let envs_right = EvalEnvsRef {
eval_env: &env2,
initial_env: &self.initial_env,
};

contract_eq::<EvalEnvsRef>(0, &ctr.contract, envs_left, &other_ctr.contract, envs_right)
});

if let Some(index) = twin_index {
ctrs_right_sieve[index] = None;
false
}
else {
true
}
})
.collect();

let ctrs_right = ctrs_right_sieve.into_iter().flatten();

ts.extend(ts1.into_iter().map(|t|
RuntimeContract::apply_all(t, ctrs_left.iter().cloned(), pos1)
Expand Down Expand Up @@ -2098,18 +2138,28 @@ impl<R: ImportResolver, C: Cache> VirtualMachine<R, C> {
match_sharedterm! {t2,
with {
Term::Array(ts, attrs) => {
let mut attrs = attrs;
let mut final_env = env2;

// Preserve the environment of the contract in the resulting array.
let rt3 = rt3.closurize(&mut self.cache, &mut env2, env3);
let contract = rt3.closurize(&mut self.cache, &mut final_env, env3);
RuntimeContract::push_dedup(
&self.initial_env,
&mut attrs.pending_contracts,
&final_env,
RuntimeContract::new(contract, lbl),
&final_env
);

let array_with_ctr = Closure {
body: RichTerm::new(
Term::Array(
ts,
attrs.with_extra_contracts([RuntimeContract::new(rt3, lbl)])
attrs,
),
pos2,
),
env: env2,
env: final_env,
};

Ok(array_with_ctr)
Expand Down
18 changes: 16 additions & 2 deletions core/src/label.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ pub struct Label {
/// The path of the type being currently checked in the original type.
pub path: ty_path::Path,

/// An environment mapping type variables to [`TypeVarData`]. Used by
/// polymorphic contracts to decide which actions to take when encountering a `forall`.
/// An environment mapping type variables to [`TypeVarData`]. Used by polymorphic contracts to
/// decide which actions to take when encountering a `forall`.
pub type_environment: HashMap<SealingKey, TypeVarData>,

/// The name of the record field to report in blame errors. This is set
Expand Down Expand Up @@ -502,6 +502,20 @@ impl Label {
pub fn with_field_name(self, field_name: Option<LocIdent>) -> Self {
Label { field_name, ..self }
}

/// Tests if the contract associated to this label might have polymorphic subcontracts
/// (equivalently, if the contract is derived from a type which has free type variables). Such
/// contracts are special, in particular because they aren't idempotent and thus can't be
/// freely deduplicated.
///
/// This check is an over approximation and might return `true` even if the contract is not
/// polymorphic, in exchange of being fast (constant time).
pub fn can_have_poly_ctrs(&self) -> bool {
// Checking that the type environment is not empty is a bit coarse: what it actually checks
// is that this contract is derived from the body of a `forall`. For example, in `forall a.
// a -> Number`, `Number` isn't polymorphic, but `has_polymorphic_ctrs` will return `true`.
!self.type_environment.is_empty()
}
}

impl Default for Label {
Expand Down
17 changes: 0 additions & 17 deletions core/src/term/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,6 @@ impl ArrayAttrs {
self.pending_contracts.clear();
self
}

/// Extend contracts from an iterator of `PendingContract`.
/// De-duplicate equal contracts. Note that current contract
/// equality testing is very limited, but this may change in the
/// future
pub fn with_extra_contracts<I>(mut self, iter: I) -> Self
where
I: IntoIterator<Item = RuntimeContract>,
{
for ctr in iter {
if !self.pending_contracts.contains(&ctr) {
self.pending_contracts.push(ctr)
}
}

self
}
}

/// A Nickel array, represented as a view (slice) into a shared backing array. The view is
Expand Down
6 changes: 6 additions & 0 deletions core/src/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,12 @@ impl RuntimeContract {

contracts.push(ctr);
}

/// Check if this contract might have polymorphic subcontracts. See
/// [crate::label::Label::can_have_poly_ctrs].
pub fn can_have_poly_ctrs(&self) -> bool {
self.label.can_have_poly_ctrs()
}
}

impl Traverse<RichTerm> for RuntimeContract {
Expand Down

0 comments on commit 47823f5

Please sign in to comment.