Skip to content

Commit

Permalink
Support private half vectors and change ocl type of half from short t…
Browse files Browse the repository at this point in the history
…o object
  • Loading branch information
mairooni committed Feb 20, 2024
1 parent fe33326 commit e8faf81
Show file tree
Hide file tree
Showing 12 changed files with 361 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,8 @@ public static class OCLOp2 extends OCLOp {
public static final OCLOp2 VMOV_FLOAT2 = new OCLOp2("(float2)");
public static final OCLOp2 VMOV_BYTE2 = new OCLOp2("(char2)");
public static final OCLOp2 VMOV_DOUBLE2 = new OCLOp2("(double2)");

public static final OCLOp2 VMOV_HALF2 = new OCLOp2("(half2)");
// @formatter:on

protected OCLOp2(String opcode) {
Expand Down Expand Up @@ -1092,6 +1094,8 @@ public static class OCLOp3 extends OCLOp2 {
public static final OCLOp3 VMOV_BYTE3 = new OCLOp3("(char3)");
public static final OCLOp3 VMOV_DOUBLE3 = new OCLOp3("(double3)");

public static final OCLOp3 VMOV_HALF3 = new OCLOp3("(half3)");

// @formatter:on
public OCLOp3(String opcode) {
super(opcode);
Expand All @@ -1118,6 +1122,8 @@ public static class OCLOp4 extends OCLOp3 {
public static final OCLOp4 VMOV_FLOAT4 = new OCLOp4("(float4)");
public static final OCLOp4 VMOV_BYTE4 = new OCLOp4("(char4)");
public static final OCLOp4 VMOV_DOUBLE4 = new OCLOp4("(double4)");

public static final OCLOp4 VMOV_HALF4 = new OCLOp4("(half4)");
// @formatter:on

protected OCLOp4(String opcode) {
Expand Down Expand Up @@ -1148,6 +1154,8 @@ public static class OCLOp8 extends OCLOp4 {
public static final OCLOp8 VMOV_BYTE8 = new OCLOp8("(char8)");
public static final OCLOp8 VMOV_DOUBLE8 = new OCLOp8("(double8)");

public static final OCLOp8 VMOV_HALF8 = new OCLOp8("(half8)");

// @formatter:on

protected OCLOp8(String opcode) {
Expand Down Expand Up @@ -1184,6 +1192,7 @@ public static class OCLOp16 extends OCLOp8 {
public static final OCLOp16 VMOV_FLOAT16 = new OCLOp16("(float16)");
public static final OCLOp16 VMOV_BYTE16 = new OCLOp16("(char16)");
public static final OCLOp16 VMOV_DOUBLE16 = new OCLOp16("(double16)");
public static final OCLOp16 VMOV_HALF16 = new OCLOp16("(half16)");
// @formatter:on
protected OCLOp16(String opcode) {
super(opcode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
import uk.ac.manchester.tornado.drivers.opencl.graal.phases.OCLFPGAThreadScheduler;
import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoAtomicsParametersPhase;
import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoAtomicsScheduling;
import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoHalfFloatConstantReplacement;
import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoHalfFloatVectorOffset;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
import uk.ac.manchester.tornado.runtime.graal.compiler.TornadoLowTier;

Expand Down Expand Up @@ -81,6 +83,10 @@ public OCLLowTier(OptionValues options, TornadoDeviceContext tornadoDeviceContex
appendPhase(new OCLFPGAThreadScheduler());
}

appendPhase(new TornadoHalfFloatConstantReplacement());

appendPhase(new TornadoHalfFloatVectorOffset());

appendPhase(new TornadoLoopCanonicalization());

if (TornadoOptions.ENABLE_FMA) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,15 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
});

r.register(new InvocationPlugin("set", Receiver.class, int.class, storageType) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode laneId, ValueNode value) {
final VectorStoreElementProxyNode store = new VectorStoreElementProxyNode(vectorKind.getElementKind(), receiver.get(), laneId, value);
b.add(b.append(store));
return true;
}
});

r.register(new InvocationPlugin("add", declaringClass, declaringClass) {
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode input1, ValueNode input2) {
final ResolvedJavaType resolvedType = b.getMetaAccess().lookupJavaType(declaringClass);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ public static int lookupTypeIndex(OCLKind kind) {
return 3;
case DOUBLE:
return 4;
case HALF:
return 5;
default:
return -1;
}
Expand Down Expand Up @@ -538,8 +540,9 @@ public JavaKind asJavaKind() {
return JavaKind.Byte;
case SHORT:
case USHORT:
case HALF:
return JavaKind.Short;
case HALF:
return JavaKind.Object;
case INT:
case UINT:
return JavaKind.Int;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ public final int laneId() {
return (lane instanceof ConstantNode) ? lane.asJavaConstant().asInt() : -1;
}

public final ValueNode getLaneId() {
return this.lane;
}

public ValueNode getVector() {
return vector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,31 @@
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp16.VMOV_BYTE16;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp16.VMOV_DOUBLE16;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp16.VMOV_FLOAT16;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp16.VMOV_HALF16;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp16.VMOV_INT16;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp16.VMOV_SHORT16;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp2.VMOV_BYTE2;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp2.VMOV_DOUBLE2;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp2.VMOV_FLOAT2;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp2.VMOV_HALF2;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp2.VMOV_INT2;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp2.VMOV_SHORT2;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp3.VMOV_BYTE3;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp3.VMOV_DOUBLE3;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp3.VMOV_FLOAT3;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp3.VMOV_HALF3;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp3.VMOV_INT3;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp3.VMOV_SHORT3;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp4.VMOV_BYTE4;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp4.VMOV_DOUBLE4;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp4.VMOV_FLOAT4;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp4.VMOV_HALF4;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp4.VMOV_INT4;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp4.VMOV_SHORT4;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp8.VMOV_BYTE8;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp8.VMOV_DOUBLE8;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp8.VMOV_FLOAT8;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp8.VMOV_HALF8;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp8.VMOV_INT8;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLOp8.VMOV_SHORT8;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLTernaryIntrinsic.VSTORE16;
Expand Down Expand Up @@ -80,11 +85,11 @@ public final class VectorUtil {

private static final OCLUnaryOp[] pointerTable = new OCLUnaryOp[] { CAST_TO_SHORT_PTR, CAST_TO_INT_PTR, CAST_TO_FLOAT_PTR, CAST_TO_BYTE_PTR };

private static final OCLOp2[] assignOp2Table = new OCLOp2[] { VMOV_SHORT2, VMOV_INT2, VMOV_FLOAT2, VMOV_BYTE2, VMOV_DOUBLE2 };
private static final OCLOp3[] assignOp3Table = new OCLOp3[] { VMOV_SHORT3, VMOV_INT3, VMOV_FLOAT3, VMOV_BYTE3, VMOV_DOUBLE3 };
private static final OCLOp4[] assignOp4Table = new OCLOp4[] { VMOV_SHORT4, VMOV_INT4, VMOV_FLOAT4, VMOV_BYTE4, VMOV_DOUBLE4 };
private static final OCLOp8[] assignOp8Table = new OCLOp8[] { VMOV_SHORT8, VMOV_INT8, VMOV_FLOAT8, VMOV_BYTE8, VMOV_DOUBLE8 };
private static final OCLOp16[] assignOp16Table = new OCLOp16[] { VMOV_SHORT16, VMOV_INT16, VMOV_FLOAT16, VMOV_BYTE16, VMOV_DOUBLE16 };
private static final OCLOp2[] assignOp2Table = new OCLOp2[] { VMOV_SHORT2, VMOV_INT2, VMOV_FLOAT2, VMOV_BYTE2, VMOV_DOUBLE2, VMOV_HALF2 };
private static final OCLOp3[] assignOp3Table = new OCLOp3[] { VMOV_SHORT3, VMOV_INT3, VMOV_FLOAT3, VMOV_BYTE3, VMOV_DOUBLE3, VMOV_HALF3 };
private static final OCLOp4[] assignOp4Table = new OCLOp4[] { VMOV_SHORT4, VMOV_INT4, VMOV_FLOAT4, VMOV_BYTE4, VMOV_DOUBLE4, VMOV_HALF4 };
private static final OCLOp8[] assignOp8Table = new OCLOp8[] { VMOV_SHORT8, VMOV_INT8, VMOV_FLOAT8, VMOV_BYTE8, VMOV_DOUBLE8, VMOV_HALF8 };
private static final OCLOp16[] assignOp16Table = new OCLOp16[] { VMOV_SHORT16, VMOV_INT16, VMOV_FLOAT16, VMOV_BYTE16, VMOV_DOUBLE16, VMOV_HALF16 };

private static <T> T lookupValueByLength(T[] array, OCLKind vectorKind) {
final int index = vectorKind.lookupLengthIndex();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package uk.ac.manchester.tornado.drivers.opencl.graal.phases;

import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.phases.Phase;
import uk.ac.manchester.tornado.runtime.graal.HalfFloatConstant;

import java.util.Optional;

public class TornadoHalfFloatConstantReplacement extends Phase {

@Override
public Optional<NotApplicable> notApplicableTo(GraphState graphState) {
return ALWAYS_APPLICABLE;
}

protected void run(StructuredGraph graph) {

for (HalfFloatConstant halfFloatConstant : graph.getNodes().filter(HalfFloatConstant.class)) {
ValueNode input = halfFloatConstant.getValue();
halfFloatConstant.replaceAtUsages(input);
halfFloatConstant.safeDelete();
}

}
}
Loading

0 comments on commit e8faf81

Please sign in to comment.