Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for reduction with sum of squares (RSS), compute of RNSNorm, and RNSNorm fused with Matmul #593

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
/*
* Copyright (c) 2024, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package uk.ac.manchester.tornado.unittests.compute;

import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

import java.util.stream.IntStream;

import static org.junit.Assert.assertEquals;


/**
* Test to validate the functionality of benchmark compute layers or components of a Large Language Model (LLM) transformer,
* focusing on operations like RMS Normalization (RNSNorm) and Matrix-Vector Multiplication (MatMul).
*
* The test verifies that the fused kernel approach, which combines RMSNorm and MatMul in a single execution flow,
* produces correct results by comparing the output of a parallel (fused) execution with a sequential (reference) implementation.
*
* Specifically, the test checks:
* 1. **RMS Normalization (RNSNorm)**:
* - RMSNorm is computed by normalizing input values using the inverse square root of the sum of squares, scaled by a weight factor.
* - The test compares the result of this normalization from the parallel execution against a sequential computation.
*
* 2. **Matrix-Vector Multiplication (MatMul)**:
* - The test also ensures the matrix-vector multiplication operation (MatMul) works correctly when combined with the RMSNorm step, ensuring that the fusion of these operations does not compromise their correctness.
*
* The execution is tested in two forms:
* - **Sequential Execution**: Using traditional CPU-based loops for reference results.
* - **Fused Parallel Execution**: Leveraging TornadoVM's parallel computing capabilities to execute the same operations using fused kernels.
*
* For both the RMSNorm and MatMul operations, the test compares the result from the parallel execution with the sequential reference to ensure that the fused parallel approach produces identical outputs, within an acceptable error tolerance.
*
* <p>
* How to run the tests?
* </p>
*
* <code>
* tornado-test -V uk.ac.manchester.tornado.unittests.compute.LLMFusedKernelsTest
* </code>
*/

public class LLMFusedKernelsTest extends TornadoTestBase {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep consistency with other test classes, I would suggest to move the "Test" at the beginning of the name of the class.


public static FloatArray matmul(FloatArray weights, FloatArray x, FloatArray out, int dim0, int dim1) {
IntStream.range(0, dim0).forEach(i -> {
float result = 0f;
int thisOffset = i * dim1;
for (int j = 0; j < dim1; j++) {
result += weights.get(thisOffset + j) * x.get(j);
}
out.set(i, result);
});
return out;
}

public static void matrixVector(KernelContext context, FloatArray a, FloatArray b, FloatArray c, int dim0, int dim1) {
int idx = context.globalIdx;
float sum = 0.0f;
for (int k = 0; k < dim0; k++) {
sum += a.get((idx * dim1) + k) * b.get(k);
}
c.set(idx, sum);
}

public FloatArray rmsnorm(FloatArray x, FloatArray weight, int size, float rmsNormEps) {
FloatArray out = new FloatArray(size);
// Step 1: Calculate sum of squares
float ss = 0.0f;
for (int i = 0; i < size; i++) {
ss += x.get(i) * x.get(i); // Sum of squares
}

ss /= size; // Normalize by the size
ss += rmsNormEps; // Add epsilon
ss = (float) (1.0 / Math.sqrt(ss)); // Inverse square root

// Step 2: Normalize and scale
for (int i = 0; i < size; i++) {
float normalizedValue = ss * x.get(i);
out.set(i, weight.get(i) * normalizedValue);
}
return out;
}

public static void reduceSquareSums(KernelContext context, FloatArray a, FloatArray reduce) {
int globalIdx = context.globalIdx;
int localIdx = context.localIdx;
int localGroupSize = context.localGroupSizeX;
int groupID = context.groupIdx; // Expose Group ID

float[] localA = context.allocateFloatLocalArray(256);
localA[localIdx] = a.get(globalIdx) * a.get(globalIdx);
for (int stride = (localGroupSize / 2); stride > 0; stride /= 2) {
context.localBarrier();
if (localIdx < stride) {
localA[localIdx] += localA[localIdx + stride];
}
}
if (localIdx == 0) {
reduce.set(groupID, localA[0]);
}
}

private static void finalSum(KernelContext context, FloatArray reduce, int size, float eps) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix format of the file.

int globalIdx = context.globalIdx;

float sum = 0.0f;
if (globalIdx == 0) {
for (int i = 0; i < size; i++) {
sum += reduce.get(i);
}
}

float ss = sum / (float) size;
ss += eps;
ss = 1.0f / TornadoMath.sqrt(ss);
reduce.set(0,ss);
}

public static void normalizeAndScale(KernelContext context,
FloatArray out, FloatArray input, FloatArray weight, FloatArray scalingFactorBuffer,
int size, float eps) {

int globalIdx = context.globalIdx;

if (globalIdx < size) {
float scaledValue = weight.get(globalIdx) * (scalingFactorBuffer.get(0) * input.get(globalIdx));
out.set(globalIdx, scaledValue);
}
}
Comment on lines +147 to +157
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix code formatting.


@Test
public void testRNSNorm() throws TornadoExecutionPlanException {
final int size = 2048;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add the following, unless it is supported:

assertNotBackend(TornadoVMBackendType.SPIRV);

final int localSize = 256;
float eps = 1e-5f;


FloatArray x = new FloatArray(size);
FloatArray rnsOutput = new FloatArray(size);
FloatArray weight = new FloatArray(size);
FloatArray sequentialOutput;

FloatArray reduce = new FloatArray(size / localSize);


for (int i = 0; i < size; i++) {
x.set(i, (float) (i + 1));
weight.set(i, 0.5f);
}

sequentialOutput = rmsnorm(x, weight, size, eps);

WorkerGrid worker = new WorkerGrid1D(size);
// Change the Grid
worker.setGlobalWork(size, 1, 1);
worker.setLocalWork(localSize, 1, 1);
GridScheduler gridScheduler = new GridScheduler("s0.t0", worker);
gridScheduler.setWorkerGrid("s0.t1", new WorkerGrid1D(1));
gridScheduler.setWorkerGrid("s0.t2", worker);

KernelContext context = new KernelContext();

TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.FIRST_EXECUTION, size, eps, x, reduce) //
.transferToDevice(DataTransferMode.EVERY_EXECUTION, localSize, weight) //
.task("t0", LLMFusedKernelsTest::reduceSquareSums, context, x, reduce) //
.task("t1", LLMFusedKernelsTest::finalSum, context, reduce, size, eps) //
.task("t2", LLMFusedKernelsTest::normalizeAndScale, context, rnsOutput, x, weight, reduce, size, eps) //
.transferToHost(DataTransferMode.EVERY_EXECUTION, rnsOutput); //


ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
executionPlan.withGridScheduler(gridScheduler) //
.execute();
}


for (int i = 0; i < rnsOutput.getSize(); i++) {
float expected = sequentialOutput.get(i); // Expected value from the sequential output
float actual = rnsOutput.get(i); // Actual value from the RNS output

// Perform the comparison with assertion
assertEquals("Mismatch at index " + i, expected, actual, 1e-6f); // Allow some tolerance
}

}

@Test
public void testRNSNormFusedWithMatMul() throws TornadoExecutionPlanException {
final int size = 2048;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add the following, unless it is supported:

assertNotBackend(TornadoVMBackendType.SPIRV);

final int localSize = 256;
float eps = 1e-5f;


FloatArray x = new FloatArray(size);
FloatArray state = new FloatArray(size);
FloatArray rnsOutput = new FloatArray(size);
FloatArray sequentialOutput;

FloatArray weights = new FloatArray(size * size);
FloatArray outputLogits = new FloatArray(size);
FloatArray outputSeqLogits = new FloatArray( size);

FloatArray reduce = new FloatArray(size / localSize);

outputLogits.init(0f);
outputSeqLogits.init(0f);

for (int i = 0; i < size; i++) {
x.set(i, (float) (i + 1));
state.set(i, 0.5f+i);
}

for (int i = 0; i < size * size; i++) {
weights.set(i, 0.8f+i);
}

sequentialOutput = rmsnorm(x, weights, size, eps);
outputSeqLogits = matmul(weights, sequentialOutput, outputSeqLogits, size, size);


KernelContext context = new KernelContext();
TaskGraph taskGraph = new TaskGraph("fused") //
.transferToDevice(DataTransferMode.FIRST_EXECUTION, size, eps, x, reduce, weights) //
.transferToDevice(DataTransferMode.EVERY_EXECUTION, localSize, state) //
.task("reduce", LLMFusedKernelsTest::reduceSquareSums, context, x, reduce) //
.task("sum", LLMFusedKernelsTest::finalSum, context, reduce, size, eps) //
.task("ns", LLMFusedKernelsTest::normalizeAndScale, context, rnsOutput, x, state, reduce, size, eps) //
.task("mv", LLMFusedKernelsTest::matrixVector, context, weights, rnsOutput, outputLogits, size, size) //
.transferToHost(DataTransferMode.EVERY_EXECUTION, outputLogits); //


WorkerGrid worker = new WorkerGrid1D(size);
worker.setGlobalWork(size, 1, 1);
worker.setLocalWork(localSize, 1, 1);
GridScheduler gridScheduler = new GridScheduler("fused.reduce", worker);
gridScheduler.setWorkerGrid("fused.sum", new WorkerGrid1D(1));
gridScheduler.setWorkerGrid("fused.ns", worker);
gridScheduler.setWorkerGrid("fused.mv", worker);

ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
executionPlan.withGridScheduler(gridScheduler) //
.execute();
}


for (int i = 0; i < outputSeqLogits.getSize(); i++) {
float expected = outputSeqLogits.get(i); // Expected value from the sequential output
float actual = outputLogits.get(i); // Actual value from the RNS output

// assertEquals("Mismatch at index " + i, expected, actual, 1f); // Allow some tolerance
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the test should assert the expected and actual values, right?

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,14 @@ private static void multiplyFloats(FloatArray input, @Reduce FloatArray result)
}
}

private static void sumSquareReduce(FloatArray input, @Reduce FloatArray result) {
result.set(0, 1.0f);
for (@Parallel int i = 0; i < input.getSize(); i++) {
result.set(0, result.get(0) + (input.get(i) * input.get(i)));
}
}


private static void computeSum(final FloatArray values, @Reduce FloatArray result) {
result.set(0, 0.0f);
for (@Parallel int i = 0; i < values.getSize(); i++) {
Expand Down Expand Up @@ -320,6 +328,38 @@ public void testMultFloats() throws TornadoExecutionPlanException {
assertEquals(sequential.get(0), result.get(0), 0.1f);
}

@Test
public void testReduceSumSquares() throws TornadoExecutionPlanException {
FloatArray input = new FloatArray(SIZE);
FloatArray result = new FloatArray(1);
final int neutral = 1;
result.init(neutral);

Random r = new Random();
IntStream.range(0, SIZE).sequential().forEach(i -> {
input.set(i, 1.0f);
});

input.set(10, r.nextFloat());
input.set(12, r.nextFloat());

TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.EVERY_EXECUTION, input) //
.task("t0", TestReductionsFloats::sumSquareReduce, input, result) //
.transferToHost(DataTransferMode.EVERY_EXECUTION, result);
ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
executionPlan.execute();
}

FloatArray sequential = new FloatArray(1);
sequential.init(1.0f);
sumSquareReduce(input, sequential);

// Check result
assertEquals(sequential.get(0), result.get(0), 0.1f);
}

private static void reductionAddFloatsConditionally(FloatArray input, @Reduce FloatArray result) {
result.set(0, 0.0f);
for (@Parallel int i = 0; i < input.getSize(); i++) {
Expand Down