@@ -4,185 +4,58 @@ import (
4
4
"bytes"
5
5
"context"
6
6
"crypto/rand"
7
- "crypto/sha256"
8
7
"encoding/base64"
9
- "encoding/hex"
10
- "encoding/json"
11
8
"fmt"
12
9
"io"
13
10
"log/slog"
14
- "net/http"
15
- "net/url"
16
11
"os"
17
12
"path/filepath"
18
- "strconv"
19
- "strings"
20
- "time"
21
13
22
14
"golang.org/x/crypto/ssh"
23
-
24
- "github.com/jmorganca/ollama/api"
25
- )
26
-
27
- const (
28
- KeyType = "id_ed25519"
29
15
)
30
16
31
- type AuthRedirect struct {
32
- Realm string
33
- Service string
34
- Scope string
35
- }
17
+ const defaultPrivateKey = "id_ed25519"
36
18
37
- type SignatureData struct {
38
- Method string
39
- Path string
40
- Data []byte
41
- }
42
-
43
- func generateNonce (length int ) (string , error ) {
19
+ func NewNonce (r io.Reader , length int ) (string , error ) {
44
20
nonce := make ([]byte , length )
45
- _ , err := rand .Read (nonce )
46
- if err != nil {
21
+ if _ , err := io .ReadFull (r , nonce ); err != nil {
47
22
return "" , err
48
23
}
49
- return base64 .RawURLEncoding .EncodeToString (nonce ), nil
50
- }
51
-
52
- func (r AuthRedirect ) URL () (* url.URL , error ) {
53
- redirectURL , err := url .Parse (r .Realm )
54
- if err != nil {
55
- return nil , err
56
- }
57
-
58
- values := redirectURL .Query ()
59
-
60
- values .Add ("service" , r .Service )
61
-
62
- for _ , s := range strings .Split (r .Scope , " " ) {
63
- values .Add ("scope" , s )
64
- }
65
24
66
- values .Add ("ts" , strconv .FormatInt (time .Now ().Unix (), 10 ))
67
-
68
- nonce , err := generateNonce (16 )
69
- if err != nil {
70
- return nil , err
71
- }
72
- values .Add ("nonce" , nonce )
73
-
74
- redirectURL .RawQuery = values .Encode ()
75
- return redirectURL , nil
25
+ return base64 .RawURLEncoding .EncodeToString (nonce ), nil
76
26
}
77
27
78
- func SignRequest ( method , url string , data []byte , headers http. Header ) error {
28
+ func Sign ( ctx context. Context , bts []byte ) ( string , error ) {
79
29
home , err := os .UserHomeDir ()
80
30
if err != nil {
81
- return err
82
- }
83
-
84
- keyPath := filepath .Join (home , ".ollama" , KeyType )
85
-
86
- rawKey , err := os .ReadFile (keyPath )
87
- if err != nil {
88
- slog .Info (fmt .Sprintf ("Failed to load private key: %v" , err ))
89
- return err
90
- }
91
-
92
- s := SignatureData {
93
- Method : method ,
94
- Path : url ,
95
- Data : data ,
96
- }
97
-
98
- sig , err := s .Sign (rawKey )
99
- if err != nil {
100
- return err
101
- }
102
-
103
- headers .Set ("Authorization" , sig )
104
- return nil
105
- }
106
-
107
- func GetAuthToken (ctx context.Context , redirData AuthRedirect ) (string , error ) {
108
- redirectURL , err := redirData .URL ()
109
- if err != nil {
110
- return "" , err
111
- }
112
-
113
- headers := make (http.Header )
114
- err = SignRequest (http .MethodGet , redirectURL .String (), nil , headers )
115
- if err != nil {
116
- return "" , err
117
- }
118
- resp , err := MakeRequest (ctx , http .MethodGet , redirectURL , headers , nil , nil )
119
- if err != nil {
120
- slog .Info (fmt .Sprintf ("couldn't get token: %q" , err ))
121
31
return "" , err
122
32
}
123
- defer resp .Body .Close ()
124
-
125
- if resp .StatusCode >= http .StatusBadRequest {
126
- responseBody , err := io .ReadAll (resp .Body )
127
- if err != nil {
128
- return "" , fmt .Errorf ("%d: %v" , resp .StatusCode , err )
129
- } else if len (responseBody ) > 0 {
130
- return "" , fmt .Errorf ("%d: %s" , resp .StatusCode , responseBody )
131
- }
132
33
133
- return "" , fmt .Errorf ("%s" , resp .Status )
134
- }
34
+ keyPath := filepath .Join (home , ".ollama" , defaultPrivateKey )
135
35
136
- respBody , err := io . ReadAll ( resp . Body )
36
+ privateKeyFile , err := os . ReadFile ( keyPath )
137
37
if err != nil {
38
+ slog .Info (fmt .Sprintf ("Failed to load private key: %v" , err ))
138
39
return "" , err
139
40
}
140
41
141
- var tok api.TokenResponse
142
- if err := json .Unmarshal (respBody , & tok ); err != nil {
143
- return "" , err
144
- }
145
-
146
- return tok .Token , nil
147
- }
148
-
149
- // Bytes returns a byte slice of the data to sign for the request
150
- func (s SignatureData ) Bytes () []byte {
151
- // We first derive the content hash of the request body using:
152
- // base64(hex(sha256(request body)))
153
-
154
- hash := sha256 .Sum256 (s .Data )
155
- hashHex := make ([]byte , hex .EncodedLen (len (hash )))
156
- hex .Encode (hashHex , hash [:])
157
- contentHash := base64 .StdEncoding .EncodeToString (hashHex )
158
-
159
- // We then put the entire request together in a serialize string using:
160
- // "<method>,<uri>,<content hash>"
161
- // e.g. "GET,http://localhost,OTdkZjM1O..."
162
-
163
- return []byte (strings .Join ([]string {s .Method , s .Path , contentHash }, "," ))
164
- }
165
-
166
- // SignData takes a SignatureData object and signs it with a raw private key
167
- func (s SignatureData ) Sign (rawKey []byte ) (string , error ) {
168
- signer , err := ssh .ParsePrivateKey (rawKey )
42
+ privateKey , err := ssh .ParsePrivateKey (privateKeyFile )
169
43
if err != nil {
170
44
return "" , err
171
45
}
172
46
173
47
// get the pubkey, but remove the type
174
- pubKey := ssh .MarshalAuthorizedKey (signer .PublicKey ())
175
- parts := bytes .Split (pubKey , []byte (" " ))
48
+ publicKey := ssh .MarshalAuthorizedKey (privateKey .PublicKey ())
49
+ parts := bytes .Split (publicKey , []byte (" " ))
176
50
if len (parts ) < 2 {
177
51
return "" , fmt .Errorf ("malformed public key" )
178
52
}
179
53
180
- signedData , err := signer .Sign (nil , s . Bytes () )
54
+ signedData , err := privateKey .Sign (rand . Reader , bts )
181
55
if err != nil {
182
56
return "" , err
183
57
}
184
58
185
59
// signature is <pubkey>:<signature>
186
- sig := fmt .Sprintf ("%s:%s" , bytes .TrimSpace (parts [1 ]), base64 .StdEncoding .EncodeToString (signedData .Blob ))
187
- return sig , nil
60
+ return fmt .Sprintf ("%s:%s" , bytes .TrimSpace (parts [1 ]), base64 .StdEncoding .EncodeToString (signedData .Blob )), nil
188
61
}
0 commit comments