Skip to content

Commit e43648a

Browse files
mxyngjmorganca
authored andcommitted
rerefactor
1 parent 823a520 commit e43648a

File tree

9 files changed

+224
-251
lines changed

9 files changed

+224
-251
lines changed

app/lifecycle/updater.go

+30-8
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ package lifecycle
22

33
import (
44
"context"
5+
"crypto/rand"
56
"encoding/json"
67
"errors"
78
"fmt"
89
"io"
910
"log/slog"
1011
"mime"
1112
"net/http"
13+
"net/url"
1214
"os"
1315
"path"
1416
"path/filepath"
@@ -21,7 +23,7 @@ import (
2123
)
2224

2325
var (
24-
UpdateCheckURLBase = "https://ollama.ai/api/update"
26+
UpdateCheckURLBase = "https://ollama.com/api/update"
2527
UpdateDownloaded = false
2628
)
2729

@@ -47,22 +49,42 @@ func getClient(req *http.Request) http.Client {
4749

4850
func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
4951
var updateResp UpdateResponse
50-
updateCheckURL := UpdateCheckURLBase + "?os=" + runtime.GOOS + "&arch=" + runtime.GOARCH + "&version=" + version.Version
51-
headers := make(http.Header)
52-
err := auth.SignRequest(http.MethodGet, updateCheckURL, nil, headers)
52+
53+
requestURL, err := url.Parse(UpdateCheckURLBase)
5354
if err != nil {
54-
slog.Info(fmt.Sprintf("failed to sign update request %s", err))
55+
return false, updateResp
56+
}
57+
58+
query := requestURL.Query()
59+
query.Add("os", runtime.GOOS)
60+
query.Add("arch", runtime.GOARCH)
61+
query.Add("version", version.Version)
62+
query.Add("ts", fmt.Sprintf("%d", time.Now().Unix()))
63+
64+
nonce, err := auth.NewNonce(rand.Reader, 16)
65+
if err != nil {
66+
return false, updateResp
5567
}
56-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, updateCheckURL, nil)
68+
69+
query.Add("nonce", nonce)
70+
requestURL.RawQuery = query.Encode()
71+
72+
data := []byte(fmt.Sprintf("%s,%s", http.MethodGet, requestURL.RequestURI()))
73+
signature, err := auth.Sign(ctx, data)
74+
if err != nil {
75+
return false, updateResp
76+
}
77+
78+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
5779
if err != nil {
5880
slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
5981
return false, updateResp
6082
}
61-
req.Header = headers
83+
req.Header.Set("Authorization", signature)
6284
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
6385
client := getClient(req)
6486

65-
slog.Debug(fmt.Sprintf("checking for available update at %s with headers %v", updateCheckURL, headers))
87+
slog.Debug(fmt.Sprintf("checking for available update at %s with headers %v", requestURL, req.Header))
6688
resp, err := client.Do(req)
6789
if err != nil {
6890
slog.Warn(fmt.Sprintf("failed to check for update: %s", err))

app/ollama.iss

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#define MyAppVersion "0.0.0"
1313
#endif
1414
#define MyAppPublisher "Ollama, Inc."
15-
#define MyAppURL "https://ollama.ai/"
15+
#define MyAppURL "https://ollama.com/"
1616
#define MyAppExeName "ollama app.exe"
1717
#define MyIcon ".\assets\app.ico"
1818

auth/auth.go

+13-140
Original file line numberDiff line numberDiff line change
@@ -4,185 +4,58 @@ import (
44
"bytes"
55
"context"
66
"crypto/rand"
7-
"crypto/sha256"
87
"encoding/base64"
9-
"encoding/hex"
10-
"encoding/json"
118
"fmt"
129
"io"
1310
"log/slog"
14-
"net/http"
15-
"net/url"
1611
"os"
1712
"path/filepath"
18-
"strconv"
19-
"strings"
20-
"time"
2113

2214
"golang.org/x/crypto/ssh"
23-
24-
"github.com/jmorganca/ollama/api"
25-
)
26-
27-
const (
28-
KeyType = "id_ed25519"
2915
)
3016

31-
type AuthRedirect struct {
32-
Realm string
33-
Service string
34-
Scope string
35-
}
17+
const defaultPrivateKey = "id_ed25519"
3618

37-
type SignatureData struct {
38-
Method string
39-
Path string
40-
Data []byte
41-
}
42-
43-
func generateNonce(length int) (string, error) {
19+
func NewNonce(r io.Reader, length int) (string, error) {
4420
nonce := make([]byte, length)
45-
_, err := rand.Read(nonce)
46-
if err != nil {
21+
if _, err := io.ReadFull(r, nonce); err != nil {
4722
return "", err
4823
}
49-
return base64.RawURLEncoding.EncodeToString(nonce), nil
50-
}
51-
52-
func (r AuthRedirect) URL() (*url.URL, error) {
53-
redirectURL, err := url.Parse(r.Realm)
54-
if err != nil {
55-
return nil, err
56-
}
57-
58-
values := redirectURL.Query()
59-
60-
values.Add("service", r.Service)
61-
62-
for _, s := range strings.Split(r.Scope, " ") {
63-
values.Add("scope", s)
64-
}
6524

66-
values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
67-
68-
nonce, err := generateNonce(16)
69-
if err != nil {
70-
return nil, err
71-
}
72-
values.Add("nonce", nonce)
73-
74-
redirectURL.RawQuery = values.Encode()
75-
return redirectURL, nil
25+
return base64.RawURLEncoding.EncodeToString(nonce), nil
7626
}
7727

78-
func SignRequest(method, url string, data []byte, headers http.Header) error {
28+
func Sign(ctx context.Context, bts []byte) (string, error) {
7929
home, err := os.UserHomeDir()
8030
if err != nil {
81-
return err
82-
}
83-
84-
keyPath := filepath.Join(home, ".ollama", KeyType)
85-
86-
rawKey, err := os.ReadFile(keyPath)
87-
if err != nil {
88-
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
89-
return err
90-
}
91-
92-
s := SignatureData{
93-
Method: method,
94-
Path: url,
95-
Data: data,
96-
}
97-
98-
sig, err := s.Sign(rawKey)
99-
if err != nil {
100-
return err
101-
}
102-
103-
headers.Set("Authorization", sig)
104-
return nil
105-
}
106-
107-
func GetAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
108-
redirectURL, err := redirData.URL()
109-
if err != nil {
110-
return "", err
111-
}
112-
113-
headers := make(http.Header)
114-
err = SignRequest(http.MethodGet, redirectURL.String(), nil, headers)
115-
if err != nil {
116-
return "", err
117-
}
118-
resp, err := MakeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
119-
if err != nil {
120-
slog.Info(fmt.Sprintf("couldn't get token: %q", err))
12131
return "", err
12232
}
123-
defer resp.Body.Close()
124-
125-
if resp.StatusCode >= http.StatusBadRequest {
126-
responseBody, err := io.ReadAll(resp.Body)
127-
if err != nil {
128-
return "", fmt.Errorf("%d: %v", resp.StatusCode, err)
129-
} else if len(responseBody) > 0 {
130-
return "", fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
131-
}
13233

133-
return "", fmt.Errorf("%s", resp.Status)
134-
}
34+
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
13535

136-
respBody, err := io.ReadAll(resp.Body)
36+
privateKeyFile, err := os.ReadFile(keyPath)
13737
if err != nil {
38+
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
13839
return "", err
13940
}
14041

141-
var tok api.TokenResponse
142-
if err := json.Unmarshal(respBody, &tok); err != nil {
143-
return "", err
144-
}
145-
146-
return tok.Token, nil
147-
}
148-
149-
// Bytes returns a byte slice of the data to sign for the request
150-
func (s SignatureData) Bytes() []byte {
151-
// We first derive the content hash of the request body using:
152-
// base64(hex(sha256(request body)))
153-
154-
hash := sha256.Sum256(s.Data)
155-
hashHex := make([]byte, hex.EncodedLen(len(hash)))
156-
hex.Encode(hashHex, hash[:])
157-
contentHash := base64.StdEncoding.EncodeToString(hashHex)
158-
159-
// We then put the entire request together in a serialize string using:
160-
// "<method>,<uri>,<content hash>"
161-
// e.g. "GET,http://localhost,OTdkZjM1O..."
162-
163-
return []byte(strings.Join([]string{s.Method, s.Path, contentHash}, ","))
164-
}
165-
166-
// SignData takes a SignatureData object and signs it with a raw private key
167-
func (s SignatureData) Sign(rawKey []byte) (string, error) {
168-
signer, err := ssh.ParsePrivateKey(rawKey)
42+
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
16943
if err != nil {
17044
return "", err
17145
}
17246

17347
// get the pubkey, but remove the type
174-
pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey())
175-
parts := bytes.Split(pubKey, []byte(" "))
48+
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
49+
parts := bytes.Split(publicKey, []byte(" "))
17650
if len(parts) < 2 {
17751
return "", fmt.Errorf("malformed public key")
17852
}
17953

180-
signedData, err := signer.Sign(nil, s.Bytes())
54+
signedData, err := privateKey.Sign(rand.Reader, bts)
18155
if err != nil {
18256
return "", err
18357
}
18458

18559
// signature is <pubkey>:<signature>
186-
sig := fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob))
187-
return sig, nil
60+
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
18861
}

auth/request.go

-72
This file was deleted.

0 commit comments

Comments
 (0)