Skip to content

Commit b2c5d05

Browse files
committed
Test custom websocket headers and origin
Signed-off-by: Lorenzo Donini <[email protected]>
1 parent fdf0156 commit b2c5d05

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

ws/websocket_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,84 @@ func TestInvalidBasicAuth(t *testing.T) {
336336
wsServer.Stop()
337337
}
338338

339+
func TestInvalidOriginHeader(t *testing.T) {
340+
var wsServer *Server
341+
wsServer = NewWebsocketServer(t, func(data []byte) ([]byte, error) {
342+
assert.Fail(t, "no message should be received from client!")
343+
return nil, nil
344+
})
345+
wsServer.SetNewClientHandler(func(ws Channel) {
346+
assert.Fail(t, "no new connection should be received from client!")
347+
})
348+
go wsServer.Start(serverPort, serverPath)
349+
time.Sleep(500 * time.Millisecond)
350+
351+
// Test message
352+
wsClient := NewWebsocketClient(t, func(data []byte) ([]byte, error) {
353+
assert.Fail(t, "no message should be received from server!")
354+
return nil, nil
355+
})
356+
// Set invalid origin header
357+
wsClient.SetHeaderValue("Origin", "example.org")
358+
host := fmt.Sprintf("localhost:%v", serverPort)
359+
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
360+
// Attempt to connect and expect cross-origin error
361+
err := wsClient.Start(u.String())
362+
require.Error(t, err)
363+
httpErr, ok := err.(HttpConnectionError)
364+
require.True(t, ok)
365+
assert.Equal(t, http.StatusForbidden, httpErr.HttpCode)
366+
assert.Equal(t, http.StatusForbidden, httpErr.HttpCode)
367+
assert.Equal(t, "websocket: bad handshake", httpErr.Message)
368+
// Cleanup
369+
wsServer.Stop()
370+
}
371+
372+
func TestCustomOriginHeaderHandler(t *testing.T) {
373+
var wsServer *Server
374+
origin := "example.org"
375+
connected := make(chan bool)
376+
wsServer = NewWebsocketServer(t, func(data []byte) ([]byte, error) {
377+
assert.Fail(t, "no message should be received from client!")
378+
return nil, nil
379+
})
380+
wsServer.SetNewClientHandler(func(ws Channel) {
381+
connected <- true
382+
})
383+
wsServer.SetCheckOriginHandler(func(r *http.Request) bool {
384+
return r.Header.Get("Origin") == origin
385+
})
386+
go wsServer.Start(serverPort, serverPath)
387+
time.Sleep(500 * time.Millisecond)
388+
389+
// Test message
390+
wsClient := NewWebsocketClient(t, func(data []byte) ([]byte, error) {
391+
assert.Fail(t, "no message should be received from server!")
392+
return nil, nil
393+
})
394+
// Set invalid origin header (not example.org)
395+
wsClient.SetHeaderValue("Origin", "localhost")
396+
host := fmt.Sprintf("localhost:%v", serverPort)
397+
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
398+
// Attempt to connect and expect cross-origin error
399+
err := wsClient.Start(u.String())
400+
require.Error(t, err)
401+
httpErr, ok := err.(HttpConnectionError)
402+
require.True(t, ok)
403+
assert.Equal(t, http.StatusForbidden, httpErr.HttpCode)
404+
assert.Equal(t, http.StatusForbidden, httpErr.HttpCode)
405+
assert.Equal(t, "websocket: bad handshake", httpErr.Message)
406+
407+
// Re-attempt with correct header
408+
wsClient.SetHeaderValue("Origin", "example.org")
409+
err = wsClient.Start(u.String())
410+
require.NoError(t, err)
411+
result := <-connected
412+
assert.True(t, result)
413+
// Cleanup
414+
wsServer.Stop()
415+
}
416+
339417
func TestValidClientTLSCertificate(t *testing.T) {
340418
var wsServer *Server
341419
// Create self-signed TLS certificate

0 commit comments

Comments
 (0)