forked from mikepapadim/llama2.tornadovm.java
-
Notifications
You must be signed in to change notification settings - Fork 0
/
kernels.cl
120 lines (103 loc) · 3.94 KB
/
kernels.cl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// ===============================================================
// Llama2 core OpenCL Kernels
// ===============================================================
// ===============================================================
// Kernels rmsNorm
// ===============================================================
__kernel void rmsnormReduction(__global float *partialSums, __global float *x, __local float* localSums) {
int idx = get_global_id(0);
int localIdx = get_local_id(0);
int groupSize = get_local_size(0);
int groupID = get_group_id(0);
localSums[localIdx] = x[idx];
localSums[localIdx] = localSums[localIdx] * localSums[localIdx];
for (int stride = groupSize / 2; stride > 0; stride /= 2) {
barrier(CLK_LOCAL_MEM_FENCE);
if (localIdx < stride) {
localSums[localIdx] += localSums[localIdx + stride];
}
}
if (localIdx == 0) {
partialSums[groupID] = localSums[0];
}
}
__kernel void rmsnormNormalization(__global float *output, __global float *x, __global float *weight, const float ss) {
uint idx = get_global_id(0);
output[idx] = weight[idx] * (ss * x[idx]);
}
// ===============================================================
// Kernels: softmax
// ===============================================================
__kernel void softMaxReduction(__global float *partialMax, __global float *x, __local float* locals) {
uint idx = get_global_id(0);
uint localIdx = get_local_id(0);
uint groupSize = get_local_size(0);
locals[localIdx] = x[idx];
for (int stride = groupSize / 2; stride > 0; stride /= 2) {
barrier(CLK_LOCAL_MEM_FENCE);
if (localIdx < stride) {
if (locals[localIdx] < locals[localIdx + stride]) {
locals[localIdx] = locals[localIdx + stride];
}
}
}
if (localIdx == 0) {
partialMax[get_group_id(0)] = locals[0];
}
}
__kernel void softMaxExpAndSum(__global float *partialSums, __global float *x, __local float* locals, const float maxValue) {
uint idx = get_global_id(0);
uint localIdx = get_local_id(0);
uint groupSize = get_local_size(0);
locals[localIdx] = exp(locals[localIdx] - maxValue);
for (int stride = groupSize / 2; stride > 0; stride /= 2) {
barrier(CLK_LOCAL_MEM_FENCE);
if (localIdx < stride) {
locals[localIdx] += locals[localIdx + stride];
}
}
if (localIdx == 0) {
partialSums[get_group_id(0)] = locals[0];
}
}
__kernel void softMaxNormalization(__global float *x, const float sum) {
uint idx = get_global_id(0);
x[idx] = x[idx] / sum;
}
// ===============================================================
// Kernels: matMul
// ===============================================================
__kernel void matMul(__global float *xout, __global const float *x, __global const float *w, const int n) {
uint idx = get_global_id(0);
float val = 0;
#pragma unroll 8
for (int j = 0; j < n; j++) {
// val = fma(w[idx * n + j], x[j], val);
val += w[idx * n + j] * x[j];
}
xout[idx] = val;
}
// Second approach using OpenCL local memory
__kernel void matMulLocal(__global float *xout,
__global const float *x,
__global const float* w,
__local float *lvector,
const int n) {
// Get the row index for this thread
int idx = get_global_id(0);
// Load vector into local memory to reduce global memory access
int localIdx = get_local_id(0);
int localSize = get_local_size(0);
// Load parts of the vector into local memory in chunks
for (int i = localIdx; i < n; i += localSize) {
lvector[i] = x[i];
}
barrier(CLK_LOCAL_MEM_FENCE);
// Compute the matrix-vector multiplication
float acc = 0.0f;
for (int j = 0; j < n; j++) {
acc += w[idx * n + j] * lvector[j];
}
// store final result
xout[idx] = acc;
}