Skip to content

Commit 5d0398e

Browse files
committed
blockservice: add WithContentBlocker option
The goal is to help with ipfs-shipyard/nopfs#34.
1 parent b8ac21b commit 5d0398e

File tree

3 files changed

+142
-4
lines changed

3 files changed

+142
-4
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ The following emojis are used to highlight certain changes:
1717
### Added
1818

1919
- `blockservice` now has `ContextWithSession` and `EmbedSessionInContext` functions, which allows to embed a session in a context. Future calls to `BlockGetter.GetBlock`, `BlockGetter.GetBlocks` and `NewSession` will use the session in the context.
20+
- `blockservice` now has `WithContentBlocker` option which allows to filter Add and Get requests by CID.
2021

2122
### Changed
2223

blockservice/blockservice.go

+71-4
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,24 @@ type BoundedBlockService interface {
7171
Allowlist() verifcid.Allowlist
7272
}
7373

74+
// Blocker returns err != nil if the CID is disallowed to be fetched or stored in blockservice.
75+
// It returns an error so error messages could be passed.
76+
type Blocker func(cid.Cid) error
77+
78+
// BlockedBlockService is a Blockservice bounded via an arbitrary cid [Blocker].
79+
type BlockedBlockService interface {
80+
BlockService
81+
82+
// Blocker might return [nil], then no blocking is to be done.
83+
Blocker() Blocker
84+
}
85+
7486
var _ BoundedBlockService = (*blockService)(nil)
87+
var _ BlockedBlockService = (*blockService)(nil)
7588

7689
type blockService struct {
7790
allowlist verifcid.Allowlist
91+
blocker Blocker
7892
blockstore blockstore.Blockstore
7993
exchange exchange.Interface
8094
// If checkFirst is true then first check that a block doesn't
@@ -99,6 +113,13 @@ func WithAllowlist(allowlist verifcid.Allowlist) Option {
99113
}
100114
}
101115

116+
// WithContentBlocker allows to filter what blocks can be fetched or added to the blockservice.
117+
func WithContentBlocker(blocker Blocker) Option {
118+
return func(bs *blockService) {
119+
bs.blocker = blocker
120+
}
121+
}
122+
102123
// New creates a BlockService with given datastore instance.
103124
func New(bs blockstore.Blockstore, exchange exchange.Interface, opts ...Option) BlockService {
104125
if exchange == nil {
@@ -141,6 +162,10 @@ func (s *blockService) Allowlist() verifcid.Allowlist {
141162
return s.allowlist
142163
}
143164

165+
func (s *blockService) Blocker() Blocker {
166+
return s.blocker
167+
}
168+
144169
// NewSession creates a new session that allows for
145170
// controlled exchange of wantlists to decrease the bandwidth overhead.
146171
// If the current exchange is a SessionExchange, a new exchange
@@ -171,6 +196,13 @@ func (s *blockService) AddBlock(ctx context.Context, o blocks.Block) error {
171196
if err != nil {
172197
return err
173198
}
199+
200+
if s.blocker != nil {
201+
if err := s.blocker(c); err != nil {
202+
return err
203+
}
204+
}
205+
174206
if s.checkFirst {
175207
if has, err := s.blockstore.Has(ctx, c); has || err != nil {
176208
return err
@@ -198,10 +230,17 @@ func (s *blockService) AddBlocks(ctx context.Context, bs []blocks.Block) error {
198230

199231
// hash security
200232
for _, b := range bs {
201-
err := verifcid.ValidateCid(s.allowlist, b.Cid())
233+
c := b.Cid()
234+
err := verifcid.ValidateCid(s.allowlist, c)
202235
if err != nil {
203236
return err
204237
}
238+
239+
if s.blocker != nil {
240+
if err := s.blocker(c); err != nil {
241+
return err
242+
}
243+
}
205244
}
206245
var toput []blocks.Block
207246
if s.checkFirst {
@@ -261,6 +300,12 @@ func getBlock(ctx context.Context, c cid.Cid, bs BlockService, fetchFactory func
261300
return nil, err
262301
}
263302

303+
if blocker := grabBlockerFromBlockservice(bs); blocker != nil {
304+
if err := blocker(c); err != nil {
305+
return nil, err
306+
}
307+
}
308+
264309
blockstore := bs.Blockstore()
265310

266311
block, err := blockstore.Get(ctx, c)
@@ -320,13 +365,20 @@ func getBlocks(ctx context.Context, ks []cid.Cid, blockservice BlockService, fet
320365
defer close(out)
321366

322367
allowlist := grabAllowlistFromBlockservice(blockservice)
368+
blocker := grabBlockerFromBlockservice(blockservice)
323369

324370
var lastAllValidIndex int
325371
var c cid.Cid
326372
for lastAllValidIndex, c = range ks {
327373
if err := verifcid.ValidateCid(allowlist, c); err != nil {
328374
break
329375
}
376+
377+
if blocker != nil {
378+
if err := blocker(c); err != nil {
379+
break
380+
}
381+
}
330382
}
331383

332384
if lastAllValidIndex != len(ks) {
@@ -335,11 +387,19 @@ func getBlocks(ctx context.Context, ks []cid.Cid, blockservice BlockService, fet
335387
copy(ks2, ks[:lastAllValidIndex]) // fast path for already filtered elements
336388
for _, c := range ks[lastAllValidIndex:] { // don't rescan already scanned elements
337389
// hash security
338-
if err := verifcid.ValidateCid(allowlist, c); err == nil {
339-
ks2 = append(ks2, c)
340-
} else {
390+
if err := verifcid.ValidateCid(allowlist, c); err != nil {
341391
logger.Errorf("unsafe CID (%s) passed to blockService.GetBlocks: %s", c, err)
392+
continue
393+
}
394+
395+
if blocker != nil {
396+
if err := blocker(c); err != nil {
397+
logger.Errorf("blocked CID (%s) passed to blockService.GetBlocks: %s", c, err)
398+
continue
399+
}
342400
}
401+
402+
ks2 = append(ks2, c)
343403
}
344404
ks = ks2
345405
}
@@ -526,3 +586,10 @@ func grabAllowlistFromBlockservice(bs BlockService) verifcid.Allowlist {
526586
}
527587
return verifcid.DefaultAllowlist
528588
}
589+
590+
func grabBlockerFromBlockservice(bs BlockService) Blocker {
591+
if bbs, ok := bs.(BlockedBlockService); ok {
592+
return bbs.Blocker()
593+
}
594+
return nil
595+
}

blockservice/blockservice_test.go

+70
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package blockservice
22

33
import (
44
"context"
5+
"errors"
56
"testing"
67

78
blockstore "github.com/ipfs/boxo/blockstore"
@@ -353,3 +354,72 @@ func TestContextSession(t *testing.T) {
353354
"session must be deduped in all invocations on the same context",
354355
)
355356
}
357+
358+
func TestBlocker(t *testing.T) {
359+
t.Parallel()
360+
a := assert.New(t)
361+
362+
ctx, cancel := context.WithCancel(context.Background())
363+
defer cancel()
364+
365+
bgen := butil.NewBlockGenerator()
366+
allowed := bgen.Next()
367+
notAllowed := bgen.Next()
368+
369+
var disallowed = errors.New("disallowed")
370+
371+
bs := blockstore.NewBlockstore(dssync.MutexWrap(ds.NewMapDatastore()))
372+
service := New(bs, nil, WithContentBlocker(func(c cid.Cid) error {
373+
if c == notAllowed.Cid() {
374+
return disallowed
375+
}
376+
return nil
377+
}))
378+
379+
// try putting
380+
a.NoError(service.AddBlock(ctx, allowed))
381+
has, err := bs.Has(ctx, allowed.Cid())
382+
a.NoError(err)
383+
a.True(has, "block was not added even tho it is not blocked")
384+
a.NoError(service.DeleteBlock(ctx, allowed.Cid()))
385+
386+
a.ErrorIs(service.AddBlock(ctx, notAllowed), disallowed)
387+
has, err = bs.Has(ctx, notAllowed.Cid())
388+
a.NoError(err)
389+
a.False(has, "block was added even tho it is blocked")
390+
391+
a.NoError(service.AddBlocks(ctx, []blocks.Block{allowed}))
392+
has, err = bs.Has(ctx, allowed.Cid())
393+
a.NoError(err)
394+
a.True(has, "block was not added even tho it is not blocked")
395+
a.NoError(service.DeleteBlock(ctx, allowed.Cid()))
396+
397+
a.ErrorIs(service.AddBlocks(ctx, []blocks.Block{notAllowed}), disallowed)
398+
has, err = bs.Has(ctx, notAllowed.Cid())
399+
a.NoError(err)
400+
a.False(has, "block was added even tho it is blocked")
401+
402+
// now try fetch
403+
a.NoError(bs.Put(ctx, allowed))
404+
a.NoError(bs.Put(ctx, notAllowed))
405+
406+
block, err := service.GetBlock(ctx, allowed.Cid())
407+
a.NoError(err)
408+
a.Equal(block.RawData(), allowed.RawData())
409+
410+
_, err = service.GetBlock(ctx, notAllowed.Cid())
411+
a.ErrorIs(err, disallowed)
412+
413+
var gotAllowed bool
414+
for block := range service.GetBlocks(ctx, []cid.Cid{allowed.Cid(), notAllowed.Cid()}) {
415+
switch block.Cid() {
416+
case allowed.Cid():
417+
gotAllowed = true
418+
case notAllowed.Cid():
419+
t.Error("got disallowed block")
420+
default:
421+
t.Fatalf("got unrelated block: %s", block.Cid())
422+
}
423+
}
424+
a.True(gotAllowed, "did not got allowed block")
425+
}

0 commit comments

Comments
 (0)