Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added abstraction layer for boundary condition and the capability to add profiles to boundary conditions #86

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- XLB is now installable via pip
- Complete rewrite of the codebase for better modularity and extensibility based on "Operators" design pattern
- Added NVIDIA's Warp backend for state-of-the-art performance
- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data
- Added the capability to add profiles to boundary conditions
55 changes: 54 additions & 1 deletion examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import numpy as np
import jax.numpy as jnp
import time
from functools import partial
from jax import jit


class FlowOverSphere:
Expand All @@ -37,6 +39,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape)
self.stepper = None
self.boundary_conditions = []
self.u_max = 0.04

# Setup the simulation BC, its initial conditions, and the stepper
self._setup(omega)
Expand Down Expand Up @@ -69,7 +72,7 @@ def define_boundary_indices(self):

def setup_boundary_conditions(self):
inlet, outlet, walls, sphere = self.define_boundary_indices()
bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a 1-to-1 association between a bc_profile function and a bc object ? Right now, it is not straightforward for example to have two BC's with different profiles. If we add functional to the BC constructor then we would need another operator (like boundary masker) which goes over all BC's and initializes them similar to what initialize_bc_aux_data does right now. This could be done as the first step of the stepper (if iteration is zero). This way we can also remove initilize_bc_aux_data from examples.

Also with these changes, it is kind of difficult to assign constant values to a BC. Is there a simpler way to pass in constant values to BCs without having to define a bc_prfile with jax and warp implementations?

bc_left = RegularizedBC("velocity", indices=inlet)
# bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Equilibrium BC should also be able to accept variable values. Can you add that as well ?

bc_walls = FullwayBounceBackBC(indices=walls)
# bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet)
Expand All @@ -95,8 +98,58 @@ def initialize_fields(self):
def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK")

def bc_profile(self):
u_max = self.u_max # u_max = 0.04
# Get the grid dimensions for the y and z directions
H_y = float(self.grid_shape[1] - 1) # Height in y direction
H_z = float(self.grid_shape[2] - 1) # Height in z direction

@wp.func
def bc_profile_warp(index: wp.vec3i):
# Poiseuille flow profile: parabolic velocity distribution
y = self.precision_policy.store_precision.wp_dtype(index[1])
z = self.precision_policy.store_precision.wp_dtype(index[2])

# Calculate normalized distance from center
y_center = y - (H_y / 2.0)
z_center = z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0

# Parabolic profile: u = u_max * (1 - r²)
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), 0.0, 0.0, 0.0, 0.0, length=5)
# return u_max

# @partial(jit, inline=True)
def bc_profile_jax():
y = jnp.arange(self.grid_shape[1])
z = jnp.arange(self.grid_shape[2])
Y, Z = jnp.meshgrid(y, z, indexing="ij")

# Calculate normalized distance from center
y_center = Y - (H_y / 2.0)
z_center = Z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0

# Parabolic profile for x velocity, zero for y and z
u_x = u_max * jnp.maximum(0.0, 1.0 - r_squared)
u_y = jnp.zeros_like(u_x)
u_z = jnp.zeros_like(u_x)

return jnp.stack([u_x, u_y, u_z])

if self.backend == ComputeBackend.JAX:
return bc_profile_jax
elif self.backend == ComputeBackend.WARP:
return bc_profile_warp

def initialize_bc_aux_data(self):
for bc in self.boundary_conditions:
if bc.needs_aux_init:
self.f_0, self.f_1 = bc.aux_data_init(self.bc_profile(), self.f_0, self.f_1, self.bc_mask, self.missing_mask)

def run(self, num_steps, post_process_interval=100):
start_time = time.time()
self.initialize_bc_aux_data()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with some minor changes you will be able to make this a stand alone operator (like boundary_masker).

We can either call that operator inside the stepper stepper (if iteration=0) or leave it to the user to call it in the example.

for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.f_1, self.f_0
Expand Down
17 changes: 12 additions & 5 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class LidDrivenCavity2D:
def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy):
# initialize backend
xlb.init(
velocity_set=velocity_set,
Expand All @@ -29,6 +29,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape)
self.stepper = None
self.boundary_conditions = []
self.prescribed_vel = prescribed_vel

# Setup the simulation BC, its initial conditions, and the stepper
self._setup(omega)
Expand All @@ -49,7 +50,7 @@ def define_boundary_indices(self):

def setup_boundary_conditions(self):
lid, walls = self.define_boundary_indices()
bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid)
bc_top = EquilibriumBC(rho=1.0, u=(self.prescribed_vel, 0.0), indices=lid)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for EquilibriumBC in all examples.

bc_walls = HalfwayBounceBackBC(indices=walls)
self.boundary_conditions = [bc_walls, bc_top]

Expand Down Expand Up @@ -112,7 +113,13 @@ def post_process(self, i):
precision_policy = PrecisionPolicy.FP32FP32

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy)
simulation.run(num_steps=5000, post_process_interval=1000)
# Setting fluid viscosity and relaxation parameter.
Re = 200.0
prescribed_vel = 0.05
clength = grid_shape[0] - 1
visc = prescribed_vel * clength / Re
omega = 1.0 / (3.0 * visc + 0.5)

simulation = LidDrivenCavity2D(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
simulation.run(num_steps=50000, post_process_interval=1000)
16 changes: 11 additions & 5 deletions examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


class LidDrivenCavity2D_distributed(LidDrivenCavity2D):
def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
super().__init__(omega, grid_shape, velocity_set, backend, precision_policy)
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy):
super().__init__(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)

def setup_stepper(self, omega):
stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
Expand All @@ -29,7 +29,13 @@ def setup_stepper(self, omega):
precision_policy = PrecisionPolicy.FP32FP32

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = LidDrivenCavity2D_distributed(omega, grid_shape, velocity_set, backend, precision_policy)
simulation.run(num_steps=5000, post_process_interval=1000)
# Setting fluid viscosity and relaxation parameter.
Re = 200.0
prescribed_vel = 0.05
clength = grid_shape[0] - 1
visc = prescribed_vel * clength / Re
omega = 1.0 / (3.0 * visc + 0.5)

simulation = LidDrivenCavity2D_distributed(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
simulation.run(num_steps=50000, post_process_interval=1000)
1 change: 0 additions & 1 deletion examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def define_boundary_indices(self):
def setup_boundary_conditions(self):
inlet, outlet, walls, car = self.define_boundary_indices()
bc_left = EquilibriumBC(rho=1.0, u=(self.wind_speed, 0.0, 0.0), indices=inlet)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if velocity here is also not constant and has a profile. All these boundary conditions should be set up similarly. It would be nice if we could input the profile to the BC constrcutor.

# bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be really nice if we could define BCs like before (simply as constant inputs) as well as with a function as an input to the bc constructor: eg u = (0.04, 0, 0) or u = bc_prfile(). Something like this. I think it could be done in the operator I am proposing for initializing BC aux data.

bc_walls = FullwayBounceBackBC(indices=walls)
bc_do_nothing = ExtrapolationOutflowBC(indices=outlet)
# bc_car = HalfwayBounceBackBC(mesh_vertices=car)
Expand Down
10 changes: 6 additions & 4 deletions tests/boundary_conditions/mask/test_bc_indices_masker_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape):
[test_bc],
bc_mask,
missing_mask,
start_index=(0, 0, 0) if dim == 3 else (0, 0),
)
assert missing_mask.dtype == xlb.Precision.BOOL.wp_dtype

Expand All @@ -69,9 +68,12 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape):
bc_mask = bc_mask.numpy()
missing_mask = missing_mask.numpy()

assert bc_mask.shape == (1,) + grid_shape

assert missing_mask.shape == (velocity_set.q,) + grid_shape
if len(grid_shape) == 2:
assert bc_mask.shape == (1,) + grid_shape + (1,), "bc_mask shape is incorrect got {}".format(bc_mask.shape)
assert missing_mask.shape == (velocity_set.q,) + grid_shape + (1,), "missing_mask shape is incorrect got {}".format(missing_mask.shape)
else:
assert bc_mask.shape == (1,) + grid_shape, "bc_mask shape is incorrect got {}".format(bc_mask.shape)
assert missing_mask.shape == (velocity_set.q,) + grid_shape, "missing_mask shape is incorrect got {}".format(missing_mask.shape)

if dim == 2:
assert np.all(bc_mask[0, indices[0], indices[1]] == test_bc.id)
Expand Down
27 changes: 17 additions & 10 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,15 @@ def __init__(
mesh_vertices,
)

# Set the flag for auxilary data recovery
self.needs_aux_recovery = True

# find and store the normal vector using indices
self._get_normal_vec(indices)

# Unpack the two warp functionals needed for this BC!
if self.compute_backend == ComputeBackend.WARP:
self.warp_functional, self.prepare_bc_auxilary_data = self.warp_functional
self.warp_functional, self.update_bc_auxilary_data = self.warp_functional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this renaming.


def _get_normal_vec(self, indices):
# Get the frequency count and most common element directly
Expand Down Expand Up @@ -92,9 +95,9 @@ def _roll(self, fld, vec):
return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3))

@partial(jit, static_argnums=(0,), inline=True)
def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):
def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):
"""
Prepare the auxilary distribution functions for the boundary condition.
Update the auxilary distribution functions for the boundary condition.
Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision
"""
sound_speed = 1.0 / jnp.sqrt(3.0)
Expand Down Expand Up @@ -134,7 +137,6 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
def _construct_warp(self):
# Set local constants
sound_speed = self.compute_dtype(1.0 / wp.sqrt(3.0))
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_c = self.velocity_set.c
_q = self.velocity_set.q
_opp_indices = self.velocity_set.opp_indices
Expand All @@ -143,9 +145,14 @@ def _construct_warp(self):
def get_normal_vectors(
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l])
if wp.static(self.velocity_set.d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -wp.vec2i(_c[0, l], _c[1, l])

# Construct the functionals for this BC
@wp.func
Expand All @@ -167,7 +174,7 @@ def functional(
return _f

@wp.func
def prepare_bc_auxilary_data(
def update_bc_auxilary_data(
index: Any,
timestep: Any,
missing_mask: Any,
Expand All @@ -176,7 +183,7 @@ def prepare_bc_auxilary_data(
f_pre: Any,
f_post: Any,
):
# Preparing the formulation for this BC using the neighbour's populations stored in f_aux and
# Update the auxilary data for this BC using the neighbour's populations stored in f_aux and
# f_pre (post-streaming values of the current voxel). We use directions that leave the domain
# for storing this prepared data.
_f = f_post
Expand All @@ -195,7 +202,7 @@ def prepare_bc_auxilary_data(

kernel = self._construct_kernel(functional)

return (functional, prepare_bc_auxilary_data), kernel
return (functional, update_bc_auxilary_data), kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
Expand Down
34 changes: 24 additions & 10 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class RegularizedBC(ZouHeBC):
def __init__(
self,
bc_type,
prescribed_value,
velocity_set: VelocitySet = None,
precision_policy: PrecisionPolicy = None,
compute_backend: ComputeBackend = None,
Expand All @@ -54,7 +53,6 @@ def __init__(
# Call the parent constructor
super().__init__(
bc_type,
prescribed_value,
velocity_set,
precision_policy,
compute_backend,
Expand Down Expand Up @@ -127,16 +125,11 @@ def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
_d = self.velocity_set.d
_q = self.velocity_set.q
u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d
rho = self.prescribed_value if self.bc_type == "pressure" else 0.0

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
# compute Qi tensor and store it in self
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_rho = self.compute_dtype(rho)
_u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1])
_opp_indices = self.velocity_set.opp_indices
_w = self.velocity_set.w
_c = self.velocity_set.c
Expand All @@ -162,9 +155,14 @@ def _get_fsum(
def get_normal_vectors(
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
Expand Down Expand Up @@ -218,6 +216,15 @@ def functional_velocity(
# Find normal vector
normals = get_normal_vectors(missing_mask)

# Find the value of u from the missing directions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there not a performance hit here? The BC is not dynamic like Extrapolation Outflow where its information change in time. Here we have the same value of velocity or pressure being recomputed at each iteration over and over again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not recomputing here we're simply reading it. I didn't notice any performance hit before and after the PR.

An improvement for this could be that we can store the missing direction and avoid the for loop for slightly improved performance. I'll try to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this only works for linear/planar BCs like zouhe)

for l in range(wp.static(_q)):
# Since we are only considering normal velocity, we only need to find one value
if missing_mask[l] == wp.uint8(1):
# Create velocity vector by multiplying the prescribed value with the normal vector
prescribed_value = f_1[_opp_indices[l], index[0], index[1], index[2]]
_u = -prescribed_value * normals
break

# calculate rho
fsum = _get_fsum(_f, missing_mask)
unormal = self.compute_dtype(0.0)
Expand Down Expand Up @@ -249,6 +256,13 @@ def functional_pressure(
# Find normal vector
normals = get_normal_vectors(missing_mask)

# Find the value of rho from the missing directions
for q in range(wp.static(_q)):
# Since we need only one scalar value, we only need to find one value
if missing_mask[q] == wp.uint8(1):
_rho = f_0[_opp_indices[q], index[0], index[1], index[2]]
break

# calculate velocity
fsum = _get_fsum(_f, missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
Expand Down
Loading
Loading