Skip to content

Commit

Permalink
Merge pull request google-research#924 from google-research/nat-type
Browse files Browse the repository at this point in the history
Add a `Nat` type for non-negative integers and use it for sizes and ordinals
  • Loading branch information
dougalm authored May 28, 2022
2 parents 396f962 + a629c68 commit 10d4346
Show file tree
Hide file tree
Showing 59 changed files with 657 additions and 444 deletions.
15 changes: 8 additions & 7 deletions examples/ctc.dx
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def clipidx (n:Type) [Ix n] (i:Int) : n =
-- Returns element at 0 if less than zero.
-- Ideally we could have an alternative
-- to Fin that just clips the index at its bounds.
from_ordinal n (select (i < 0) 0 i)
from_ordinal n $ unsafe_i_to_n (select (i < 0) 0 i)

def logaddexp (x:Float) (y:Float) : Float =
m = max x y
Expand Down Expand Up @@ -72,7 +72,7 @@ def ctc {vocab time position} [Eq vocab, Eq position, Eq time]
False -> log 0.000001

same_as_last = \ilabels s.
o = ordinal s
o = n_to_i $ ordinal s
select (o >= 2) (ilabels.s == ilabels.(clipidx _ (o - 2))) False

safe_idx = \prev s.
Expand All @@ -85,25 +85,26 @@ def ctc {vocab time position} [Eq vocab, Eq position, Eq time]
True -> log_prob_seq_t0
False -> for s.
cond = ilabels.s == blank || same_as_last ilabels s
labar = logaddexp prev.s (safe_idx prev ((ordinal s) - 1))
other = logaddexp labar (safe_idx prev ((ordinal s) - 2))
labar = logaddexp prev.s (safe_idx prev (n_to_i (ordinal s) - 1))
other = logaddexp labar (safe_idx prev (n_to_i (ordinal s) - 2))
ans = select cond labar other
ans + normalized_logits.t.(ilabels.s)

log_prob_seq_final = fold log_prob_seq_t0 update

-- Todo: nicer way to get last two elements of log_prob_seq_final.
seq_length = 1 + size (position & (Fin 2))
endlabel = log_prob_seq_final.((seq_length - 2)@_)
endspace = log_prob_seq_final.((seq_length - 1)@_)
endlabel = log_prob_seq_final.((unsafe_nat_diff seq_length 2)@_)
endspace = log_prob_seq_final.((unsafe_nat_diff seq_length 1)@_)
logaddexp endlabel endspace


'### Demo

def randIdxNoZero (n:Type) [Ix n] (k:Key) : n =
unif = rand k
from_ordinal n $ (1 + (FToI (floor ( unif * i_to_f ((size n) - 1)))))
from_ordinal n $ unsafe_i_to_n $
(1 + (f_to_i (floor ( unif * i_to_f ((n_to_i $ size n) - 1)))))

Vocab = Fin 6
position = Fin 3
Expand Down
40 changes: 21 additions & 19 deletions examples/fluidsim.dx
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ import plot

def wrapidx (n:Type) [Ix n] (i:Int) : n =
-- Index wrapping around at ends.
asidx $ mod i $ size n
asidx $ unsafe_i_to_n $ mod i $ n_to_i $ size n

def incwrap {n} [Ix n] (i:n) : n = -- Increment index, wrapping around at ends.
asidx $ mod ((ordinal i) + 1) $ size n
asidx $ unsafe_i_to_n $ mod ((n_to_i $ ordinal i) + 1) $ n_to_i $ size n

def decwrap {n} [Ix n] (i:n) : n = -- Decrement index, wrapping around at ends.
asidx $ mod ((ordinal i) - 1) $ size n
asidx $ unsafe_i_to_n $ mod (n_to_i (ordinal i) - 1) $ n_to_i $ size n

def finite_difference_neighbours {n a} [Add a] (x:n=>a) : n=>a =
def finite_difference_neighbours {n a} [Sub a, Add a] (x:n=>a) : n=>a =
for i. x.(incwrap i) - x.(decwrap i)

def add_neighbours {n a} [Add a] (x:n=>a) : n=>a =
Expand All @@ -27,13 +27,13 @@ def apply_along_axis1 {b c a} (f:b=>a -> b=>a) (x:b=>c=>a) : b=>c=>a =
def apply_along_axis2 {b c a} (f:c=>a -> c=>a) (x:b=>c=>a) : b=>c=>a =
for i. f x.i

def fdx {n m a} [Add a] (x:n=>m=>a) : (n=>m=>a) =
def fdx {n m a} [Sub a, Add a] (x:n=>m=>a) : (n=>m=>a) =
apply_along_axis1 finite_difference_neighbours x

def fdy {n m a} [Add a] (x:n=>m=>a) : (n=>m=>a) =
def fdy {n m a} [Sub a, Add a] (x:n=>m=>a) : (n=>m=>a) =
apply_along_axis2 finite_difference_neighbours x

def divergence {n m a} [Add a] (vx:n=>m=>a) (vy:n=>m=>a) : (n=>m=>a) =
def divergence {n m a} [Sub a, Add a] (vx:n=>m=>a) (vy:n=>m=>a) : (n=>m=>a) =
fdx vx + fdy vy

def add_neighbours_2d {n m a} [Add a] (x:n=>m=>a) : (n=>m=>a) =
Expand All @@ -44,7 +44,7 @@ def add_neighbours_2d {n m a} [Add a] (x:n=>m=>a) : (n=>m=>a) =
def project {n m a} [VSpace a] (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a =
-- Project the velocity field to be approximately mass-conserving,
-- using a few iterations of Gauss-Seidel.
h = 1.0 / i_to_f (size n)
h = 1.0 / n_to_f (size n)

-- unpack into two scalar fields
vx = for i j. v.i.j.(from_ordinal _ 0)
Expand All @@ -71,8 +71,8 @@ def advect {n m a} [VSpace a] (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a =
-- Move field f according to x and y velocities (u and v)
-- using an implicit Euler integrator.

cell_xs = linspace n 0.0 $ i_to_f (size n)
cell_ys = linspace m 0.0 $ i_to_f (size m)
cell_xs = linspace n 0.0 $ n_to_f (size n)
cell_ys = linspace m 0.0 $ n_to_f (size m)

for i j.
-- Location of source of flow for this cell. No meshgrid!
Expand All @@ -88,15 +88,15 @@ def advect {n m a} [VSpace a] (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a =
bottom_weight = center_ys - source_row

-- Cast back to indices, wrapping around edges.
l = wrapidx n (FToI source_col)
r = wrapidx n ((FToI source_col) + 1)
t = wrapidx m (FToI source_row)
b = wrapidx m ((FToI source_row) + 1)
l = wrapidx n (f_to_i source_col)
r = wrapidx n ((f_to_i source_col) + 1)
t = wrapidx m (f_to_i source_row)
b = wrapidx m ((f_to_i source_row) + 1)

-- A convex weighting of the 4 surrounding cells.
bilinear_interp right_weight bottom_weight f.l.t f.l.b f.r.t f.r.b

def fluidsim {n m a} [VSpace a] (num_steps: Int) (color_init: n=>m=>a)
def fluidsim {n m a} [VSpace a] (num_steps: Nat) (color_init: n=>m=>a)
(v: n=>m=>(Fin 2)=>Float) : (Fin num_steps)=>n=>m=>a =
with_state (color_init, v) \state.
for i:(Fin num_steps).
Expand All @@ -120,9 +120,11 @@ init_velocity = for i:N j:M k:(Fin 2).

-- Create diagonally-striped color pattern.
init_color = for i:N j:M.
r = b_to_f $ (sin $ (i_to_f $ (ordinal j) + (ordinal i)) / 8.0) > 0.0
b = b_to_f $ (sin $ (i_to_f $ (ordinal j) - (ordinal i)) / 6.0) > 0.0
g = b_to_f $ (sin $ (i_to_f $ (ordinal j) + (ordinal i)) / 4.0) > 0.0
i' = n_to_f $ ordinal i
j' = n_to_f $ ordinal j
r = b_to_f $ (sin $ (j' + i') / 8.0) > 0.0
b = b_to_f $ (sin $ (j' - i') / 6.0) > 0.0
g = b_to_f $ (sin $ (j' + i') / 4.0) > 0.0
[r, g, b]

-- Run fluid sim and plot it.
Expand All @@ -135,7 +137,7 @@ num_steps = 5
target = transpose init_color

-- This is partial
def last {n a} (xs:n=>a) : a = xs.((size n - 1)@_)
def last {n a} (xs:n=>a) : a = xs.((unsafe_nat_diff (size n) 1)@_)

def objective (v:N=>M=>(Fin 2)=>Float) : Float =
final_color = last $ fluidsim num_steps init_color v
Expand Down
6 changes: 4 additions & 2 deletions examples/manifold-gradients.dx
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def manifoldGrad
As a sense check we make sure that these functions give the same output as the
non-manifold versions on functions which are defined on $\mathbb{R}^n$.

x = for i:(Fin 10). i_to_f (ordinal i) / 10.
x = for i:(Fin 10). n_to_f (ordinal i) / 10.

def myFunc {n} (x : (n => Float)) : Float =
sum $ for i. if (mod (ordinal i) 2 == 0) then (exp x.i) else (sin x.i)
Expand Down Expand Up @@ -276,9 +276,11 @@ quatIdent = Q{x=0.0, y=0.0, z=0.0, w=1.0}

instance Add Quaternion
add = \ x y. quatFromVec $ (quatAsVec x) + quatAsVec y
sub = \ x y. quatFromVec $ (quatAsVec x) - quatAsVec y
zero = quatFromVec $ zero

instance Sub Quaternion
sub = \ x y. quatFromVec $ (quatAsVec x) - quatAsVec y

' Scaling a quaternion does not affect the rotation it represents, but is still
useful for numerical computations.

Expand Down
10 changes: 5 additions & 5 deletions examples/mcmc.dx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ LogProb : Type = Float
def runChain {a}
(initialize: Key -> a)
(step: Key -> a -> a)
(numSamples: Int)
(numSamples: Nat)
(k:Key)
: Fin numSamples => a =
[k1, k2] = split_key k
Expand All @@ -29,10 +29,10 @@ def propose {a}
select accept proposal cur

def meanAndCovariance {n d} (xs:n=>d=>Float) : (d=>Float & d=>d=>Float) =
xsMean : d=>Float = (for i. sum for j. xs.j.i) / i_to_f (size n)
xsMean : d=>Float = (for i. sum for j. xs.j.i) / n_to_f (size n)
xsCov : d=>d=>Float = (for i i'. sum for j.
(xs.j.i' - xsMean.i') *
(xs.j.i - xsMean.i ) ) / i_to_f (size n - 1)
(xs.j.i - xsMean.i ) ) / (n_to_f (size n) - 1)
(xsMean, xsCov)

'## Metropolis-Hastings implementation
Expand All @@ -51,7 +51,7 @@ def mhStep {d} [Ix d]

'## HMC implementation

HMCParams : Type = (Int & Float) -- leapfrog steps, step size
HMCParams : Type = (Nat & Float) -- leapfrog steps, step size

def leapfrogIntegrate {a}
[VSpace a]
Expand Down Expand Up @@ -87,7 +87,7 @@ def myLogProb (x:(Fin 2)=>Float) : LogProb =
x' = x - [1.5, 2.5]
neg $ 0.5 * inner x' [[1.,0.],[0.,20.]] x'

numSamples =
numSamples : Nat =
if dex_test_mode ()
then 1000
else 10000
Expand Down
6 changes: 3 additions & 3 deletions examples/particle-filter.dx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def sample {a} (d: Distribution a) (k: Key) : a =
(sampler, _) = d
sampler k

def simulate {s v} (model: Model s v) (t: Int) (key: Key) : Fin t=>(s & v) =
def simulate {s v} (model: Model s v) (t: Nat) (key: Key) : Fin t=>(s & v) =
(init, dynamics, observe) = model
[key, subkey] = split_key key
s0 = sample init subkey
Expand All @@ -25,7 +25,7 @@ def simulate {s v} (model: Model s v) (t: Int) (key: Key) : Fin t=>(s & v) =
(s, v)

def particleFilter {s a v}
(num_particles: Int) (num_timesteps: Int)
(num_particles: Nat) (num_timesteps: Nat)
(model: Model s v)
(summarize: (Fin num_particles => s) -> a)
(obs: Fin num_timesteps=>v)
Expand Down Expand Up @@ -60,7 +60,7 @@ timesteps = 10
num_particles = 10000

truth = for i:(Fin timesteps).
s = i_to_f (ordinal i)
s = n_to_f (ordinal i)
(s, sample (normalDistn s 1.0) $ ixkey (new_key 0) i)

filtered = particleFilter num_particles _ gaussModel mean (map snd truth) (new_key 0)
Expand Down
4 changes: 2 additions & 2 deletions examples/particle-swarm-optimizer.dx
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ We have **arguments**:

def optimize
{d}
(np':Int) -- number of particles
(niter:Int) -- number of iterations
(np':Nat) -- number of particles
(niter:Nat) -- number of iterations
(key:Key) -- random seed
(f:(d=>Float)->Float) -- function to optimize
((lb,ub):(d=>Float & d=>Float)) -- bounds
Expand Down
2 changes: 1 addition & 1 deletion examples/pi.dx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def estimatePiAvgVal (key:Key) : Float =
x = rand key
4.0 * sqrt (1.0 - sq x)

def meanAndStdDev (n:Int) (f: Key -> Float) (key:Key) : (Float & Float) =
def meanAndStdDev (n:Nat) (f: Key -> Float) (key:Key) : (Float & Float) =
samps = for i:(Fin n). many f key i
(mean samps, std samps)

Expand Down
4 changes: 3 additions & 1 deletion examples/quaternions.dx
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ Scaling a quaternion is multiplying each of its components by a real number (the

instance Add Quaternion
add = \ x y. from_vec $ (as_vec x) + as_vec y
sub = \ x y. from_vec $ (as_vec x) - as_vec y
zero = from_vec $ zero

instance Sub Quaternion
sub = \ x y. from_vec $ (as_vec x) - as_vec y

instance VSpace Quaternion
scale_vec = \ s x. from_vec (s .* as_vec x)

Expand Down
20 changes: 10 additions & 10 deletions examples/raytrace.dx
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import plot
' ### Generic Helper Functions
Some of these should probably go in prelude.

def Vec (n:Int) : Type = Fin n => Float
def Mat (n:Int) (m:Int) : Type = Fin n => Fin m => Float
def Vec (n:Nat) : Type = Fin n => Float
def Mat (n:Nat) (m:Nat) : Type = Fin n => Fin m => Float

def relu (x:Float) : Float = max x 0.0
def length {d} (x: d=>Float) : Float = sqrt $ sum for i. sq x.i
Expand All @@ -25,10 +25,10 @@ def directionAndLength {d} (x: d=>Float) : (d=>Float & Float) =
def randuniform (lower:Float) (upper:Float) (k:Key) : Float =
lower + (rand k) * (upper - lower)

def sampleAveraged {a} [VSpace a] (sample:Key -> a) (n:Int) (k:Key) : a =
def sampleAveraged {a} [VSpace a] (sample:Key -> a) (n:Nat) (k:Key) : a =
yield_state zero \total.
for i:(Fin n).
total := get total + sample (ixkey k i) / i_to_f n
total := get total + sample (ixkey k i) / n_to_f n

def positiveProjection {n} (x:n=>Float) (y:n=>Float) : Bool = dot x y > 0.0

Expand All @@ -41,7 +41,7 @@ def cross (a:Vec 3) (b:Vec 3) : Vec 3 =

-- TODO: Use `data Color = Red | Green | Blue` and ADTs for index sets
data Image =
MkImage height:Int width:Int (Fin height => Fin width => Color)
MkImage height:Nat width:Nat (Fin height => Fin width => Color)

xHat : Vec 3 = [1., 0., 0.]
yHat : Vec 3 = [0., 1., 0.]
Expand Down Expand Up @@ -113,8 +113,8 @@ Filter = Color
-- TODO: use a record
-- num samples, num bounces, share seed?
Params = {
numSamples : Int
& maxBounces : Int
numSamples : Nat
& maxBounces : Nat
& shareSeed : Bool }

-- TODO: use a list instead, once they work
Expand Down Expand Up @@ -248,16 +248,16 @@ def trace {n} (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color =

-- Assumes we're looking towards -z.
Camera = {
numPix : Int
numPix : Nat
& pos : Position -- pinhole position
& halfWidth : Float -- sensor half-width
& sensorDist : Float } -- pinhole-sensor distance

-- TODO: might be better with an anonymous dependent pair for the result
def cameraRays (n:Int) (camera:Camera) : Fin n => Fin n => (Key -> Ray) =
def cameraRays (n:Nat) (camera:Camera) : Fin n => Fin n => (Key -> Ray) =
-- images indexed from top-left
halfWidth = get_at #halfWidth camera
pixHalfWidth = halfWidth / i_to_f n
pixHalfWidth = halfWidth / n_to_f n
ys = reverse $ linspace (Fin n) (neg halfWidth) halfWidth
xs = linspace (Fin n) (neg halfWidth) halfWidth
for i j. \key.
Expand Down
2 changes: 1 addition & 1 deletion examples/regression.dx
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def regress {d n} (featurize: Float -> d=>Float) (xRaw:n=>Float) (y:n=>Float) :
'Fit a third-order polynomial

def poly {d} [Ix d] (x:Float) : d=>Float =
for i. pow x (i_to_f (ordinal i))
for i. pow x (n_to_f (ordinal i))

params : (Fin 4)=>Float = regress poly xs ys

Expand Down
Loading

0 comments on commit 10d4346

Please sign in to comment.