diff --git a/.changeset/fair-beds-take.md b/.changeset/fair-beds-take.md new file mode 100644 index 00000000..7336d7e9 --- /dev/null +++ b/.changeset/fair-beds-take.md @@ -0,0 +1,5 @@ +--- +"github.com/livekit/protocol": patch +--- + +Added warning prints to SIP headers diff --git a/livekit/sip.go b/livekit/sip.go index ea7d531c..2832e838 100644 --- a/livekit/sip.go +++ b/livekit/sip.go @@ -11,6 +11,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/livekit/protocol/logger" "github.com/livekit/protocol/utils/xtwirp" "golang.org/x/text/language" ) @@ -263,6 +264,29 @@ func validateHeaderValues(headers map[string]string) error { return nil } +// validateHeaders makes sure header names/keys and values are per SIP specifications +func validateHeaders(headers map[string]string) error { + for headerName, headerValue := range headers { + if err := ValidateHeaderName(headerName); err != nil { + return fmt.Errorf("invalid header name: %w", err) + } + if err := ValidateHeaderValue(headerName, headerValue); err != nil { + return fmt.Errorf("invalid header value for %s: %w", headerName, err) + } + } + return nil +} + +// validateHeaderNames Makes sure the values of the given map correspond to valid SIP header names +func validateHeaderNames(attributesToHeaders map[string]string) error { + for _, headerName := range attributesToHeaders { + if err := ValidateHeaderName(headerName); err != nil { + return fmt.Errorf("invalid header name: %w", err) + } + } + return nil +} + func (p *SIPTrunkInfo) Validate() error { if len(p.InboundNumbersRegex) != 0 { return fmt.Errorf("trunks with InboundNumbersRegex are deprecated") @@ -368,6 +392,15 @@ func (p *SIPInboundTrunkInfo) Validate() error { if err := validateHeaderValues(p.AttributesToHeaders); err != nil { return err } + if err := validateHeaders(p.Headers); err != nil { + logger.Warnw("Header validation failed for Headers field", err) + // TODO: Once we're happy with the validation, we want this to error out + } + // Don't bother with HeadersToAttributes. If they're invalid, we just won't match + if err := validateHeaderNames(p.AttributesToHeaders); err != nil { + logger.Warnw("Header validation failed for AttributesToHeaders field", err) + // TODO: Once we're happy with the validation, we want this to error out + } return nil } @@ -455,6 +488,15 @@ func (p *SIPOutboundTrunkInfo) Validate() error { if err := validateHeaderValues(p.AttributesToHeaders); err != nil { return err } + if err := validateHeaders(p.Headers); err != nil { + logger.Warnw("Header validation failed for Headers field", err) + // TODO: Once we're happy with the validation, we want this to error out + } + // Don't bother with HeadersToAttributes. If they're invalid, we just won't match + if err := validateHeaderNames(p.AttributesToHeaders); err != nil { + logger.Warnw("Header validation failed for AttributesToHeaders field", err) + // TODO: Once we're happy with the validation, we want this to error out + } return nil } @@ -472,6 +514,11 @@ func (p *SIPOutboundConfig) Validate() error { if err := validateHeaderValues(p.AttributesToHeaders); err != nil { return err } + // Don't bother with HeadersToAttributes. If they're invalid, we just won't match + if err := validateHeaderNames(p.AttributesToHeaders); err != nil { + logger.Warnw("Header validation failed for AttributesToHeaders field", err) + // No error, just a warning for SIP RFC validation for now + } return nil } @@ -677,13 +724,18 @@ func (p *CreateSIPParticipantRequest) Validate() error { return err } + if err := validateHeaders(p.Headers); err != nil { + logger.Warnw("Header validation failed for Headers field", err) + // TODO: Once we're happy with the validation, we want this to error out + } + // Validate display_name if provided if p.DisplayName != nil { if len(*p.DisplayName) > 128 { return errors.New("display_name too long (max 128 characters)") } - // TODO: Validate display name doesn't contain invalid characters + // TODO: Once we're happy with the validation, we want this to error out } // Validate destination if provided @@ -775,6 +827,11 @@ func (p *TransferSIPParticipantRequest) Validate() error { return err } + if err := validateHeaders(p.Headers); err != nil { + logger.Warnw("Header validation failed for Headers field", err) + // TODO: Once we're happy with the validation, we want this to error out + } + return nil } diff --git a/livekit/sip_validation.go b/livekit/sip_validation.go new file mode 100644 index 00000000..56386e7a --- /dev/null +++ b/livekit/sip_validation.go @@ -0,0 +1,359 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package livekit + +import ( + "errors" + fmt "fmt" + "strconv" + "strings" +) + +// RFC 3261 compliant validation functions for SIP headers and messages + +type allowedCharacters struct { + ascii [127]bool + utf8 bool +} + +func NewAllowedCharacters() *allowedCharacters { + return &allowedCharacters{} +} + +func (a *allowedCharacters) AddUTF8() error { + a.utf8 = true + return nil +} + +func (a *allowedCharacters) AddNumbers() error { + for r := '0'; r <= '9'; r++ { + a.ascii[r] = true + } + return nil +} + +func (a *allowedCharacters) AddLowercaseASCII() error { + for r := 'a'; r <= 'z'; r++ { + a.ascii[r] = true + } + return nil +} + +func (a *allowedCharacters) AddUppercaseASCII() error { + for r := 'A'; r <= 'Z'; r++ { + a.ascii[r] = true + } + return nil +} + +func (a *allowedCharacters) AddPrintableLienarASCII() { + // Anything between 0x20 and 0x7E + for i := 0x20; i <= 0x7E; i++ { + a.ascii[i] = true + } +} + +func (a *allowedCharacters) Add(chars string) error { + for _, char := range chars { + if int(char) >= len(a.ascii) { + return fmt.Errorf("char %d out of range, consider explicilty adding utf8 characters", char) + } + a.ascii[char] = true + } + return nil +} + +func (a *allowedCharacters) Remove(chars string) error { + for _, char := range chars { + if int(char) >= len(a.ascii) { + return fmt.Errorf("char %d out of range, consider explicilty adding utf8 characters", char) + } + a.ascii[char] = false + } + return nil +} + +func (a *allowedCharacters) Copy() *allowedCharacters { + return &allowedCharacters{ + ascii: a.ascii, + utf8: a.utf8, + } +} + +func (a *allowedCharacters) Validate(target string) error { + for _, char := range target { + if int(char) >= len(a.ascii) && !a.utf8 { + return fmt.Errorf("char %d out of range, consider explicilty adding utf8 characters", char) + } + if !a.ascii[char] { + return fmt.Errorf("char %d not allowed", char) + } + } + return nil +} + +var tokenCharacters *allowedCharacters +var displayNameCharacters *allowedCharacters +var headerValuesCharacters *allowedCharacters + +func init() { + // Per RFC 3261 Section 25.1 + // SIP-message = Request / Response + // Request = Request-Line *( message-header ) CRLF [ message-body ] + // Response = Status-Line *( message-header ) CRLF [ message-body ] + // Request-Line = Method SP Request-URI SP SIP-Version CRLF + // Method = (CAPITAL ASCII) + // Request-URI = SIP-URI / SIPS-URI / absoluteURI + // SIP-Version = "SIP" "/" 1*DIGIT "." 1*DIGIT (CAPITAL ASCII, DIGITS, "/.") + // Status-Line = SIP-Version SP Status-Code SP Reason-Phrase CRLF + // Status-Code = (Alphanum + "-") + // Reason-Phrase = (Basically whatever...) + // extension-header = header-name (token) ":" header-value (Basically whatever...) + + // URIs + // SIP-URI = "sip:" [ userinfo ] hostport uri-parameters [ headers ] + // SIPS-URI = "sips:" [ userinfo ] hostport uri-parameters [ headers ] + + // One specific header form we care about: + // name-addr = [ display-name ] LAQUOT addr-spec RAQUOT + // display-name = *(token LWS)/ quoted-string + // addr-spec = SIP-URI / SIPS-URI / absoluteURI + + tokenCharacters = NewAllowedCharacters() + tokenCharacters.AddNumbers() + tokenCharacters.AddLowercaseASCII() + tokenCharacters.AddUppercaseASCII() + tokenCharacters.Add("-.!%*_+`'~") + + displayNameCharacters = tokenCharacters.Copy() + displayNameCharacters.Add(" \t") + + headerValuesCharacters = NewAllowedCharacters() + headerValuesCharacters.AddPrintableLienarASCII() // Specifically not adding UTF8 for now +} + +// Required headers for SIP requests per RFC 3261 Section 8.1.1 +var RequiredRequestHeaders = map[string]bool{ + "via": true, + "from": true, + "to": true, + "call-id": true, + "cseq": true, + "max-forwards": true, +} + +// Required headers for SIP responses per RFC 3261 Section 8.2.1 +var RequiredResponseHeaders = map[string]bool{ + "via": true, + "from": true, + "to": true, + "call-id": true, + "cseq": true, +} + +// Crucial headers that can't be overridden by the user, and their shorthands +var FrobiddenSipHeaderNames = map[string]bool{ + "accept": true, + "accept-encoding": true, + "accept-language": true, + "allow": true, + "allow-events": true, // rfc3903 + "call-id": true, + "contact": true, + "content-encoding": true, + "content-length": true, + "content-type": true, + "cseq": true, + "event": true, // rfc3903 + "expires": true, + "from": true, // We might allow this in the future, but for now we're printing + "max-forwards": true, + "record-route": true, + "refer-to": true, // rfc3515 + "referred-by": true, // rfc3892sipUriCharacters + "reply-to": true, + "k": true, // Supported + "l": true, // Content-Length + "m": true, // Contact + "o": true, // Event; rfc3903 + "r": true, // Refer-To; rfc3515 + "t": true, // To + "u": true, // Allow-Events; rfc3903 + "v": true, // Via +} + +// Headers that must comply with name-addr specification per RFC 3261 Section 20.10 +// name-addr = [display-name] +// addr-spec = SIP-URI / SIPS-URI / absoluteURI +var nameAddrHeaders = map[string]bool{ + "from": true, + "to": true, + "contact": true, + "route": true, + "record-route": true, + "reply-to": true, + "p-asserted-identity": true, // RFC 3325 Section 9.1 +} + +// ValidateHeaderName validates a SIP header name per RFC 3261 Section 25.1 +func ValidateHeaderName(name string) error { + if name == "" { + return errors.New("header name cannot be empty") + } + + if len(name) > 255 { + return errors.New("header name too long (max 255 characters)") + } + + if err := tokenCharacters.Validate(name); err != nil { + return fmt.Errorf("header name %s contains invalid characters: %w", name, err) + } + + // Convert to lowercase for case-insensitive comparison + lowerName := strings.ToLower(name) + if forbidden, exists := FrobiddenSipHeaderNames[lowerName]; exists && forbidden { + return fmt.Errorf("header name %s not supported", name) + } + + return nil +} + +// ValidateHeaderValue validates a SIP header value per RFC 3261 Section 25.1 +func ValidateHeaderValue(name, value string) error { + if value == "" { + return fmt.Errorf("header %s: value cannot be empty", name) + } + + if len(value) > 1024 { + return fmt.Errorf("header %s: value too long (max 1024 characters)", name) + } + + // Basic character validation - printable ASCII. We're stricter than the spec here - no UTF-8 for now + if err := headerValuesCharacters.Validate(value); err != nil { + return fmt.Errorf("header %s: value: %w", name, err) + } + + // Convert to lowercase for case-insensitive comparison + lowerName := strings.ToLower(name) + if _, exists := nameAddrHeaders[lowerName]; exists && false { + // TODO: Disabled since all supported headers are forbidden, re-enable when we allow some + if err := validateNameAddrHeader(value); err != nil { + return fmt.Errorf("header %s: value: %w", name, err) + } + } + + return nil +} + +// findAngleBrackets efficiently finds angle brackets in a single scan +// Returns: start, end positions (-1 = missing), and error status +func findAngleBrackets(value string) (int, int, error) { + start := -1 + end := -1 + + for i, r := range value { + switch r { + case '<': + if start != -1 { + return -1, -1, errors.New("multiple opening brackets") + } + start = i + case '>': + if end != -1 { + return -1, -1, errors.New("multiple closing brackets") + } + end = i + } + } + + // Check for mismatched brackets + if (start == -1) != (end == -1) { + return -1, -1, errors.New("mismatched angle brackets") + } + + // Check that < comes before > + if start > end { + return -1, -1, errors.New("malformed angle brackets") + } + + return start, end, nil +} + +// validateNameAddrHeader validates headers that use name-addr format per RFC 3261 Section 20.10 +func validateNameAddrHeader(value string) error { + // RFC 3261 Section 20.10 - name-addr format + // name-addr = [display-name] + // addr-spec = SIP-URI / SIPS-URI / absoluteURI + + uri := value + start, end, err := findAngleBrackets(value) + if err != nil { + return err + } + if start >= 0 || end >= 0 { + uri = value[start+1 : end] + if err := validateDisplayName(strings.TrimSpace(value[:start])); err != nil { + return err + } + } else { + // This is a bare URI, and should comply with addr-spec, no special characters + if strings.ContainsAny(value, ";,? ") { + return errors.New("bare URI with special characters") + } + } + return validateURI(uri) +} + +// validateDisplayName validates a display name in name-addr format +func validateDisplayName(displayName string) error { + if displayName == "" { + return nil + } + + // Check if display name is quoted + if strings.HasPrefix(displayName, `"`) && strings.HasSuffix(displayName, `"`) { + // Quoted display name - use strconv.Unquote to validate proper escaping + _, err := strconv.Unquote(displayName) + if err != nil { + return fmt.Errorf("display name: %w", err) + } + return nil + } + + // Unquoted display name - must not contain special characters + if err := displayNameCharacters.Validate(displayName); err != nil { + return fmt.Errorf("display name: %w", err) + } + + return nil +} + +// validateURI validates URIs that can appear in name-addr format +func validateURI(uri string) error { + // Just do the basics, full validation should be done by sip service + scheme := strings.SplitN(uri, ":", 2)[0] + if scheme != "sip" && scheme != "sips" && scheme != "tel" { + // Technically, it either needs to be sip/s: or scheme://... + // Thus, tel: uri should not be supported here... but we allow it because of de-facto usage. + return errors.New("uri: scheme not one of sip, sips, or tel") + } + + // Just no spaces, proper validation should be done in sip service + if strings.Contains(uri, " ") { + return errors.New("uri: contains spaces") + } + + return nil +} diff --git a/livekit/sip_validation_test.go b/livekit/sip_validation_test.go new file mode 100644 index 00000000..9c9be283 --- /dev/null +++ b/livekit/sip_validation_test.go @@ -0,0 +1,259 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package livekit + +import ( + "fmt" + "strings" + "testing" +) + +// Valid Header Test Cases + +// ValidHeaderNames contains valid SIP header names +var ValidHeaderNames = []string{ + "Q", // single uppercase + "q", // single lowercase + "Qrom", // keyword + "qrom", // keyword + "Qall-ID", // hyphenated keyword + "P-Asserted-Identity", // multiple hyphens + "X-", // hyphen at end + "-X", // hyphen at start + "X123", // alphanumeric + "X_123", // underscore + "X.123", // period + "X!123", // exclamation + "X%123", // percent + "X*123", // asterisk + "X+123", // plus + "X`123", // backtick + "X'123", // single quote + "X~123", // tilde +} + +// InvalidHeaderNames contains invalid SIP header names +var InvalidHeaderNames = []string{ + "", // empty + "From To", // space in name + "From:To", // colon in name + "From,To", // comma in name + "From;To", // semicolon in name + "FromTo", // angle bracket in name + "From@To", // at symbol in name + "From\"To", // quote in name + "From\\To", // backslash in name + "From/To", // forward slash + "From[To", // square bracket + "From]To", // square bracket + "From{To", // curly brace + "From}To", // curly brace + "From(To", // parenthesis + "From)To", // parenthesis + "From?To", // question mark + "From=To", // equals sign + "From#To", // hash + "From$To", // dollar sign + "From&To", // ampersand + "From|To", // pipe + "From^To", // caret + "From\000To", // null byte + "From\nTo", // newline + "From\rTo", // carriage return + "From\tTo", // tab +} + +// ValidHeaderValues contains valid SIP header values (implementation-specific restrictions) +// Note: These restrictions are NOT in RFC 3261 but are applied for security/performance +var ValidHeaderValues = []string{ + "u1@example.com", // basic email + "", // SIP URI with brackets + "Alice ", // display name + URI + "\"Alice Smith\" ", // quoted display name + "SIP/2.0/UDP 192.168.1.1:5060", // Via header + "1 INVITE", // CSeq header + "255", // Max-Forwards (max valid) + "0", // Max-Forwards (min valid) + "application/sdp", // Content-Type + "123", // Content-Length + "3600", // Expires + "call-123@example.com", // Call-ID + "text/plain; charset=utf-8", // Content-Type with params + "", // IPv6 URI + "\"Alice & Bob\" ", // display name with & symbol + strings.Repeat("a", 1024), // max length +} + +// Note: These restrictions are NOT in RFC 3261 but are applied for security/performance +var InvalidHeaderValues = []string{ + "", // empty + "Header with\nnewline", // newline + "Header with\rreturn", // carriage return + "Header with\ttab", // tab + "Header with\x00null", // null byte + "Header with\x01control", // control character + "Header with\x1Funit separator", // control character + "Header with\x7Fdelete", // delete character + "Header with\x80extended", // extended ASCII + "Header with\xFFextended", // extended ASCII + "Header with unicode café", // Unicode + "Header with unicode 世界", // Unicode + "Header with unicode émojis 🎉", // Unicode with emojis + strings.Repeat("a", 1025), // too long +} + +// testCaseName truncates a test case name to maxLen and adds dots with total size +func testCaseName(name string, maxLen int, index int) string { + if len(name) <= maxLen { + return fmt.Sprintf("%d/%s)", index+1, name) + } + // Truncate to make room for "..." and size info + truncated := name[:maxLen-10] // Reserve space for "..." and "(1234)" + return fmt.Sprintf("%d/%s...(%d)", index+1, truncated, len(name)) +} + +// ValidNameAddrHeaders contains valid Name-addr format headers with parameters +var ValidNameAddrHeaders = []string{ + `"Alice Johnson" `, + `"Alice \"Ace\" Johnson's device\\" `, + `Alice Johnson `, + `sip:u4@example.com`, // basic SIP URI (no brackets needed) + `sips:u5@example.com`, // secure SIP URI (no brackets needed) + `tel:+1-555-123-4567`, // TEL URI (no brackets needed) + ``, // basic SIP URI with brackets + ``, // secure SIP URI with brackets + ``, // TEL URI with brackets + `Alice `, // display name + SIP URI + `"Alice Johnson" `, // quoted display name + ``, // SIP URI with transport + ``, // SIP URI with flag param + ``, // SIP URI with port + ``, // SIP URI with multiple params + `Alice `, // display name + params + `"Alice \"Ace\"" `, // quoted display + ``, // IPv6 with params + `;expires=60`, // SIPS URI with expires parameter + `Alice `, // display name + params + `"Alice & Bob" `, // display name with & symbol +} + +// InvalidNameAddrHeaders contains invalid Name-addr format headers +var InvalidNameAddrHeaders = []string{ + `"Alice "Ace" Johnson" `, // unescaped quotes + `"\Alice" `, // unescaped backslashes + `"Alice" Johnson `, // unmatched quotes + `Alice "Ace" Johnson `, // unescaped quotes in unquoted + `"Alice Johnson `, // unterminated quote + `Alice Johnson" `, // unmatched quote + ``, // missing opening bracket + ` `, // multiple URIs + `Alice `, // multiple URIs with display + `Alice sip:u13@example.com`, // display name without brackets + `Alice sips:u14@example.com`, // display name without brackets + `Alice & Bob `, // display name with & symbol + `sip:u16@example.com;transport=tcp`, // special chars without brackets + `sip:u17@example.com,transport=tcp`, // comma without brackets + `sip:u18@example.com?transport=tcp`, // question mark without brackets + ``, // missing equals sign + ``, // space in parameters +} + +// TestValidateHeaderName_ValidHeaders tests that all valid header names pass validation +func TestValidateHeaderName_ValidHeaders(t *testing.T) { + for i, headerName := range ValidHeaderNames { + t.Run(testCaseName(headerName, 32, i), func(t *testing.T) { + err := ValidateHeaderName(headerName) + if err != nil { + t.Errorf("ValidateHeaderName(%q) = %v, want nil", headerName, err) + } + }) + } +} + +// TestValidateHeaderName_InvalidHeaders tests that all invalid header names fail validation +func TestValidateHeaderName_InvalidHeaders(t *testing.T) { + for i, headerName := range InvalidHeaderNames { + t.Run(testCaseName(headerName, 32, i), func(t *testing.T) { + err := ValidateHeaderName(headerName) + if err == nil { + t.Errorf("ValidateHeaderName(%q) = nil, want error", headerName) + } + }) + } +} + +// TestValidateHeaderValue_ValidValues tests that all valid header values pass validation +func TestValidateHeaderValue_ValidValues(t *testing.T) { + for i, headerValue := range ValidHeaderValues { + t.Run(testCaseName(headerValue, 32, i), func(t *testing.T) { + err := ValidateHeaderValue("Test-Header", headerValue) + if err != nil { + t.Errorf("ValidateHeaderValue(%q) = %v, want nil", headerValue, err) + } + }) + } +} + +// TestValidateHeaderValue_InvalidValues tests that all invalid header values fail validation +// Note: These restrictions are implementation-specific, NOT from RFC 3261 +func TestValidateHeaderValue_InvalidValues(t *testing.T) { + for i, headerValue := range InvalidHeaderValues { + t.Run(testCaseName(headerValue, 32, i), func(t *testing.T) { + err := ValidateHeaderValue("Test-Header", headerValue) + if err == nil { + t.Errorf("ValidateHeaderValue(%q) = nil, want error", headerValue) + } + }) + } +} + +// TestValidateNameAddr_ValidHeaders tests that all valid Name-addr headers pass validation +func TestValidateNameAddr_ValidHeaders(t *testing.T) { + for i, nameAddr := range ValidNameAddrHeaders { + t.Run(testCaseName(nameAddr, 32, i), func(t *testing.T) { + err := validateNameAddrHeader(nameAddr) + if err != nil { + t.Errorf("validateNameAddrHeader(%q) = %v, want nil", nameAddr, err) + } + }) + } +} + +// TestValidateNameAddr_InvalidHeaders tests that all invalid Name-addr headers fail validation +func TestValidateNameAddr_InvalidHeaders(t *testing.T) { + for i, nameAddr := range InvalidNameAddrHeaders { + t.Run(testCaseName(nameAddr, 32, i), func(t *testing.T) { + err := validateNameAddrHeader(nameAddr) + if err == nil { + t.Errorf("validateNameAddrHeader(%q) = nil, want error", nameAddr) + } + }) + } +} + +func TestFrobiddenSipHeaderNames(t *testing.T) { + i := 0 + for name := range FrobiddenSipHeaderNames { + i++ + t.Run(testCaseName(name, 32, i), func(t *testing.T) { + err := ValidateHeaderName(name) + if err == nil { + t.Errorf("ValidateHeaderName(%q) = nil, want error", name) + } + }) + } +}