-
Notifications
You must be signed in to change notification settings - Fork 65
/
blocklistloader-http.go
148 lines (127 loc) · 3.68 KB
/
blocklistloader-http.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package rdns
import (
"bufio"
"context"
"crypto/sha256"
"fmt"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"time"
)
// HTTPLoader reads blocklist rules from a server via HTTP(S).
type HTTPLoader struct {
url string
opt HTTPLoaderOptions
fromDisk bool
lastSuccess []string
}
// HTTPLoaderOptions holds options for HTTP blocklist loaders.
type HTTPLoaderOptions struct {
CacheDir string
// Don't fail when trying to load the list
AllowFailure bool
}
var _ BlocklistLoader = &HTTPLoader{}
const httpTimeout = 30 * time.Minute
func NewHTTPLoader(url string, opt HTTPLoaderOptions) *HTTPLoader {
return &HTTPLoader{url, opt, opt.CacheDir != "", nil}
}
func (l *HTTPLoader) Load() (rules []string, err error) {
log := Log.WithField("url", l.url)
log.Trace("loading blocklist")
// If AllowFailure is enabled, return the last successfully loaded list
// and nil
defer func() {
if err != nil && l.opt.AllowFailure {
log.WithError(err).Warn("failed to load blocklist, continuing with previous ruleset")
rules = l.lastSuccess
err = nil
} else {
l.lastSuccess = rules
}
}()
// If a cache-dir was given, try to load the list from disk on first load
if l.fromDisk {
start := time.Now()
l.fromDisk = false
rules, err := l.loadFromDisk()
if err == nil {
log.WithField("load-time", time.Since(start)).Trace("loaded blocklist from cache-dir")
return rules, err
}
log.WithError(err).Warn("unable to load cached list from disk, loading from upstream")
}
ctx, cancel := context.WithTimeout(context.Background(), httpTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", l.url, nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return nil, fmt.Errorf("got unexpected status code %d from %s", resp.StatusCode, l.url)
}
start := time.Now()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
rules = append(rules, scanner.Text())
}
log.WithField("load-time", time.Since(start)).Trace("completed loading blocklist")
// Cache the content to disk if the read from the remote server was successful
if scanner.Err() == nil && l.opt.CacheDir != "" {
log.Trace("writing rules to cache-dir")
if err := l.writeToDisk(rules); err != nil {
log.WithError(err).Error("failed to write rules to cache")
}
}
return rules, scanner.Err()
}
// Loads a cached version of the list from disk. The filename is made by hashing the URL with SHA256
// and the file is expect to be in cache-dir.
func (l *HTTPLoader) loadFromDisk() ([]string, error) {
f, err := os.Open(l.cacheFilename())
if err != nil {
return nil, err
}
defer f.Close()
var rules []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
rules = append(rules, scanner.Text())
}
return rules, scanner.Err()
}
func (l *HTTPLoader) writeToDisk(rules []string) (err error) {
f, err := ioutil.TempFile(l.opt.CacheDir, "routedns")
if err != nil {
return
}
fb := bufio.NewWriter(f)
defer func() {
tmpFileName := f.Name()
fb.Flush()
f.Close() // Close the file before trying to rename (Windows needs it)
if err == nil {
err = os.Rename(tmpFileName, l.cacheFilename())
}
// Make sure to clean up even if the move above was successful
os.Remove(tmpFileName)
}()
for _, r := range rules {
if _, err := fb.WriteString(r + "\n"); err != nil {
return err
}
}
return nil
}
// Returns the name of the list cache file, which is the SHA265 of url in the cache-dir.
func (l *HTTPLoader) cacheFilename() string {
name := fmt.Sprintf("%x", sha256.Sum256([]byte(l.url)))
return filepath.Join(l.opt.CacheDir, name)
}