-
Notifications
You must be signed in to change notification settings - Fork 46
/
algo_es.go
140 lines (121 loc) · 2.8 KB
/
algo_es.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"math/big"
)
// NewSignerES returns a new ECDSA-based signer.
func NewSignerES(alg Algorithm, key *ecdsa.PrivateKey) (*ESAlg, error) {
if key == nil {
return nil, ErrNilKey
}
hash, err := getParamsES(alg, roundBytes(key.PublicKey.Params().BitSize)*2)
if err != nil {
return nil, err
}
return &ESAlg{
alg: alg,
hash: hash,
privateKey: key,
publicKey: nil,
signSize: roundBytes(key.PublicKey.Params().BitSize) * 2,
}, nil
}
// NewVerifierES returns a new ECDSA-based verifier.
func NewVerifierES(alg Algorithm, key *ecdsa.PublicKey) (*ESAlg, error) {
if key == nil {
return nil, ErrNilKey
}
hash, err := getParamsES(alg, roundBytes(key.Params().BitSize)*2)
if err != nil {
return nil, err
}
return &ESAlg{
alg: alg,
hash: hash,
privateKey: nil,
publicKey: key,
signSize: roundBytes(key.Params().BitSize) * 2,
}, nil
}
func getParamsES(alg Algorithm, size int) (crypto.Hash, error) {
var hash crypto.Hash
var keySize int
switch alg {
case ES256:
hash, keySize = crypto.SHA256, 64
case ES384:
hash, keySize = crypto.SHA384, 96
case ES512:
hash, keySize = crypto.SHA512, 132
default:
return 0, ErrUnsupportedAlg
}
if keySize != size {
return 0, ErrInvalidKey
}
return hash, nil
}
type ESAlg struct {
alg Algorithm
hash crypto.Hash
publicKey *ecdsa.PublicKey
privateKey *ecdsa.PrivateKey
signSize int
}
func (es *ESAlg) Algorithm() Algorithm {
return es.alg
}
func (es *ESAlg) SignSize() int {
return es.signSize
}
func (es *ESAlg) Sign(payload []byte) ([]byte, error) {
digest, err := hashPayload(es.hash, payload)
if err != nil {
return nil, err
}
r, s, err := ecdsa.Sign(rand.Reader, es.privateKey, digest)
if err != nil {
return nil, err
}
pivot := es.SignSize() / 2
rBytes, sBytes := r.Bytes(), s.Bytes()
signature := make([]byte, es.SignSize())
copy(signature[pivot-len(rBytes):], rBytes)
copy(signature[pivot*2-len(sBytes):], sBytes)
return signature, nil
}
func (es *ESAlg) Verify(token *Token) error {
switch {
case !token.isValid():
return ErrUninitializedToken
case !constTimeAlgEqual(token.Header().Algorithm, es.alg):
return ErrAlgorithmMismatch
default:
return es.verify(token.PayloadPart(), token.Signature())
}
}
func (es *ESAlg) verify(payload, signature []byte) error {
if len(signature) != es.SignSize() {
return ErrInvalidSignature
}
digest, err := hashPayload(es.hash, payload)
if err != nil {
return err
}
pivot := es.SignSize() / 2
r := big.NewInt(0).SetBytes(signature[:pivot])
s := big.NewInt(0).SetBytes(signature[pivot:])
if !ecdsa.Verify(es.publicKey, digest, r, s) {
return ErrInvalidSignature
}
return nil
}
func roundBytes(n int) int {
res := n / 8
if n%8 > 0 {
return res + 1
}
return res
}