@@ -336,6 +336,84 @@ func TestInvalidBasicAuth(t *testing.T) {
336
336
wsServer .Stop ()
337
337
}
338
338
339
+ func TestInvalidOriginHeader (t * testing.T ) {
340
+ var wsServer * Server
341
+ wsServer = NewWebsocketServer (t , func (data []byte ) ([]byte , error ) {
342
+ assert .Fail (t , "no message should be received from client!" )
343
+ return nil , nil
344
+ })
345
+ wsServer .SetNewClientHandler (func (ws Channel ) {
346
+ assert .Fail (t , "no new connection should be received from client!" )
347
+ })
348
+ go wsServer .Start (serverPort , serverPath )
349
+ time .Sleep (500 * time .Millisecond )
350
+
351
+ // Test message
352
+ wsClient := NewWebsocketClient (t , func (data []byte ) ([]byte , error ) {
353
+ assert .Fail (t , "no message should be received from server!" )
354
+ return nil , nil
355
+ })
356
+ // Set invalid origin header
357
+ wsClient .SetHeaderValue ("Origin" , "example.org" )
358
+ host := fmt .Sprintf ("localhost:%v" , serverPort )
359
+ u := url.URL {Scheme : "ws" , Host : host , Path : testPath }
360
+ // Attempt to connect and expect cross-origin error
361
+ err := wsClient .Start (u .String ())
362
+ require .Error (t , err )
363
+ httpErr , ok := err .(HttpConnectionError )
364
+ require .True (t , ok )
365
+ assert .Equal (t , http .StatusForbidden , httpErr .HttpCode )
366
+ assert .Equal (t , http .StatusForbidden , httpErr .HttpCode )
367
+ assert .Equal (t , "websocket: bad handshake" , httpErr .Message )
368
+ // Cleanup
369
+ wsServer .Stop ()
370
+ }
371
+
372
+ func TestCustomOriginHeaderHandler (t * testing.T ) {
373
+ var wsServer * Server
374
+ origin := "example.org"
375
+ connected := make (chan bool )
376
+ wsServer = NewWebsocketServer (t , func (data []byte ) ([]byte , error ) {
377
+ assert .Fail (t , "no message should be received from client!" )
378
+ return nil , nil
379
+ })
380
+ wsServer .SetNewClientHandler (func (ws Channel ) {
381
+ connected <- true
382
+ })
383
+ wsServer .SetCheckOriginHandler (func (r * http.Request ) bool {
384
+ return r .Header .Get ("Origin" ) == origin
385
+ })
386
+ go wsServer .Start (serverPort , serverPath )
387
+ time .Sleep (500 * time .Millisecond )
388
+
389
+ // Test message
390
+ wsClient := NewWebsocketClient (t , func (data []byte ) ([]byte , error ) {
391
+ assert .Fail (t , "no message should be received from server!" )
392
+ return nil , nil
393
+ })
394
+ // Set invalid origin header (not example.org)
395
+ wsClient .SetHeaderValue ("Origin" , "localhost" )
396
+ host := fmt .Sprintf ("localhost:%v" , serverPort )
397
+ u := url.URL {Scheme : "ws" , Host : host , Path : testPath }
398
+ // Attempt to connect and expect cross-origin error
399
+ err := wsClient .Start (u .String ())
400
+ require .Error (t , err )
401
+ httpErr , ok := err .(HttpConnectionError )
402
+ require .True (t , ok )
403
+ assert .Equal (t , http .StatusForbidden , httpErr .HttpCode )
404
+ assert .Equal (t , http .StatusForbidden , httpErr .HttpCode )
405
+ assert .Equal (t , "websocket: bad handshake" , httpErr .Message )
406
+
407
+ // Re-attempt with correct header
408
+ wsClient .SetHeaderValue ("Origin" , "example.org" )
409
+ err = wsClient .Start (u .String ())
410
+ require .NoError (t , err )
411
+ result := <- connected
412
+ assert .True (t , result )
413
+ // Cleanup
414
+ wsServer .Stop ()
415
+ }
416
+
339
417
func TestValidClientTLSCertificate (t * testing.T ) {
340
418
var wsServer * Server
341
419
// Create self-signed TLS certificate
0 commit comments