diff --git a/README.md b/README.md index 4babafbe8..fe9f6701d 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ or these example programs: * [Basis function regression](https://google-research.github.io/dex-lang/examples/regression.html) * [Brownian bridge](https://google-research.github.io/dex-lang/examples/brownian_motion.html) * [Dynamic programming (Levenshtein distance)](https://google-research.github.io/dex-lang/examples/levenshtein-distance.html) + * [Molecular dynamics simulation](https://google-research.github.io/dex-lang/examples/md.html) Or for a more comprehensive look, there's diff --git a/examples/md.dx b/examples/md.dx index dea1492c9..ebadae8da 100644 --- a/examples/md.dx +++ b/examples/md.dx @@ -4,7 +4,7 @@ import plot 'This is more-or-less a port of Jax MD into Dex to see how molecular dynamics looks in Dex. For now, the structure of the two implementations is pretty close. However, details -look different and the evolution will depend on what is or is not ergonomic in dex vs jax. +look different. '## Math @@ -189,6 +189,7 @@ L_small = box_size_at_number_density (n_to_i N_small) 1.2 (n_to_i d) -- We will simulate in a box of this side length L_small +> 20.412415 R_init_small = rand_mat N_small d (\k. L_small * rand k) (new_key 0) @@ -196,6 +197,10 @@ R_init_small = rand_mat N_small d (\k. L_small * rand k) (new_key 0) %time :html render_svg (draw_system 0.5 R_init_small) ((0.0, 0.0), (L_small, L_small)) +> +> +> Compile time: 277.664 ms +> Run time: 15.644 ms 'Define energy function. Note the `preiodic_displacement`, which means our system will be evolving on a torus. @@ -206,8 +211,10 @@ def energy {n d} (pos: n=>d=>Float) : Float 'Here's the initial energy we compute for our system. :t energy R_init_small +> Float32 energy R_init_small +> 74.689995 'Initialize a simulation @@ -217,6 +224,7 @@ state_small = fire_descent_init 0.1 0.1 energy R_init_small initial, as expected: energy $ get_position $ fire_descent_step free_shift energy state_small +> 71.78412 'Now we can test that our code basically works by running 100 steps of minimization. @@ -225,16 +233,27 @@ minimization. (state_small', energies) = scan state_small \i:(Fin 100) s. s' = fire_descent_step (periodic_shift L_small) energy s (s', energy $ get_position s') +> +> Compile time: 525.343 ms +> Run time: 1.349 s 'Here's how the energy decreases over time. %time :html show_plot $ y_plot energies +> +> +> Compile time: 690.704 ms +> Run time: 2.965 ms 'Here's what the system looks like after minimization. %time :html render_svg (draw_system 0.5 (get_position state_small')) ((0.0, 0.0), (L_small, L_small)) +> +> +> Compile time: 276.340 ms +> Run time: 13.222 ms '## Neighbors optimization @@ -352,14 +371,19 @@ bucket_size = 10 %time tbl = cell_table grid_size bucket_size cell_size $ get_position state_small' +> +> Compile time: 118.997 ms +> Run time: 22.975 us 'We have a table of cells with atoms in them :t tbl +> (((Fin 2) => Fin 20) => (Fin 10 & ((Fin 10) => Fin 500))) 'And here are the atoms in the 0th cell. as_list tbl.(unsafe_from_ordinal _ 0) +> (AsList 2 [(33@(Fin 500)), (237@(Fin 500))]) '### Now let's compute pairs of neighbors from our cell list We'll specialize to two dimensions for now, but broadening to @@ -369,6 +393,7 @@ cell_neighbors_2d = [[-1, -1], [-1, 0], [-1, 1], [0, -1], [0, 0], [0, 1], [1, -1], [1, 0], [1, 1]] :t cell_neighbors_2d +> ((Fin 9) => (Fin 2) => Int32) -- Toroidal index offsetting def torus_offset {n} (ix: (Fin n)) (offset: Int) : (Fin n) = @@ -435,11 +460,15 @@ def periodic_near {atom_ix} %time res = (neighbor_list 4000 tbl (periodic_near 1.0 L_small $ get_position state_small') $ get_position state_small') +> +> Compile time: 95.243 ms +> Run time: 130.425 us 'In that configuration, we find this many pairs of neighbors: (AsList k _) = as_list res k +> 3090 'Now that we have the concept of neighbor lists, we cen define a variant of `pair_energy` that only considers atoms that the neighbor @@ -469,8 +498,10 @@ def energy_nl {n d} original, fully pairwise energy function. energy_nl L_small (as_list res) (get_position state_small') +> 1.230931 energy (get_position state_small') +> 1.230932 -- Package the above up into a function that just computes the -- neighbor list from an array of atoms. @@ -495,7 +526,9 @@ state_nl = fire_descent_init 0.1 0.1 energy_func R_init_small energy R_init_small +> 74.689995 energy $ get_position $ fire_descent_step free_shift energy state_small +> 71.78412 -- A helper for short-circuiting `any` computation def fast_any {n eff} [Ix n] (f: n -> {|eff} Bool) : {|eff} Bool = @@ -547,19 +580,52 @@ def simulate {atom_ix} %time (state_nl', energies_nl) = unsafe_io do simulate (periodic_displacement L_small) 0.5 L_small (Fin 100) state_small +> 4568 initial neighbor list size +> 4614 new neighbor list size +> 4564 new neighbor list size +> 4376 new neighbor list size +> 4346 new neighbor list size +> 4266 new neighbor list size +> 4178 new neighbor list size +> 4140 new neighbor list size +> 4100 new neighbor list size +> 4028 new neighbor list size +> 4006 new neighbor list size +> 3922 new neighbor list size +> 3868 new neighbor list size +> 3810 new neighbor list size +> 3762 new neighbor list size +> 3720 new neighbor list size +> 3656 new neighbor list size +> 3640 new neighbor list size +> 3602 new neighbor list size +> 3572 new neighbor list size +> 3554 new neighbor list size +> +> Compile time: 1.604 s +> Run time: 41.986 ms %time :html show_plot $ y_plot energies_nl +> +> +> Compile time: 674.513 ms +> Run time: 3.613 ms %time :html render_svg (draw_system 0.5 (get_position state_nl')) ((0.0, 0.0), (L_small, L_small)) +> +> +> Compile time: 278.686 ms +> Run time: 15.103 ms 'But of course the point of the exercise is that this now scales up to larger systems because it avoids the quadratic energy computation. -N_large = 50000 +N_large = if not (dex_test_mode ()) then 50000 else 500 L_large = box_size_at_number_density (n_to_i N_large) 1.2 (n_to_i d) L_large +> 20.412415 R_init_large = rand_mat N_large d (\k. L_large * rand k) (new_key 0) @@ -567,6 +633,10 @@ R_init_large = rand_mat N_large d (\k. L_large * rand k) (new_key 0) %time :html render_svg (draw_system 0.2 R_init_large) ((0.0, 0.0), (L_large, L_large)) +> +> +> Compile time: 298.110 ms +> Run time: 15.438 ms state_large = energy_func = (energy_nl L_large $ just_neighbor_list 1.0 L_large R_init_large) @@ -575,14 +645,46 @@ state_large = %time (state_large_nl', energies_large_nl) = unsafe_io do simulate (periodic_displacement L_large) 0.5 L_large (Fin 100) state_large - -'Eneregy decrease +> 4568 initial neighbor list size +> 4614 new neighbor list size +> 4564 new neighbor list size +> 4376 new neighbor list size +> 4346 new neighbor list size +> 4266 new neighbor list size +> 4178 new neighbor list size +> 4140 new neighbor list size +> 4100 new neighbor list size +> 4028 new neighbor list size +> 4006 new neighbor list size +> 3922 new neighbor list size +> 3868 new neighbor list size +> 3810 new neighbor list size +> 3762 new neighbor list size +> 3720 new neighbor list size +> 3656 new neighbor list size +> 3640 new neighbor list size +> 3602 new neighbor list size +> 3572 new neighbor list size +> 3554 new neighbor list size +> +> Compile time: 1.609 s +> Run time: 18.003 ms + +'Energy decrease %time :html show_plot $ y_plot energies_large_nl +> +> +> Compile time: 722.463 ms +> Run time: 3.417 ms 'System state after minimization. %time :html render_svg (draw_system 0.2 (get_position state_large_nl')) ((0.0, 0.0), (L_large, L_large)) +> +> +> Compile time: 288.290 ms +> Run time: 15.147 ms diff --git a/makefile b/makefile index 99fbd9704..cdc2fb945 100644 --- a/makefile +++ b/makefile @@ -209,7 +209,7 @@ example-names := \ isomorphisms fluidsim \ sgd psd kernelregression nn \ quaternions manifold-gradients schrodinger tutorial \ - latex linear-maps dither mcts + latex linear-maps dither mcts md # TODO: re-enable # fft vega-plotting