Skip to content

Commit

Permalink
Wire up the molecular dynamics demo to CI and advertising in the README.
Browse files Browse the repository at this point in the history
  • Loading branch information
axch committed Nov 11, 2022
1 parent 7cf290e commit bea0316
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ or these example programs:
* [Basis function regression](https://google-research.github.io/dex-lang/examples/regression.html)
* [Brownian bridge](https://google-research.github.io/dex-lang/examples/brownian_motion.html)
* [Dynamic programming (Levenshtein distance)](https://google-research.github.io/dex-lang/examples/levenshtein-distance.html)
* [Molecular dynamics simulation](https://google-research.github.io/dex-lang/examples/md.html)

Or for a more comprehensive look, there's

Expand Down
110 changes: 106 additions & 4 deletions examples/md.dx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import plot

'This is more-or-less a port of Jax MD into Dex to see how molecular dynamics looks in
Dex. For now, the structure of the two implementations is pretty close. However, details
look different and the evolution will depend on what is or is not ergonomic in dex vs jax.
look different.

'## Math

Expand Down Expand Up @@ -189,13 +189,18 @@ L_small = box_size_at_number_density (n_to_i N_small) 1.2 (n_to_i d)

-- We will simulate in a box of this side length
L_small
> 20.412415

R_init_small = rand_mat N_small d (\k. L_small * rand k) (new_key 0)

'The initial state of our random system

%time
:html render_svg (draw_system 0.5 R_init_small) ((0.0, 0.0), (L_small, L_small))
> <html output>
>
> Compile time: 277.664 ms
> Run time: 15.644 ms

'Define energy function. Note the `preiodic_displacement`, which means
our system will be evolving on a torus.
Expand All @@ -206,8 +211,10 @@ def energy {n d} (pos: n=>d=>Float) : Float
'Here's the initial energy we compute for our system.

:t energy R_init_small
> Float32

energy R_init_small
> 74.689995

'Initialize a simulation

Expand All @@ -217,6 +224,7 @@ state_small = fire_descent_init 0.1 0.1 energy R_init_small
initial, as expected:

energy $ get_position $ fire_descent_step free_shift energy state_small
> 71.78412

'Now we can test that our code basically works by running 100 steps of
minimization.
Expand All @@ -225,16 +233,27 @@ minimization.
(state_small', energies) = scan state_small \i:(Fin 100) s.
s' = fire_descent_step (periodic_shift L_small) energy s
(s', energy $ get_position s')
>
> Compile time: 525.343 ms
> Run time: 1.349 s

'Here's how the energy decreases over time.

%time
:html show_plot $ y_plot energies
> <html output>
>
> Compile time: 690.704 ms
> Run time: 2.965 ms

'Here's what the system looks like after minimization.

%time
:html render_svg (draw_system 0.5 (get_position state_small')) ((0.0, 0.0), (L_small, L_small))
> <html output>
>
> Compile time: 276.340 ms
> Run time: 13.222 ms

'## Neighbors optimization

Expand Down Expand Up @@ -352,14 +371,19 @@ bucket_size = 10

%time
tbl = cell_table grid_size bucket_size cell_size $ get_position state_small'
>
> Compile time: 118.997 ms
> Run time: 22.975 us

'We have a table of cells with atoms in them

:t tbl
> (((Fin 2) => Fin 20) => (Fin 10 & ((Fin 10) => Fin 500)))

'And here are the atoms in the 0th cell.

as_list tbl.(unsafe_from_ordinal _ 0)
> (AsList 2 [(33@(Fin 500)), (237@(Fin 500))])

'### Now let's compute pairs of neighbors from our cell list
We'll specialize to two dimensions for now, but broadening to
Expand All @@ -369,6 +393,7 @@ cell_neighbors_2d = [[-1, -1], [-1, 0], [-1, 1], [0, -1],
[0, 0], [0, 1], [1, -1], [1, 0], [1, 1]]

:t cell_neighbors_2d
> ((Fin 9) => (Fin 2) => Int32)

-- Toroidal index offsetting
def torus_offset {n} (ix: (Fin n)) (offset: Int) : (Fin n) =
Expand Down Expand Up @@ -435,11 +460,15 @@ def periodic_near {atom_ix}
%time
res = (neighbor_list 4000 tbl
(periodic_near 1.0 L_small $ get_position state_small') $ get_position state_small')
>
> Compile time: 95.243 ms
> Run time: 130.425 us

'In that configuration, we find this many pairs of neighbors:

(AsList k _) = as_list res
k
> 3090

'Now that we have the concept of neighbor lists, we cen define a
variant of `pair_energy` that only considers atoms that the neighbor
Expand Down Expand Up @@ -469,8 +498,10 @@ def energy_nl {n d}
original, fully pairwise energy function.

energy_nl L_small (as_list res) (get_position state_small')
> 1.230931

energy (get_position state_small')
> 1.230932

-- Package the above up into a function that just computes the
-- neighbor list from an array of atoms.
Expand All @@ -495,7 +526,9 @@ state_nl =
fire_descent_init 0.1 0.1 energy_func R_init_small

energy R_init_small
> 74.689995
energy $ get_position $ fire_descent_step free_shift energy state_small
> 71.78412

-- A helper for short-circuiting `any` computation
def fast_any {n eff} [Ix n] (f: n -> {|eff} Bool) : {|eff} Bool =
Expand Down Expand Up @@ -547,26 +580,63 @@ def simulate {atom_ix}
%time
(state_nl', energies_nl) =
unsafe_io do simulate (periodic_displacement L_small) 0.5 L_small (Fin 100) state_small
> 4568 initial neighbor list size
> 4614 new neighbor list size
> 4564 new neighbor list size
> 4376 new neighbor list size
> 4346 new neighbor list size
> 4266 new neighbor list size
> 4178 new neighbor list size
> 4140 new neighbor list size
> 4100 new neighbor list size
> 4028 new neighbor list size
> 4006 new neighbor list size
> 3922 new neighbor list size
> 3868 new neighbor list size
> 3810 new neighbor list size
> 3762 new neighbor list size
> 3720 new neighbor list size
> 3656 new neighbor list size
> 3640 new neighbor list size
> 3602 new neighbor list size
> 3572 new neighbor list size
> 3554 new neighbor list size
>
> Compile time: 1.604 s
> Run time: 41.986 ms

%time
:html show_plot $ y_plot energies_nl
> <html output>
>
> Compile time: 674.513 ms
> Run time: 3.613 ms

%time
:html render_svg (draw_system 0.5 (get_position state_nl')) ((0.0, 0.0), (L_small, L_small))
> <html output>
>
> Compile time: 278.686 ms
> Run time: 15.103 ms

'But of course the point of the exercise is that this now scales up to
larger systems because it avoids the quadratic energy computation.

N_large = 50000
N_large = if not (dex_test_mode ()) then 50000 else 500
L_large = box_size_at_number_density (n_to_i N_large) 1.2 (n_to_i d)
L_large
> 20.412415

R_init_large = rand_mat N_large d (\k. L_large * rand k) (new_key 0)

'Initial state (we render the atoms smaller now so they don't over-plot too badly).

%time
:html render_svg (draw_system 0.2 R_init_large) ((0.0, 0.0), (L_large, L_large))
> <html output>
>
> Compile time: 298.110 ms
> Run time: 15.438 ms

state_large =
energy_func = (energy_nl L_large $ just_neighbor_list 1.0 L_large R_init_large)
Expand All @@ -575,14 +645,46 @@ state_large =
%time
(state_large_nl', energies_large_nl) =
unsafe_io do simulate (periodic_displacement L_large) 0.5 L_large (Fin 100) state_large

'Eneregy decrease
> 4568 initial neighbor list size
> 4614 new neighbor list size
> 4564 new neighbor list size
> 4376 new neighbor list size
> 4346 new neighbor list size
> 4266 new neighbor list size
> 4178 new neighbor list size
> 4140 new neighbor list size
> 4100 new neighbor list size
> 4028 new neighbor list size
> 4006 new neighbor list size
> 3922 new neighbor list size
> 3868 new neighbor list size
> 3810 new neighbor list size
> 3762 new neighbor list size
> 3720 new neighbor list size
> 3656 new neighbor list size
> 3640 new neighbor list size
> 3602 new neighbor list size
> 3572 new neighbor list size
> 3554 new neighbor list size
>
> Compile time: 1.609 s
> Run time: 18.003 ms

'Energy decrease

%time
:html show_plot $ y_plot energies_large_nl
> <html output>
>
> Compile time: 722.463 ms
> Run time: 3.417 ms

'System state after minimization.

%time
:html render_svg (draw_system 0.2 (get_position state_large_nl')) ((0.0, 0.0), (L_large, L_large))
> <html output>
>
> Compile time: 288.290 ms
> Run time: 15.147 ms

2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ example-names := \
isomorphisms fluidsim \
sgd psd kernelregression nn \
quaternions manifold-gradients schrodinger tutorial \
latex linear-maps dither mcts
latex linear-maps dither mcts md
# TODO: re-enable
# fft vega-plotting

Expand Down

0 comments on commit bea0316

Please sign in to comment.