Skip to content

Commit

Permalink
Cleaning up tests in chlo_ops.mlir backported twice
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Dec 20, 2024
1 parent 8c7946c commit 3bc119b
Showing 1 changed file with 0 additions and 216 deletions.
216 changes: 0 additions & 216 deletions stablehlo/tests/ops_chlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -289,222 +289,6 @@ func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tenso

// -----

// ragged_dot mode 1: [b,m,k], [g,b,k,n], [g] -> [b,m,n]
func.func @ragged_dot_non_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> {
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [1],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [2],
lhs_ragged_dimensions = [1],
rhs_group_dimensions = [0]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32>
func.return %0 : tensor<2x11x7xf32>
}

// -----

// ragged_dot mode 2: [m,k], [k,n], [g] -> [g,m,n]
func.func @ragged_dot_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x2x11x7xf32> {
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1],
lhs_ragged_dimensions = [2],
rhs_group_dimensions = []
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<2x11x5xf32>, tensor<2x5x7xf32>, tensor<3xi64>) -> tensor<3x2x11x7xf32>
func.return %0 : tensor<3x2x11x7xf32>
}

// -----

// ragged_dot mode 3: [b,m,k], [b,k,n], [g] -> [b,m,n]
func.func @ragged_dot_batch(%lhs : tensor<3x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x11x7xf32> {
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1],
lhs_ragged_dimensions = [0],
rhs_group_dimensions = []
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<3x11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<3x11x7xf32>
func.return %0 : tensor<3x11x7xf32>
}

// -----

func.func @ragged_dot_incompatible_contracting_dims(%lhs : tensor<11x5xf32>, %rhs : tensor<3x2x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> {
// @expected-error@+1 {{contracting dimension sizes must match}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [],
rhs_batching_dimensions = [],
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [1],
lhs_ragged_dimensions = [0],
rhs_group_dimensions = [0]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<11x5xf32>, tensor<3x2x7xf32>, tensor<3xi64>) -> tensor<11x7xf32>
func.return %0 : tensor<11x7xf32>
}

// -----

func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3x2xi64>) -> tensor<11x7xf32> {
// @expected-error@+1 {{expected rank of group_sizes of ragged dot to be 1, got 2}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [],
rhs_batching_dimensions = [],
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [1],
lhs_ragged_dimensions = [0],
rhs_group_dimensions = [0]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3x2xi64>) -> tensor<11x7xf32>
func.return %0 : tensor<11x7xf32>
}

// -----

func.func @ragged_dot_group_sizes_incorrect_shape(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> {
// @expected-error@+1 {{group_sizes is expected to have shape=[3], got [2]}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [],
rhs_batching_dimensions = [],
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [1],
lhs_ragged_dimensions = [0],
rhs_group_dimensions = [0]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<2xi64>) -> tensor<11x7xf32>
func.return %0 : tensor<11x7xf32>
}

// -----

func.func @ragged_dot_incorrect_number_of_lhs_ragged_dimensions(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> {
// @expected-error@+1 {{There must be exactly one ragged dimension in the lhs}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [],
rhs_batching_dimensions = [],
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [1],
lhs_ragged_dimensions = [0, 1],
rhs_group_dimensions = [0]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32>
func.return %0 : tensor<11x7xf32>
}

// -----

func.func @ragged_dot_rhs_group_dim_is_batch(%lhs : tensor<3x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x11x7xf32> {
// @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_batching_dimensions: 0}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1],
lhs_ragged_dimensions = [1],
rhs_group_dimensions = [0]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<3x11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<3x11x7xf32>
func.return %0 : tensor<3x11x7xf32>
}

// -----

func.func @ragged_dot_rhs_group_dim_is_contracting(%lhs : tensor<11x3xf32>, %rhs : tensor<3x3x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> {
// @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_contracting_dimensions: 1}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [],
rhs_batching_dimensions = [],
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [1],
lhs_ragged_dimensions = [0],
rhs_group_dimensions = [1]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<11x3xf32>, tensor<3x3x7xf32>, tensor<3xi64>) -> tensor<11x7xf32>
func.return %0 : tensor<11x7xf32>
}

// -----

func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_batch(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> {
// @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [1],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [2],
lhs_ragged_dimensions = [0],
rhs_group_dimensions = [0]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32>
func.return %0 : tensor<2x11x7xf32>
}

// -----

func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_contracting(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> {
// @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [],
rhs_batching_dimensions = [],
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [1],
lhs_ragged_dimensions = [1],
rhs_group_dimensions = [0]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32>
func.return %0 : tensor<11x7xf32>
}

// -----

func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tensor<11x5xf32>, %rhs : tensor<5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> {
// @expected-error@+1 {{There must be exactly one group dimension in the rhs when the lhs ragged dimension is non-contracting}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [],
rhs_batching_dimensions = [],
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [0],
lhs_ragged_dimensions = [0],
rhs_group_dimensions = []
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<11x5xf32>, tensor<5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32>
func.return %0 : tensor<11x7xf32>
}

// -----

func.func @top_k(%arg0 : tensor<f32>) {
// expected-error @+2 {{failed to infer returned types}}
// @expected-error @+1{{operand's rank must be at least 1}}
Expand Down

0 comments on commit 3bc119b

Please sign in to comment.