Skip to content

Commit 8ea409b

Browse files
committed
Final changes
1 parent 8400964 commit 8ea409b

File tree

3 files changed

+36
-33
lines changed

3 files changed

+36
-33
lines changed

internal/ghmcp/server.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ type MCPServerConfig struct {
6262

6363
const stdioServerLogPrefix = "stdioserver"
6464

65-
func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
65+
func NewMCPServer(cfg MCPServerConfig, logger *slog.Logger) (*server.MCPServer, error) {
6666
apiHost, err := parseAPIHost(cfg.Host)
6767
if err != nil {
6868
return nil, fmt.Errorf("failed to parse API host: %w", err)
@@ -88,6 +88,9 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
8888
if cfg.RepoAccessTTL != nil {
8989
repoAccessOpts = append(repoAccessOpts, lockdown.WithTTL(*cfg.RepoAccessTTL))
9090
}
91+
92+
repoAccessLogger := logger.With("component", "lockdown")
93+
repoAccessOpts = append(repoAccessOpts, lockdown.WithLogger(repoAccessLogger))
9194
var repoAccessCache *lockdown.RepoAccessCache
9295
if cfg.LockdownMode {
9396
repoAccessCache = lockdown.GetInstance(gqlClient, repoAccessOpts...)
@@ -273,7 +276,7 @@ func RunStdioServer(cfg StdioServerConfig) error {
273276
ContentWindowSize: cfg.ContentWindowSize,
274277
LockdownMode: cfg.LockdownMode,
275278
RepoAccessTTL: cfg.RepoAccessCacheTTL,
276-
})
279+
}, logger)
277280
if err != nil {
278281
return fmt.Errorf("failed to create MCP server: %w", err)
279282
}

pkg/lockdown/lockdown.go

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@ type repoAccessCacheEntry struct {
2727
isPrivate bool
2828
knownUsers map[string]bool // normalized login -> has push access
2929
viewerLogin string
30-
viewerType string
3130
}
3231

3332
// RepoAccessInfo captures repository metadata needed for lockdown decisions.
3433
type RepoAccessInfo struct {
3534
IsPrivate bool
3635
HasPushAccess bool
3736
ViewerLogin string
38-
ViewerType string
3937
}
4038

4139
const (
@@ -89,10 +87,7 @@ func GetInstance(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessC
8987
cache: cache2go.Cache(defaultRepoAccessCacheKey),
9088
ttl: defaultRepoAccessTTL,
9189
trustedBotLogins: map[string]struct{}{
92-
"dependabot[bot]": {},
93-
"dependabot-preview[bot]": {},
94-
"github-actions[bot]": {},
95-
"github-copilot[bot]": {},
90+
"copilot": {},
9691
},
9792
}
9893
for _, opt := range opts {
@@ -121,11 +116,13 @@ type CacheStats struct {
121116
func (c *RepoAccessCache) IsSafeContent(ctx context.Context, username, owner, repo string) (bool, error) {
122117
repoInfo, err := c.getRepoAccessInfo(ctx, username, owner, repo)
123118
if err != nil {
124-
c.logDebug("error checking repo access info for content filtering", "owner", owner, "repo", repo, "user", username, "error", err)
125119
return false, err
126120
}
127121

128-
if c.isTrustedBot(username, repoInfo.ViewerType) || repoInfo.IsPrivate || repoInfo.ViewerLogin == strings.ToLower(username) {
122+
c.logInfo(ctx, fmt.Sprintf("evaluated repo access fur user %s to %s/%s for content filtering, result: hasPushAccess=%t, isPrivate=%t",
123+
username, owner, repo, repoInfo.HasPushAccess, repoInfo.IsPrivate))
124+
125+
if c.isTrustedBot(username) || repoInfo.IsPrivate || repoInfo.ViewerLogin == strings.ToLower(username) {
129126
return true, nil
130127
}
131128
return repoInfo.HasPushAccess, nil
@@ -146,32 +143,34 @@ func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner
146143
if err == nil {
147144
entry := cacheItem.Data().(*repoAccessCacheEntry)
148145
if cachedHasPush, known := entry.knownUsers[userKey]; known {
149-
c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username)
146+
c.logDebug(ctx, "repo access cache hit")
150147
return RepoAccessInfo{
151148
IsPrivate: entry.isPrivate,
152149
HasPushAccess: cachedHasPush,
153150
ViewerLogin: entry.viewerLogin,
154151
}, nil
155152
}
156-
c.logDebug("known users cache miss", "owner", owner, "repo", repo, "user", username)
153+
154+
c.logDebug(ctx, "known users cache miss")
155+
157156
info, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo)
158157
if queryErr != nil {
159158
return RepoAccessInfo{}, queryErr
160159
}
160+
161161
entry.knownUsers[userKey] = info.HasPushAccess
162162
entry.viewerLogin = info.ViewerLogin
163-
entry.viewerType = info.ViewerType
164163
entry.isPrivate = info.IsPrivate
165164
c.cache.Add(key, c.ttl, entry)
165+
166166
return RepoAccessInfo{
167167
IsPrivate: entry.isPrivate,
168168
HasPushAccess: entry.knownUsers[userKey],
169169
ViewerLogin: entry.viewerLogin,
170-
ViewerType: entry.viewerType,
171170
}, nil
172171
}
173172

174-
c.logDebug("repo access cache miss", "owner", owner, "repo", repo, "user", username)
173+
c.logDebug(ctx, "repo access cache miss")
175174

176175
info, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo)
177176
if queryErr != nil {
@@ -183,15 +182,13 @@ func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner
183182
knownUsers: map[string]bool{userKey: info.HasPushAccess},
184183
isPrivate: info.IsPrivate,
185184
viewerLogin: info.ViewerLogin,
186-
viewerType: info.ViewerType,
187185
}
188186
c.cache.Add(key, c.ttl, entry)
189187

190188
return RepoAccessInfo{
191189
IsPrivate: entry.isPrivate,
192190
HasPushAccess: entry.knownUsers[userKey],
193191
ViewerLogin: entry.viewerLogin,
194-
ViewerType: entry.viewerType,
195192
}, nil
196193
}
197194

@@ -202,8 +199,7 @@ func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, own
202199

203200
var query struct {
204201
Viewer struct {
205-
Typename string `graphql:"__typename"`
206-
Login githubv4.String
202+
Login githubv4.String
207203
}
208204
Repository struct {
209205
IsPrivate githubv4.Boolean
@@ -242,20 +238,28 @@ func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, own
242238
IsPrivate: bool(query.Repository.IsPrivate),
243239
HasPushAccess: hasPush,
244240
ViewerLogin: string(query.Viewer.Login),
245-
ViewerType: query.Viewer.Typename,
246241
}, nil
247242
}
248243

249-
func (c *RepoAccessCache) logDebug(msg string, args ...any) {
250-
if c != nil && c.logger != nil {
251-
c.logger.Debug(msg, args...)
244+
func (c *RepoAccessCache) log(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) {
245+
if c == nil || c.logger == nil {
246+
return
247+
}
248+
if !c.logger.Enabled(ctx, level) {
249+
return
252250
}
251+
c.logger.LogAttrs(ctx, level, msg, attrs...)
253252
}
254253

255-
func (c *RepoAccessCache) isTrustedBot(username string, viewerType string) bool {
256-
if viewerType != "Bot" {
257-
return false
258-
}
254+
func (c *RepoAccessCache) logDebug(ctx context.Context, msg string, attrs ...slog.Attr) {
255+
c.log(ctx, slog.LevelDebug, msg, attrs...)
256+
}
257+
258+
func (c *RepoAccessCache) logInfo(ctx context.Context, msg string, attrs ...slog.Attr) {
259+
c.log(ctx, slog.LevelInfo, msg, attrs...)
260+
}
261+
262+
func (c *RepoAccessCache) isTrustedBot(username string) bool {
259263
_, ok := c.trustedBotLogins[strings.ToLower(username)]
260264
return ok
261265
}

pkg/lockdown/lockdown_test.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ const (
1919

2020
type repoAccessQuery struct {
2121
Viewer struct {
22-
Typename string `graphql:"__typename"`
23-
Login githubv4.String
22+
Login githubv4.String
2423
}
2524
Repository struct {
2625
IsPrivate githubv4.Boolean
@@ -67,8 +66,7 @@ func newMockRepoAccessCache(t *testing.T, ttl time.Duration) (*RepoAccessCache,
6766

6867
response := githubv4mock.DataResponse(map[string]any{
6968
"viewer": map[string]any{
70-
"__typename": "User",
71-
"login": testUser,
69+
"login": testUser,
7270
},
7371
"repository": map[string]any{
7472
"isPrivate": false,
@@ -101,7 +99,6 @@ func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) {
10199
info, err := cache.getRepoAccessInfo(ctx, testUser, testOwner, testRepo)
102100
require.NoError(t, err)
103101
require.Equal(t, testUser, info.ViewerLogin)
104-
require.Equal(t, "User", info.ViewerType)
105102
require.True(t, info.HasPushAccess)
106103
require.EqualValues(t, 1, transport.CallCount())
107104

@@ -110,7 +107,6 @@ func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) {
110107
info, err = cache.getRepoAccessInfo(ctx, testUser, testOwner, testRepo)
111108
require.NoError(t, err)
112109
require.Equal(t, testUser, info.ViewerLogin)
113-
require.Equal(t, "User", info.ViewerType)
114110
require.True(t, info.HasPushAccess)
115111
require.EqualValues(t, 2, transport.CallCount())
116112
}

0 commit comments

Comments
 (0)