Skip to content

Commit

Permalink
Added abstraction layer for boundary condition and the capability to …
Browse files Browse the repository at this point in the history
…add profiles to boundary conditions
  • Loading branch information
mehdiataei committed Nov 28, 2024
1 parent 2b6355b commit 62f9ba0
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 69 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- XLB is now installable via pip
- Complete rewrite of the codebase for better modularity and extensibility based on "Operators" design pattern
- Added NVIDIA's Warp backend for state-of-the-art performance
- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data
- Added the capability to add profiles to boundary conditions
55 changes: 54 additions & 1 deletion examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import numpy as np
import jax.numpy as jnp
import time
from functools import partial
from jax import jit


class FlowOverSphere:
Expand All @@ -37,6 +39,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape)
self.stepper = None
self.boundary_conditions = []
self.u_max = 0.04

# Setup the simulation BC, its initial conditions, and the stepper
self._setup(omega)
Expand Down Expand Up @@ -69,7 +72,7 @@ def define_boundary_indices(self):

def setup_boundary_conditions(self):
inlet, outlet, walls, sphere = self.define_boundary_indices()
bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet)
bc_left = RegularizedBC("velocity", indices=inlet)
# bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet)
bc_walls = FullwayBounceBackBC(indices=walls)
# bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet)
Expand All @@ -95,8 +98,58 @@ def initialize_fields(self):
def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK")

def bc_profile(self):
u_max = self.u_max # u_max = 0.04
# Get the grid dimensions for the y and z directions
H_y = float(self.grid_shape[1] - 1) # Height in y direction
H_z = float(self.grid_shape[2] - 1) # Height in z direction

@wp.func
def bc_profile_warp(index: wp.vec3i):
# Poiseuille flow profile: parabolic velocity distribution
y = self.precision_policy.store_precision.wp_dtype(index[1])
z = self.precision_policy.store_precision.wp_dtype(index[2])

# Calculate normalized distance from center
y_center = y - (H_y / 2.0)
z_center = z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0

# Parabolic profile: u = u_max * (1 - r²)
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), 0.0, 0.0, 0.0, 0.0, length=5)
# return u_max

# @partial(jit, inline=True)
def bc_profile_jax():
y = jnp.arange(self.grid_shape[1])
z = jnp.arange(self.grid_shape[2])
Y, Z = jnp.meshgrid(y, z, indexing="ij")

# Calculate normalized distance from center
y_center = Y - (H_y / 2.0)
z_center = Z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0

# Parabolic profile for x velocity, zero for y and z
u_x = u_max * jnp.maximum(0.0, 1.0 - r_squared)
u_y = jnp.zeros_like(u_x)
u_z = jnp.zeros_like(u_x)

return jnp.stack([u_x, u_y, u_z])

if self.backend == ComputeBackend.JAX:
return bc_profile_jax
elif self.backend == ComputeBackend.WARP:
return bc_profile_warp

def initialize_bc_aux_data(self):
for bc in self.boundary_conditions:
if bc.needs_aux_init:
self.f_0, self.f_1 = bc.aux_data_init(self.bc_profile(), self.f_0, self.f_1, self.bc_mask, self.missing_mask)

def run(self, num_steps, post_process_interval=100):
start_time = time.time()
self.initialize_bc_aux_data()
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.f_1, self.f_0
Expand Down
1 change: 0 additions & 1 deletion examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def define_boundary_indices(self):
def setup_boundary_conditions(self):
inlet, outlet, walls, car = self.define_boundary_indices()
bc_left = EquilibriumBC(rho=1.0, u=(self.wind_speed, 0.0, 0.0), indices=inlet)
# bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet)
bc_walls = FullwayBounceBackBC(indices=walls)
bc_do_nothing = ExtrapolationOutflowBC(indices=outlet)
# bc_car = HalfwayBounceBackBC(mesh_vertices=car)
Expand Down
15 changes: 9 additions & 6 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,15 @@ def __init__(
mesh_vertices,
)

# Set the flag for auxilary data recovery
self.needs_aux_recovery = True

# find and store the normal vector using indices
self._get_normal_vec(indices)

# Unpack the two warp functionals needed for this BC!
if self.compute_backend == ComputeBackend.WARP:
self.warp_functional, self.prepare_bc_auxilary_data = self.warp_functional
self.warp_functional, self.update_bc_auxilary_data = self.warp_functional

def _get_normal_vec(self, indices):
# Get the frequency count and most common element directly
Expand Down Expand Up @@ -92,9 +95,9 @@ def _roll(self, fld, vec):
return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3))

@partial(jit, static_argnums=(0,), inline=True)
def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):
def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):
"""
Prepare the auxilary distribution functions for the boundary condition.
Update the auxilary distribution functions for the boundary condition.
Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision
"""
sound_speed = 1.0 / jnp.sqrt(3.0)
Expand Down Expand Up @@ -171,7 +174,7 @@ def functional(
return _f

@wp.func
def prepare_bc_auxilary_data(
def update_bc_auxilary_data(
index: Any,
timestep: Any,
missing_mask: Any,
Expand All @@ -180,7 +183,7 @@ def prepare_bc_auxilary_data(
f_pre: Any,
f_post: Any,
):
# Preparing the formulation for this BC using the neighbour's populations stored in f_aux and
# Update the auxilary data for this BC using the neighbour's populations stored in f_aux and
# f_pre (post-streaming values of the current voxel). We use directions that leave the domain
# for storing this prepared data.
_f = f_post
Expand All @@ -199,7 +202,7 @@ def prepare_bc_auxilary_data(

kernel = self._construct_kernel(functional)

return (functional, prepare_bc_auxilary_data), kernel
return (functional, update_bc_auxilary_data), kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
Expand Down
22 changes: 16 additions & 6 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class RegularizedBC(ZouHeBC):
def __init__(
self,
bc_type,
prescribed_value,
velocity_set: VelocitySet = None,
precision_policy: PrecisionPolicy = None,
compute_backend: ComputeBackend = None,
Expand All @@ -54,7 +53,6 @@ def __init__(
# Call the parent constructor
super().__init__(
bc_type,
prescribed_value,
velocity_set,
precision_policy,
compute_backend,
Expand Down Expand Up @@ -127,15 +125,11 @@ def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
_d = self.velocity_set.d
_q = self.velocity_set.q
u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d
rho = self.prescribed_value if self.bc_type == "pressure" else 0.0

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
# compute Qi tensor and store it in self
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_rho = self.compute_dtype(rho)
_u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1])
_opp_indices = self.velocity_set.opp_indices
_w = self.velocity_set.w
_c = self.velocity_set.c
Expand Down Expand Up @@ -222,6 +216,15 @@ def functional_velocity(
# Find normal vector
normals = get_normal_vectors(missing_mask)

# Find the value of u from the missing directions
for l in range(wp.static(_q)):
# Since we are only considering normal velocity, we only need to find one value
if missing_mask[l] == wp.uint8(1):
# Create velocity vector by multiplying the prescribed value with the normal vector
prescribed_value = f_1[_opp_indices[l], index[0], index[1], index[2]]
_u = -prescribed_value * normals
break

# calculate rho
fsum = _get_fsum(_f, missing_mask)
unormal = self.compute_dtype(0.0)
Expand Down Expand Up @@ -253,6 +256,13 @@ def functional_pressure(
# Find normal vector
normals = get_normal_vectors(missing_mask)

# Find the value of rho from the missing directions
for q in range(wp.static(_q)):
# Since we need only one scalar value, we only need to find one value
if missing_mask[q] == wp.uint8(1):
_rho = f_0[_opp_indices[q], index[0], index[1], index[2]]
break

# calculate velocity
fsum = _get_fsum(_f, missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
Expand Down
Loading

0 comments on commit 62f9ba0

Please sign in to comment.