This repository has been archived by the owner on Jun 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
gptj.go
99 lines (78 loc) · 2.62 KB
/
gptj.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
package gptj
// #cgo CFLAGS: -I./gpt4all-j/ggml/include/ggml/ -I./gpt4all-j/ggml/examples/ -I./gpt4all-j/llmodel -I./gpt4all-j/llmodel/llama.cpp/ -I./
// #cgo CXXFLAGS: -std=c++17 -I./gpt4all-j/ggml/include/ggml/ -I./gpt4all-j/ggml/examples/ -I./gpt4all-j/llmodel -I./ -I./gpt4all-j/llmodel/llama.cpp/
// #cgo darwin LDFLAGS: -framework Accelerate
// #cgo darwin CXXFLAGS: -std=c++17
// #cgo LDFLAGS: -lgptj -lm -lstdc++
// #include <binding.h>
import "C"
import (
"fmt"
"runtime"
"strings"
"sync"
"unsafe"
)
// The following code is https://github.com/go-skynet/go-llama.cpp/blob/master/llama.go with small changes
type GPTJ struct {
state unsafe.Pointer
}
func New(model string, opts ...ModelOption) (*GPTJ, error) {
ops := NewModelOptions(opts...)
state := C.binding_load_gptj_model(C.CString(model), C.int(ops.Threads))
if state == nil {
return nil, fmt.Errorf("failed loading model")
}
gpt := &GPTJ{state: state}
// set a finalizer to remove any callbacks when the struct is reclaimed by the garbage collector.
runtime.SetFinalizer(gpt, func(g *GPTJ) {
setCallback(g.state, nil)
})
return gpt, nil
}
func (l *GPTJ) Predict(text string, opts ...PredictOption) (string, error) {
po := NewPredictOptions(opts...)
input := C.CString(text)
if po.Tokens == 0 {
po.Tokens = 99999999
}
out := make([]byte, po.Tokens)
C.binding_model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize),
C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase))
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
res = strings.TrimPrefix(res, " ")
res = strings.TrimPrefix(res, text)
res = strings.TrimPrefix(res, "\n")
res = strings.TrimSuffix(res, "<|endoftext|>")
return res, nil
}
func (l *GPTJ) Free() {
C.binding_gptj_free_model(l.state)
}
func (l *GPTJ) SetTokenCallback(callback func(token string) bool) {
setCallback(l.state, callback)
}
var (
m sync.Mutex
callbacks = map[uintptr]func(string) bool{}
)
//export bindingTokenCallback
func bindingTokenCallback(statePtr unsafe.Pointer, token *C.char) bool {
m.Lock()
defer m.Unlock()
if callback, ok := callbacks[uintptr(statePtr)]; ok {
return callback(C.GoString(token))
}
return true
}
// setCallback can be used to register a token callback for LLama. Pass in a nil callback to
// remove the callback.
func setCallback(statePtr unsafe.Pointer, callback func(string) bool) {
m.Lock()
defer m.Unlock()
if callback == nil {
delete(callbacks, uintptr(statePtr))
} else {
callbacks[uintptr(statePtr)] = callback
}
}