diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d036de..69ad1f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,3 +18,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - XLB is now installable via pip - Complete rewrite of the codebase for better modularity and extensibility based on "Operators" design pattern - Added NVIDIA's Warp backend for state-of-the-art performance +- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data +- Added the capability to add profiles to boundary conditions \ No newline at end of file diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 1b5905e..fe1ca50 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -19,6 +19,8 @@ import numpy as np import jax.numpy as jnp import time +from functools import partial +from jax import jit class FlowOverSphere: @@ -37,6 +39,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] + self.u_max = 0.04 # Setup the simulation BC, its initial conditions, and the stepper self._setup(omega) @@ -69,7 +72,7 @@ def define_boundary_indices(self): def setup_boundary_conditions(self): inlet, outlet, walls, sphere = self.define_boundary_indices() - bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet) + bc_left = RegularizedBC("velocity", indices=inlet) # bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) # bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet) @@ -95,8 +98,58 @@ def initialize_fields(self): def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK") + def bc_profile(self): + u_max = self.u_max # u_max = 0.04 + # Get the grid dimensions for the y and z directions + H_y = float(self.grid_shape[1] - 1) # Height in y direction + H_z = float(self.grid_shape[2] - 1) # Height in z direction + + @wp.func + def bc_profile_warp(index: wp.vec3i): + # Poiseuille flow profile: parabolic velocity distribution + y = self.precision_policy.store_precision.wp_dtype(index[1]) + z = self.precision_policy.store_precision.wp_dtype(index[2]) + + # Calculate normalized distance from center + y_center = y - (H_y / 2.0) + z_center = z - (H_z / 2.0) + r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0 + + # Parabolic profile: u = u_max * (1 - r²) + return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), 0.0, 0.0, 0.0, 0.0, length=5) + # return u_max + + # @partial(jit, inline=True) + def bc_profile_jax(): + y = jnp.arange(self.grid_shape[1]) + z = jnp.arange(self.grid_shape[2]) + Y, Z = jnp.meshgrid(y, z, indexing="ij") + + # Calculate normalized distance from center + y_center = Y - (H_y / 2.0) + z_center = Z - (H_z / 2.0) + r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0 + + # Parabolic profile for x velocity, zero for y and z + u_x = u_max * jnp.maximum(0.0, 1.0 - r_squared) + u_y = jnp.zeros_like(u_x) + u_z = jnp.zeros_like(u_x) + + return jnp.stack([u_x, u_y, u_z]) + + if self.backend == ComputeBackend.JAX: + return bc_profile_jax + elif self.backend == ComputeBackend.WARP: + return bc_profile_warp + + def initialize_bc_aux_data(self): + for bc in self.boundary_conditions: + if bc.needs_aux_init: + self.f_0, self.f_1 = bc.aux_data_init(self.bc_profile(), self.f_0, self.f_1, self.bc_mask, self.missing_mask) + def run(self, num_steps, post_process_interval=100): start_time = time.time() + self.initialize_bc_aux_data() for i in range(num_steps): self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 96c79f7..c159685 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -95,7 +95,6 @@ def define_boundary_indices(self): def setup_boundary_conditions(self): inlet, outlet, walls, car = self.define_boundary_indices() bc_left = EquilibriumBC(rho=1.0, u=(self.wind_speed, 0.0, 0.0), indices=inlet) - # bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) # bc_car = HalfwayBounceBackBC(mesh_vertices=car) diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 8a2a482..fa75490 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -56,12 +56,15 @@ def __init__( mesh_vertices, ) + # Set the flag for auxilary data recovery + self.needs_aux_recovery = True + # find and store the normal vector using indices self._get_normal_vec(indices) # Unpack the two warp functionals needed for this BC! if self.compute_backend == ComputeBackend.WARP: - self.warp_functional, self.prepare_bc_auxilary_data = self.warp_functional + self.warp_functional, self.update_bc_auxilary_data = self.warp_functional def _get_normal_vec(self, indices): # Get the frequency count and most common element directly @@ -92,9 +95,9 @@ def _roll(self, fld, vec): return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3)) @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): + def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): """ - Prepare the auxilary distribution functions for the boundary condition. + Update the auxilary distribution functions for the boundary condition. Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision """ sound_speed = 1.0 / jnp.sqrt(3.0) @@ -171,7 +174,7 @@ def functional( return _f @wp.func - def prepare_bc_auxilary_data( + def update_bc_auxilary_data( index: Any, timestep: Any, missing_mask: Any, @@ -180,7 +183,7 @@ def prepare_bc_auxilary_data( f_pre: Any, f_post: Any, ): - # Preparing the formulation for this BC using the neighbour's populations stored in f_aux and + # Update the auxilary data for this BC using the neighbour's populations stored in f_aux and # f_pre (post-streaming values of the current voxel). We use directions that leave the domain # for storing this prepared data. _f = f_post @@ -199,7 +202,7 @@ def prepare_bc_auxilary_data( kernel = self._construct_kernel(functional) - return (functional, prepare_bc_auxilary_data), kernel + return (functional, update_bc_auxilary_data), kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index d1abd4e..da6ba53 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -44,7 +44,6 @@ class RegularizedBC(ZouHeBC): def __init__( self, bc_type, - prescribed_value, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, @@ -54,7 +53,6 @@ def __init__( # Call the parent constructor super().__init__( bc_type, - prescribed_value, velocity_set, precision_policy, compute_backend, @@ -127,15 +125,11 @@ def _construct_warp(self): # assign placeholders for both u and rho based on prescribed_value _d = self.velocity_set.d _q = self.velocity_set.q - u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d - rho = self.prescribed_value if self.bc_type == "pressure" else 0.0 # Set local constants TODO: This is a hack and should be fixed with warp update # _u_vec = wp.vec(_d, dtype=self.compute_dtype) # compute Qi tensor and store it in self _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _rho = self.compute_dtype(rho) - _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) _opp_indices = self.velocity_set.opp_indices _w = self.velocity_set.w _c = self.velocity_set.c @@ -222,6 +216,15 @@ def functional_velocity( # Find normal vector normals = get_normal_vectors(missing_mask) + # Find the value of u from the missing directions + for l in range(wp.static(_q)): + # Since we are only considering normal velocity, we only need to find one value + if missing_mask[l] == wp.uint8(1): + # Create velocity vector by multiplying the prescribed value with the normal vector + prescribed_value = f_1[_opp_indices[l], index[0], index[1], index[2]] + _u = -prescribed_value * normals + break + # calculate rho fsum = _get_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) @@ -253,6 +256,13 @@ def functional_pressure( # Find normal vector normals = get_normal_vectors(missing_mask) + # Find the value of rho from the missing directions + for q in range(wp.static(_q)): + # Since we need only one scalar value, we only need to find one value + if missing_mask[q] == wp.uint8(1): + _rho = f_0[_opp_indices[q], index[0], index[1], index[2]] + break + # calculate velocity fsum = _get_fsum(_f, missing_mask) unormal = -self.compute_dtype(1.0) + fsum / _rho diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 3e89992..4637fbd 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -21,6 +21,7 @@ boundary_condition_registry, ) from xlb.operator.equilibrium import QuadraticEquilibrium +import jax class ZouHeBC(BoundaryCondition): @@ -38,7 +39,6 @@ class ZouHeBC(BoundaryCondition): def __init__( self, bc_type, - prescribed_value, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, @@ -50,7 +50,6 @@ def __init__( assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'." self.bc_type = bc_type self.equilibrium_operator = QuadraticEquilibrium() - self.prescribed_value = prescribed_value # Call the parent constructor super().__init__( @@ -62,11 +61,14 @@ def __init__( mesh_vertices, ) - # Set the prescribed value for pressure or velocity - dim = self.velocity_set.d - if self.compute_backend == ComputeBackend.JAX: - self.prescribed_value = jnp.atleast_1d(prescribed_value)[(slice(None),) + (None,) * dim] - # TODO: this won't work if the prescribed values are a profile with the length of bdry indices! + # This BC needs auxilary data initialization before streaming + self.needs_aux_init = True + + # This BC needs auxilary data recovery after streaming + self.needs_aux_recovery = True + + # This BC needs one auxilary data for the density or normal velocity + self.num_of_aux_data = 1 # This BC needs padding for finding missing directions when imposed on a geometry that is in the domain interior self.needs_padding = True @@ -87,20 +89,33 @@ def _get_normal_vec(self, missing_mask): @partial(jit, static_argnums=(0,), inline=True) def get_rho(self, fpop, missing_mask): if self.bc_type == "velocity": - vel = self.prescribed_value + target_shape = (self.velocity_set.d,) + fpop.shape[1:] + vel = self._broadcast_prescribed_values(self.prescribed_values, self.prescribed_values.shape, target_shape) rho = self.calculate_rho(fpop, vel, missing_mask) elif self.bc_type == "pressure": - rho = self.prescribed_value + rho = self.prescribed_values else: raise ValueError(f"type = {self.bc_type} not supported! Use 'pressure' or 'velocity'.") return rho + @partial(jit, static_argnums=(0, 2, 3), inline=True) + def _broadcast_prescribed_values(self, prescribed_values, prescribed_values_shape, target_shape): + broadcast_dims = [0] # Always include the leading dimension + p_idx = 1 # Start from the second dimension for prescribed_values + for t_idx, t_dim in enumerate(target_shape): + if p_idx < len(prescribed_values_shape) + 1 and prescribed_values_shape[p_idx] == t_dim: + broadcast_dims.append(t_idx) + p_idx += 1 + broadcast_dims = tuple(broadcast_dims) + return jax.lax.broadcast_in_dim(prescribed_values, target_shape, broadcast_dims) + @partial(jit, static_argnums=(0,), inline=True) def get_vel(self, fpop, missing_mask): if self.bc_type == "velocity": - vel = self.prescribed_value + target_shape = (self.velocity_set.d,) + fpop.shape[1:] + vel = self._broadcast_prescribed_values(self.prescribed_values, self.prescribed_values.shape, target_shape) elif self.bc_type == "pressure": - rho = self.prescribed_value + rho = self.prescribed_values vel = self.calculate_vel(fpop, rho, missing_mask) else: raise ValueError(f"type = {self.bc_type} not supported! Use 'pressure' or 'velocity'.") @@ -134,14 +149,22 @@ def calculate_rho(self, fpop, vel, missing_mask): return rho @partial(jit, static_argnums=(0,), inline=True) - def calculate_equilibrium(self, fpop, missing_mask): + def calculate_equilibrium(self, f_1, missing_mask): """ This is the ZouHe method of calculating the missing macroscopic variables at the boundary. """ - rho = self.get_rho(fpop, missing_mask) - vel = self.get_vel(fpop, missing_mask) + # Get the density and velocity from the f_1 + if self.bc_type == "velocity": + vel = self.prescribed_values + elif self.bc_type == "pressure": + rho = self.prescribed_values + vel = self.calculate_vel(f_1, rho, missing_mask) + else: + raise ValueError(f"type = {self.bc_type} not supported! Use 'pressure' or 'velocity'.") + + rho = self.get_rho(f_1, missing_mask) + vel = self.get_vel(f_1, missing_mask) - # compute feq at the boundary feq = self.equilibrium_operator(rho, vel) return feq @@ -158,32 +181,29 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): + def jax_implementation(self, f_0, f_1, bc_mask, missing_mask): # creat a mask to slice boundary cells boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) # compute the equilibrium based on prescribed values and the type of BC - feq = self.calculate_equilibrium(f_post, missing_mask) + feq = self.calculate_equilibrium(f_1, missing_mask) # set the unknown f populations based on the non-equilibrium bounce-back method - f_post_bd = self.bounceback_nonequilibrium(f_post, feq, missing_mask) - f_post = jnp.where(boundary, f_post_bd, f_post) - return f_post + f_1_bd = self.bounceback_nonequilibrium(f_1, feq, missing_mask) + f_1 = jnp.where(boundary, f_1_bd, f_1) + return f_1 def _construct_warp(self): # assign placeholders for both u and rho based on prescribed_value _d = self.velocity_set.d _q = self.velocity_set.q - u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d - rho = self.prescribed_value if self.bc_type == "pressure" else 0.0 # Set local constants TODO: This is a hack and should be fixed with warp update # _u_vec = wp.vec(_d, dtype=self.compute_dtype) - _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _rho = self.compute_dtype(rho) - _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) + _d_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + _rho = self.compute_dtype(0.0) _opp_indices = self.velocity_set.opp_indices _c = self.velocity_set.c _c_float = self.velocity_set.c_float @@ -210,11 +230,11 @@ def get_normal_vectors( if wp.static(_d == 3): for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) + return -_d_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) else: for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -_u_vec(_c_float[0, l], _c_float[1, l]) + return -_d_vec(_c_float[0, l], _c_float[1, l]) @wp.func def bounceback_nonequilibrium( @@ -231,62 +251,79 @@ def bounceback_nonequilibrium( def functional_velocity( index: Any, timestep: Any, - missing_mask: Any, + _missing_mask: Any, f_0: Any, f_1: Any, - f_pre: Any, - f_post: Any, + _f_pre: Any, + _f_post: Any, ): # Post-streaming values are only modified at missing direction - _f = f_post + _f = _f_post # Find normal vector - normals = get_normal_vectors(missing_mask) + normals = get_normal_vectors(_missing_mask) # calculate rho - fsum = _get_fsum(_f, missing_mask) + fsum = _get_fsum(_f, _missing_mask) unormal = self.compute_dtype(0.0) + + # Find the value of u from the missing directions + for l in range(wp.static(_q)): + # Since we are only considering normal velocity, we only need to find one value (all values are the same in the missing directions) + if _missing_mask[l] == wp.uint8(1): + # Create velocity vector by multiplying the prescribed value with the normal vector + # TODO: This can be optimized by saving _missing_mask[l] in the bc class later since it is the same for all boundary cells + prescribed_value = f_1[_opp_indices[l], index[0], index[1], index[2]] + _u = -prescribed_value * normals + break + for d in range(_d): unormal += _u[d] * normals[d] + _rho = fsum / (self.compute_dtype(1.0) + unormal) # impose non-equilibrium bounceback - feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, feq, missing_mask) + _feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, _feq, _missing_mask) return _f @wp.func def functional_pressure( index: Any, timestep: Any, - missing_mask: Any, + _missing_mask: Any, f_0: Any, f_1: Any, - f_pre: Any, - f_post: Any, + _f_pre: Any, + _f_post: Any, ): # Post-streaming values are only modified at missing direction - _f = f_post + _f = _f_post # Find normal vector - normals = get_normal_vectors(missing_mask) + normals = get_normal_vectors(_missing_mask) + + # Find the value of rho from the missing directions + for q in range(wp.static(_q)): + # Since we need only one scalar value, we only need to find one value (all values are the same in the missing directions) + if _missing_mask[q] == wp.uint8(1): + _rho = f_0[_opp_indices[q], index[0], index[1], index[2]] + break # calculate velocity - fsum = _get_fsum(_f, missing_mask) + fsum = _get_fsum(_f, _missing_mask) unormal = -self.compute_dtype(1.0) + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, feq, missing_mask) + _f = bounceback_nonequilibrium(_f, feq, _missing_mask) return _f if self.bc_type == "velocity": functional = functional_velocity elif self.bc_type == "pressure": functional = functional_pressure - elif self.bc_type == "velocity": - functional = functional_pressure kernel = self._construct_kernel(functional) diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index bf1eef2..c09f594 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -7,6 +7,8 @@ from typing import Any from jax import jit from functools import partial +import jax +import jax.numpy as jnp from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -57,13 +59,25 @@ def __init__( # A flag for BCs that need implicit boundary distance between the grid and a mesh (to be set to True if applicable inside each BC) self.needs_mesh_distance = False + # A flag for BCs that need auxilary data initialization before stepper + self.needs_aux_init = False + + # A flag to track if the BC is initialized with auxilary data + self.is_initialized_with_aux_data = False + + # Number of auxilary data needed for the BC (for prescribed values) + self.num_of_aux_data = 0 + + # A flag for BCs that need auxilary data recovery after streaming + self.needs_aux_recovery = False + if self.compute_backend == ComputeBackend.WARP: # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool @wp.func - def prepare_bc_auxilary_data( + def update_bc_auxilary_data( index: Any, timestep: Any, missing_mask: Any, @@ -102,10 +116,10 @@ def _get_thread_data( # Construct some helper warp functions for getting tid data if self.compute_backend == ComputeBackend.WARP: self._get_thread_data = _get_thread_data - self.prepare_bc_auxilary_data = prepare_bc_auxilary_data + self.update_bc_auxilary_data = update_bc_auxilary_data @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): + def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): """ A placeholder function for prepare the auxilary distribution functions for the boundary condition. currently being called after collision only. @@ -137,7 +151,7 @@ def kernel( # Apply the boundary condition if _boundary_id == _id: timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) + _f = functional(index, timestep, _missing_mask, _f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -146,3 +160,54 @@ def kernel( f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) return kernel + + def _construct_aux_data_init_kernel(self, functional): + """ + Constructs the warp kernel for the auxilary data recovery. + """ + _id = wp.uint8(self.id) + _opp_indices = self.velocity_set.opp_indices + _num_of_aux_data = self.num_of_aux_data + + # Construct the warp kernel + @wp.kernel + def aux_data_init_kernel( + f_0: wp.array4d(dtype=Any), + f_1: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # read tid data + _f_0, _f_1, _boundary_id, _missing_mask = self._get_thread_data(f_0, f_1, bc_mask, missing_mask, index) + + # Apply the functional + if _boundary_id == _id: + # prescribed_values is a q-sized vector of type wp.vec + prescribed_values = functional(index) + # Write the result for all q directions, but only store up to num_of_aux_data + # TODO: Somehow raise an error if the number of prescribed values does not match the number of missing directions + counter = wp.int32(0) + for l in range(wp.static(self.velocity_set.q)): + if _missing_mask[l] == wp.uint8(1) and counter < _num_of_aux_data: + f_1[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(prescribed_values[counter]) + counter += 1 + + return aux_data_init_kernel + + def aux_data_init(self, functional, f_0, f_1, bc_mask, missing_mask): + if self.compute_backend == ComputeBackend.WARP: + # Launch the warp kernel + wp.launch( + self._construct_aux_data_init_kernel(functional), + inputs=[f_0, f_1, bc_mask, missing_mask], + dim=f_0.shape[1:], + ) + elif self.compute_backend == ComputeBackend.JAX: + # We don't use boundary aux encoding/decoding in JAX + self.prescribed_values = functional() + self.is_initialized_with_aux_data = True + return f_0, f_1 diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index e08e95c..ed0eb4e 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -76,7 +76,7 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): # Apply collision type boundary conditions for bc in self.boundary_conditions: - f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, bc_mask, missing_mask) + f_post_collision = bc.update_bc_auxilary_data(f_post_stream, f_post_collision, bc_mask, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( f_post_stream, @@ -108,6 +108,8 @@ def _construct_warp(self): # Group active boundary conditions active_bcs = set(boundary_condition_registry.id_to_bc[bc.id] for bc in self.boundary_conditions) + _opp_indices = self.velocity_set.opp_indices + @wp.func def apply_bc( index: Any, @@ -134,7 +136,7 @@ def apply_bc( f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids): if _boundary_id == wp.static(self.boundary_conditions[i].id): - f_result = wp.static(self.boundary_conditions[i].prepare_bc_auxilary_data)( + f_result = wp.static(self.boundary_conditions[i].update_bc_auxilary_data)( index, timestep, missing_mask, f_0, f_1, f_pre, f_post ) return f_result @@ -161,6 +163,23 @@ def get_thread_data( return _f0_thread, _f1_thread, _missing_mask + @wp.func + def apply_aux_recovery_bc( + index: Any, + _boundary_id: Any, + _missing_mask: Any, + f_0: Any, + _f1_thread: Any, + ): + # Unroll the loop over boundary conditions + for i in range(wp.static(len(self.boundary_conditions))): + if wp.static(self.boundary_conditions[i].needs_aux_recovery): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + # Perform the swapping of data + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]]) + @wp.kernel def kernel( f_0: wp.array4d(dtype=Any), @@ -192,13 +211,11 @@ def kernel( # Apply post-collision boundary conditions _f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False) + # Apply auxiliary recovery for boundary conditions (swapping) + apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0, _f1_thread) + # Store the result in f_1 for l in range(self.velocity_set.q): - # TODO: Improve this later - if wp.static("GradsApproximationBC" in active_bcs): - if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]): - if _missing_mask[l] == wp.uint8(1): - f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]]) f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) return None, kernel diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 44aab5f..b28ac7f 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -13,6 +13,11 @@ def __init__(self, operators, boundary_conditions): from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry self.operators = operators + + for bc in boundary_conditions: + if bc.needs_aux_init and not bc.is_initialized_with_aux_data: + raise RuntimeError(f"Boundary condition {bc.__class__.__name__} requires auxiliary data initialization but was not initialized") + self.boundary_conditions = boundary_conditions # Get velocity set, precision policy, and compute backend