Skip to content

Commit

Permalink
The cli normalized the given CIDRs by default, so when a user entered…
Browse files Browse the repository at this point in the history
… 10.0.0.1/8 (as a sample) the cli normalized it to 10.0.0.0/8 silent. After this MR we now validate that the given IP is the start of the CIDR block (e.g. 10.0.0.0/8). (#304)

Signed-off-by: Lukas Kämmerling <[email protected]>
  • Loading branch information
LKaemmerling authored Mar 17, 2021
1 parent 6c04c99 commit 5442833
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 28 deletions.
20 changes: 6 additions & 14 deletions internal/cmd/firewall/add_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@ func runAddRule(cli *state.State, cmd *cobra.Command, args []string) error {
return fmt.Errorf("Firewall not found: %v", idOrName)
}

var sourceNets []net.IPNet
for i, sourceIP := range sourceIPs {
_, sourceNet, err := net.ParseCIDR(sourceIP)
if err != nil {
return fmt.Errorf("invalid CIDR on index %d : %s", i, err)
}
sourceNets = append(sourceNets, *sourceNet)
}
d := hcloud.FirewallRuleDirection(direction)
rule := hcloud.FirewallRule{
Direction: d,
Expand All @@ -75,20 +67,20 @@ func runAddRule(cli *state.State, cmd *cobra.Command, args []string) error {

switch d {
case hcloud.FirewallRuleDirectionOut:
rule.DestinationIPs = make([]net.IPNet, 0, len(destinationIPs))
rule.DestinationIPs = make([]net.IPNet, len(destinationIPs))
for i, ip := range destinationIPs {
_, n, err := net.ParseCIDR(ip)
n, err := validateFirewallIP(ip)
if err != nil {
return fmt.Errorf("invalid CIDR on index %d : %s", i, err)
return fmt.Errorf("destination error on index %d: %s", i, err)
}
rule.DestinationIPs[i] = *n
}
case hcloud.FirewallRuleDirectionIn:
rule.SourceIPs = make([]net.IPNet, 0, len(sourceIPs))
rule.SourceIPs = make([]net.IPNet, len(sourceIPs))
for i, ip := range sourceIPs {
_, n, err := net.ParseCIDR(ip)
n, err := validateFirewallIP(ip)
if err != nil {
return fmt.Errorf("invalid CIDR on index %d : %s", i, err)
return fmt.Errorf("source ips error on index %d: %s", i, err)
}
rule.SourceIPs[i] = *n
}
Expand Down
20 changes: 6 additions & 14 deletions internal/cmd/firewall/delete_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@ func runDeleteRule(cli *state.State, cmd *cobra.Command, args []string) error {
return fmt.Errorf("Firewall not found: %v", idOrName)
}

var sourceNets []net.IPNet
for i, sourceIP := range sourceIPs {
_, sourceNet, err := net.ParseCIDR(sourceIP)
if err != nil {
return fmt.Errorf("invalid CIDR on index %d : %s", i, err)
}
sourceNets = append(sourceNets, *sourceNet)
}
d := hcloud.FirewallRuleDirection(direction)
rule := hcloud.FirewallRule{
Direction: d,
Expand All @@ -74,20 +66,20 @@ func runDeleteRule(cli *state.State, cmd *cobra.Command, args []string) error {
}
switch d {
case hcloud.FirewallRuleDirectionOut:
rule.DestinationIPs = make([]net.IPNet, 0, len(destinationIPs))
rule.DestinationIPs = make([]net.IPNet, len(destinationIPs))
for i, ip := range destinationIPs {
_, n, err := net.ParseCIDR(ip)
n, err := validateFirewallIP(ip)
if err != nil {
return fmt.Errorf("invalid CIDR on index %d : %s", i, err)
return fmt.Errorf("destination ips error on index %d: %s", i, err)
}
rule.DestinationIPs[i] = *n
}
case hcloud.FirewallRuleDirectionIn:
rule.SourceIPs = make([]net.IPNet, 0, len(sourceIPs))
rule.SourceIPs = make([]net.IPNet, len(sourceIPs))
for i, ip := range sourceIPs {
_, n, err := net.ParseCIDR(ip)
n, err := validateFirewallIP(ip)
if err != nil {
return fmt.Errorf("invalid CIDR on index %d : %s", i, err)
return fmt.Errorf("source ips error on index %d: %s", i, err)
}
rule.SourceIPs[i] = *n
}
Expand Down
18 changes: 18 additions & 0 deletions internal/cmd/firewall/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package firewall

import (
"fmt"
"net"
)

func validateFirewallIP(ip string) (*net.IPNet, error) {
i, n, err := net.ParseCIDR(ip)
if err != nil {
return nil, fmt.Errorf("%s", err)
}
if i.String() != n.IP.String() {
return nil, fmt.Errorf("%s is not the start of the cidr block %s", ip, n)
}

return n, nil
}
64 changes: 64 additions & 0 deletions internal/cmd/firewall/validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package firewall

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestValidateFirewallIP(t *testing.T) {
tests := []struct {
name string
ip string
err error
}{
{
name: "Valid CIDR (IPv4)",
ip: "10.0.0.0/8",
},
{
name: "Valid CIDR (IPv6)",
ip: "fe80::/128",
},
{
name: "Invalid IP",
ip: "test",
err: fmt.Errorf("invalid CIDR address: test"),
},
{
name: "Missing CIDR notation (IPv4)",
ip: "10.0.0.0",
err: fmt.Errorf("invalid CIDR address: 10.0.0.0"),
},
{
name: "Missing CIDR notation (IPv6)",
ip: "fe80::",
err: fmt.Errorf("invalid CIDR address: fe80::"),
},
{
name: "Host bit set (IPv4)",
ip: "10.0.0.5/8",
err: fmt.Errorf("10.0.0.5/8 is not the start of the cidr block 10.0.0.0/8"),
},
{
name: "Host bit set (IPv6)",
ip: "fe80::1337/64",
err: fmt.Errorf("fe80::1337/64 is not the start of the cidr block fe80::/64"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
net, err := validateFirewallIP(test.ip)

if test.err != nil {
assert.Equal(t, err, test.err)
assert.Nil(t, net)
return
}

assert.NoError(t, err)
assert.NotNil(t, net)
})
}
}

0 comments on commit 5442833

Please sign in to comment.