Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorAPI] Add initial support for TensorQ8 and TensorQ4 to support quantization and dequantization of Floats #591

Draft
wants to merge 15 commits into
base: develop
Choose a base branch
from
Draft
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
import uk.ac.manchester.tornado.api.types.tensors.GGMLType;

/**
* This class represents an array of bytes stored in native memory.
Expand Down Expand Up @@ -61,6 +62,14 @@ public ByteArray(int numberOfElements) {
segment.setAtIndex(JAVA_INT, 0, numberOfElements);
}


public ByteArray(int numberOfElements, long requiredStorageSize) {
this.numberOfElements = numberOfElements;
baseIndex=0;
segment = Arena.ofAuto().allocate(requiredStorageSize, 1);
segment.setAtIndex(JAVA_INT, 0, numberOfElements);
mikepapadim marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Constructs a new {@link ByteArray} instance by concatenating the contents of the given array of {@link ByteArray} instances.
*
Expand Down Expand Up @@ -123,6 +132,14 @@ public static ByteArray fromSegment(MemorySegment segment) {
return byteArray;
}

// Temporary workaround to copy raw memory segment without a tornado header
public static ByteArray fromSegment(MemorySegment segment, int numberOfElements) {
long byteSize = segment.byteSize();
ByteArray byteArray = new ByteArray(numberOfElements, byteSize);
MemorySegment.copy(segment, 0, byteArray.segment, byteArray.baseIndex * BYTE_BYTES, byteSize);
return byteArray;
}

/**
* Creates a new instance of the {@link ByteArray} class from a {@link ByteBuffer}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ public enum DType {
/**
* Represents a quantized 8-bit unsigned integer used in specialized applications like machine learning, using 1 byte.
*/
QUINT8(1, ValueLayout.JAVA_BYTE);
QUINT8(1, ValueLayout.JAVA_BYTE),

Q4_0(1, ValueLayout.JAVA_BYTE);

// @formatter:on

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2024, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package uk.ac.manchester.tornado.api.types.tensors;


public final class Float16 {
public static final int BYTES = 2;
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2024, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package uk.ac.manchester.tornado.api.types.tensors;

public enum GGMLType {
F32(Float.BYTES),
F16(Float16.BYTES),
Q4_0(Float16.BYTES + 16 * Byte.BYTES, 32),
Q4_1(2 * Float16.BYTES + 16 * Byte.BYTES, 32),
UNSUPPORTED_Q4_2(Integer.MAX_VALUE), // support has been removed
UNSUPPORTED_Q4_3(Integer.MAX_VALUE), // support has been removed
Q5_0(Integer.MAX_VALUE),
Q5_1(Integer.MAX_VALUE),
Q8_0(Float16.BYTES + 32 * Byte.BYTES, 32),
Q8_1(32 * Byte.BYTES + 2 * Float.BYTES, 32),
// k-quantizations
Q2_K(Integer.MAX_VALUE),
Q3_K(Integer.MAX_VALUE),
Q4_K(2 * Float16.BYTES + ((GGMLType.QK_K / 16) / 8 * 6) + GGMLType.QK_K / 2, GGMLType.QK_K),
Q5_K(2 * Float16.BYTES + ((GGMLType.QK_K / 16) / 8 * 6) + GGMLType.QK_K / 8 + GGMLType.QK_K / 2, GGMLType.QK_K),
Q6_K(GGMLType.QK_K / 2 + GGMLType.QK_K / 4 + GGMLType.QK_K / 16 + Float16.BYTES, GGMLType.QK_K),
Q8_K(Integer.MAX_VALUE),
I8(Byte.BYTES),
I16(Short.BYTES),
I32(Integer.BYTES);

private static final GGMLType[] VALUES = values();

private final int typeSize;

private final int blockSize;

public int getTypeSize() {
return typeSize;
}

public int getBlockSize() {
return blockSize;
}

public static GGMLType fromId(int id) {
return VALUES[id];
}

GGMLType(int typeSize) {
this(typeSize, 1);
}

public long byteSizeFor(int numberOfElements) {
long t = numberOfElements * (long) getTypeSize();
assert t % getBlockSize() == 0;
return Math.toIntExact(t / getBlockSize());
}

public static final int QK_K = 256; // or 64?

GGMLType(int typeSize, int blockSize) {
assert blockSize > 0;
assert typeSize > 0;
assert isPowerOf2(blockSize);
this.typeSize = typeSize;
this.blockSize = blockSize;
}

private static boolean isPowerOf2(int n) {
return n > 0 && (n & (n - 1)) == 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ public long[] getDimensions() {
* @return the total size of the shape as an int
*/
public int getSize() {
return (int) Arrays.stream(dimensions).reduce(1, (a, b) -> a * b);
assert Arrays.stream(dimensions).allMatch(i -> i > 0);
return (int) Arrays.stream(dimensions).reduce(Math::multiplyExact).orElseThrow();
}

@Override
Expand Down
Loading