Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for reduction with sum of squares (RSS), compute of RNSNorm, and RNSNorm fused with Matmul #593

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from

Conversation

mikepapadim
Copy link
Member

Description

This PR adds tests to test the code gen and and vailidity of the results for some common operation required in the LLM architecture.

  1. Reduction with Sum of Squares (RSS) - Implements and tests the accuracy and stability of RSS with the generated code to use snippets for local memory.
  2. Computation of RNSNorm - Introduces tests for RNSNorm, validating it can combine the above with a serial kernel.
  3. Fused RNSNorm with Matmul - Adds tests for the RNSNorm function when fused with matrix multiplication (Matmul), ensuring compatibility and efficiency.

Backend/s tested

Mark the backends affected by this PR.

  • OpenCL
  • PTX
  • SPIRV

OS tested

Mark the OS where this PR is tested.

  • Linux
  • OSx
  • Windows

Did you check on FPGAs?

If it is applicable, check your changes on FPGAs.

  • Yes
  • No

How to test the new patch?

tornado-test -V uk.ac.manchester.tornado.unittests.reductions.TestReductionsFloats#testReduceSumSquares

tornado-test -V uk.ac.manchester.tornado.unittests.compute.LLMFusedKernelsTest

}
}

private static void finalSum(KernelContext context, FloatArray reduce, int size, float eps) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix format of the file.

float expected = outputSeqLogits.get(i); // Expected value from the sequential output
float actual = outputLogits.get(i); // Actual value from the RNS output

// assertEquals("Mismatch at index " + i, expected, actual, 1f); // Allow some tolerance
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the test should assert the expected and actual values, right?

@jjfumero
Copy link
Member

Include the new test in the test-suite:
tornado-assembly/src/bin/tornado-test

@jjfumero
Copy link
Member

The new test enters in an infinite loop when running with the SPIR-V backend:

tornado-test -V uk.ac.manchester.tornado.unittests.compute.LLMFusedKernelsTest
/home/juan/tornadovm/TornadoVM/bin/sdk/bin/tornado --jvm "-Xmx6g -Dtornado.recover.bailout=False -Dtornado.unittests.verbose=True "  -m  tornado.unittests/uk.ac.manchester.tornado.unittests.tools.TornadoTestRunner  --params "uk.ac.manchester.tornado.unittests.compute.LLMFusedKernelsTest"

The PTX and OpenCL backends run fine.

Copy link
Collaborator

@stratika stratika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add the new LLMFusedKernelsTest class in the tornado-test in order to be run when Jenkins runs the unit-tests. In my setup, the tests pass for PTX. But, the tests in the LLMFusedKernelsTest class are not finishing when running with SPIR-V. I guess, that they are not supported for SPIR-V?

* </code>
*/

public class LLMFusedKernelsTest extends TornadoTestBase {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep consistency with other test classes, I would suggest to move the "Test" at the beginning of the name of the class.

Comment on lines +147 to +157
public static void normalizeAndScale(KernelContext context,
FloatArray out, FloatArray input, FloatArray weight, FloatArray scalingFactorBuffer,
int size, float eps) {

int globalIdx = context.globalIdx;

if (globalIdx < size) {
float scaledValue = weight.get(globalIdx) * (scalingFactorBuffer.get(0) * input.get(globalIdx));
out.set(globalIdx, scaledValue);
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix code formatting.


@Test
public void testRNSNorm() throws TornadoExecutionPlanException {
final int size = 2048;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add the following, unless it is supported:

assertNotBackend(TornadoVMBackendType.SPIRV);


@Test
public void testRNSNormFusedWithMatMul() throws TornadoExecutionPlanException {
final int size = 2048;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add the following, unless it is supported:

assertNotBackend(TornadoVMBackendType.SPIRV);

float expected = outputSeqLogits.get(i); // Expected value from the sequential output
float actual = outputLogits.get(i); // Actual value from the RNS output

// assertEquals("Mismatch at index " + i, expected, actual, 1f); // Allow some tolerance
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the test should assert the expected and actual values, right?

@jjfumero
Copy link
Member

@mikepapadim , is this ready?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

3 participants