Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NET-1976: Add support for Managed DNS on Windows. #999

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
7 changes: 7 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ const (
Initd
)

const (
// DefaultHostID {0EF230F0-2EAD-4370-B0F9-AFC2D2A039E6} is a fixed string,
// for creating the unique GUID. It's a meaningless unique GUID here to
// make sure only one network profile is created.
DefaultHostID = "0EF230F0-2EAD-4370-B0F9-AFC2D2A039E6"
)

// Initype - the type of init system in use
type InitType int

Expand Down
44 changes: 40 additions & 4 deletions dns/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@ package dns

import (
"encoding/json"
"errors"
"os"
"strings"
"sync"

"github.com/gravitl/netclient/config"
"golang.org/x/exp/slog"
)

const (
DNS_MANAGER_STUB = "stub" // '/run/systemd/resolve/stub-resolv.conf'
DNS_MANAGER_UPLINK = "uplink" // '/run/systemd/resolve/resolv.conf'
DNS_MANAGER_RESOLVECONF = "resolveconf" // 'generated by resolvconf(8)'
DNS_MANAGER_FILE = "file" // other than above
DNS_MANAGER_STUB = "stub" // '/run/systemd/resolve/stub-resolv.conf'
DNS_MANAGER_UPLINK = "uplink" // '/run/systemd/resolve/resolv.conf'
DNS_MANAGER_RESOLVECONF = "resolveconf" // 'generated by resolvconf(8)'
DNS_MANAGER_FILE = "file" // other than above
DNS_MANAGER_WINDOWS_REGISTRY = "registry" // for Windows machines
)

var (
Expand Down Expand Up @@ -115,3 +118,36 @@ func cleanDNSJsonFile() error {

return nil
}

// getDnsIp return the ip address of the dns server
func getDnsIp() (string, error) {
dnsIp := GetDNSServerInstance().AddrStr
if dnsIp == "" {
return "", errors.New("no listener is running")
}

if len(config.GetNodes()) == 0 {
return "", errors.New("no network joint")
}

dnsIp = getIpFromServerString(dnsIp)

return dnsIp, nil
}

// getIpFromServerString returns ip address from the ip:port
// address pair.
func getIpFromServerString(addrStr string) string {
s := ""
s = addrStr[0:strings.LastIndex(addrStr, ":")]

if strings.Contains(s, "[") {
s = strings.ReplaceAll(s, "[", "")
}

if strings.Contains(s, "]") {
s = strings.ReplaceAll(s, "]", "")
}

return s
}
34 changes: 6 additions & 28 deletions dns/config_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,11 @@ func RestoreDNSConfig() (err error) {
}

func buildAddConfigContentUplink() ([]string, error) {

dnsIp := GetDNSServerInstance().AddrStr
if dnsIp == "" {
return []string{}, errors.New("no listener is running")
dnsIp, err := getDnsIp()
if err != nil {
return nil, err
}

dnsIp = getIpFromServerString(dnsIp)
ns := "DNS=" + dnsIp

f, err := os.Open(resolvUplinkPath)
Expand Down Expand Up @@ -171,16 +169,11 @@ func setupResolveUplink() (err error) {

func setupResolvectl() (err error) {

dnsIp := GetDNSServerInstance().AddrStr
if dnsIp == "" {
return errors.New("no listener is running")
}
if len(config.GetNodes()) == 0 {
return errors.New("no network joint")
dnsIp, err := getDnsIp()
if err != nil {
return err
}

dnsIp = getIpFromServerString(dnsIp)

_, err = ncutils.RunCmd(fmt.Sprintf("resolvectl dns netmaker %s", dnsIp), false)
if err != nil {
slog.Warn("add DNS IP for netmaker failed", "error", err.Error())
Expand All @@ -206,21 +199,6 @@ func setupResolvectl() (err error) {
return nil
}

func getIpFromServerString(addrStr string) string {
s := ""
s = addrStr[0:strings.LastIndex(addrStr, ":")]

if strings.Contains(s, "[") {
s = strings.ReplaceAll(s, "[", "")
}

if strings.Contains(s, "]") {
s = strings.ReplaceAll(s, "]", "")
}

return s
}

func backupResolveconfFile(src, dst string) error {

_, err := os.Stat(dst)
Expand Down
70 changes: 65 additions & 5 deletions dns/config_windows.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,76 @@
package dns

func FlushLocalDnsCache() (err error) {
return nil
import (
"fmt"
"github.com/gravitl/netclient/config"
"github.com/gravitl/netclient/ncutils"
"golang.org/x/exp/slog"
"golang.org/x/sys/windows/registry"
"net/netip"
"sync"
)

var dnsConfigMutex sync.Mutex

func FlushLocalDnsCache() error {
_, err := ncutils.RunCmd("ipconfig /flushdns", false)
if err != nil {
slog.Warn("failed to flush local dns cache", "error", err.Error())
}

return err
}

func SetupDNSConfig() (err error) {
return nil
func SetupDNSConfig() error {
dnsConfigMutex.Lock()
defer dnsConfigMutex.Unlock()

// ignore if dns manager is not Windows Registry
if config.Netclient().DNSManagerType != DNS_MANAGER_WINDOWS_REGISTRY {
return nil
}

dnsIp, err := getDnsIp()
if err != nil {
return err
}

ip, err := netip.ParseAddr(dnsIp)
if err != nil {
return err
}

guid := config.Netclient().Host.ID.String()
if guid == "" {
guid = config.DefaultHostID
}

guid = "{" + guid + "}"

keyPath := ""
if ip.Is6() {
keyPath = fmt.Sprintf(`SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\%s`, guid)
} else {
keyPath = fmt.Sprintf(`SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\%s`, guid)
}

// open registry key with permissions to set value
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE|registry.SET_VALUE)
if err != nil {
return err
}
defer key.Close()

return key.SetStringValue("NameServer", dnsIp)
}

func RestoreDNSConfig() (err error) {
func RestoreDNSConfig() error {
return nil
}

func InitDNSConfig() {
dnsConfigMutex.Lock()
defer dnsConfigMutex.Unlock()

config.Netclient().DNSManagerType = DNS_MANAGER_WINDOWS_REGISTRY
}
10 changes: 5 additions & 5 deletions dns/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func GetDNSServerInstance() *DNSServer {

// Start the DNS listener
func (dnsServer *DNSServer) Start() {
if runtime.GOOS != "linux" {
if runtime.GOOS != "linux" && runtime.GOOS != "windows" {
return
}
dnsMutex.Lock()
Expand Down Expand Up @@ -95,10 +95,10 @@ func (dnsServer *DNSServer) Start() {
}

//Setup DNS config for Linux
if config.Netclient().Host.OS == "linux" {
if config.Netclient().Host.OS == "linux" || config.Netclient().Host.OS == "windows" {
err := SetupDNSConfig()
if err != nil {
slog.Error("setup DNS conig failed", "error", err.Error())
slog.Error("setup DNS config failed", "error", err.Error())
}
}

Expand All @@ -107,7 +107,7 @@ func (dnsServer *DNSServer) Start() {

// Stop the DNS listener
func (dnsServer *DNSServer) Stop() {
if runtime.GOOS != "linux" {
if runtime.GOOS != "linux" && runtime.GOOS != "windows" {
return
}
dnsMutex.Lock()
Expand All @@ -117,7 +117,7 @@ func (dnsServer *DNSServer) Stop() {
}

//restore DNS config for Linux
if config.Netclient().Host.OS == "linux" {
if config.Netclient().Host.OS == "linux" || config.Netclient().Host.OS == "windows" {
err := RestoreDNSConfig()
if err != nil {
slog.Warn("Restore DNS conig failed", "error", err.Error())
Expand Down
39 changes: 31 additions & 8 deletions dns/resolver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dns

import (
"errors"
"net"
"strings"
"sync"
Expand All @@ -16,6 +17,11 @@ const (

var dnsMapMutex = sync.Mutex{} // used to mutex functions of the DNS

var (
ErrNXDomain = errors.New("non existent domain")
ErrNoQTypeRecord = errors.New("domain exists but no record matching the question type")
)

type DNSResolver struct {
DnsEntriesCacheStore map[string]dns.RR
DnsEntriesCacheMap map[string][]dnsRecord
Expand Down Expand Up @@ -53,11 +59,10 @@ func handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
reply.RecursionAvailable = true
reply.RecursionDesired = true
reply.Rcode = dns.RcodeSuccess
reply.Authoritative = true

resp := GetDNSResolverInstance().Lookup(r)
if resp != nil {
reply.Answer = append(reply.Answer, resp)
} else {
resp, err := GetDNSResolverInstance().Lookup(r)
if err != nil && errors.Is(err, ErrNXDomain) {
nslist := config.Netclient().NameServers
if config.Netclient().CurrGwNmIP != nil {
nslist = []string{}
Expand Down Expand Up @@ -97,7 +102,11 @@ func handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
}
}

err := w.WriteMsg(reply)
if resp != nil {
reply.Answer = append(reply.Answer, resp)
}

err = w.WriteMsg(reply)
if err != nil {
slog.Error("write DNS response message error: ", "error", err.Error())
}
Expand Down Expand Up @@ -136,14 +145,28 @@ func (d *DNSResolver) RegisterAAAA(record dnsRecord) error {
}

// Lookup DNS entry in local directory
func (d *DNSResolver) Lookup(m *dns.Msg) dns.RR {
func (d *DNSResolver) Lookup(m *dns.Msg) (dns.RR, error) {
dnsMapMutex.Lock()
defer dnsMapMutex.Unlock()
q := m.Question[0]
r, ok := d.DnsEntriesCacheStore[buildDNSEntryKey(strings.TrimSuffix(q.Name, "."), q.Qtype)]
if !ok {
return nil
if q.Qtype == dns.TypeA {
_, ok = d.DnsEntriesCacheStore[buildDNSEntryKey(strings.TrimSuffix(q.Name, "."), dns.TypeAAAA)]
if ok {
// aware but no ipv6 address
return nil, ErrNoQTypeRecord
}
} else if q.Qtype == dns.TypeAAAA {
_, ok = d.DnsEntriesCacheStore[buildDNSEntryKey(strings.TrimSuffix(q.Name, "."), dns.TypeA)]
if ok {
// aware but no ipv4 address
return nil, ErrNoQTypeRecord
}
}

return nil, ErrNXDomain
}

return r
return r, nil
}
2 changes: 1 addition & 1 deletion functions/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func startGoRoutines(wg *sync.WaitGroup) context.CancelFunc {
stun.SetDefaultStunServers()
}
netclientCfg := config.Netclient()
if netclientCfg.Host.OS == "linux" {
if netclientCfg.Host.OS == "linux" || netclientCfg.Host.OS == "windows" {
dns.InitDNSConfig()
updateConfig = true
}
Expand Down
14 changes: 10 additions & 4 deletions functions/mqhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,12 @@ func HostPeerUpdate(client mqtt.Client, msg mqtt.Message) {
dns.GetDNSServerInstance().Stop()
}
}
if server.ManageDNS && config.Netclient().DNSManagerType == dns.DNS_MANAGER_STUB {
dns.SetupDNSConfig()

if server.ManageDNS {
if (config.Netclient().Host.OS == "linux" && dns.GetDNSServerInstance().AddrStr != "" && config.Netclient().DNSManagerType == dns.DNS_MANAGER_STUB) ||
config.Netclient().Host.OS == "windows" {
dns.SetupDNSConfig()
}
}

reloadStun := false
Expand Down Expand Up @@ -492,8 +496,10 @@ func resetInterfaceFunc() {
if dns.GetDNSServerInstance().AddrStr == "" {
dns.GetDNSServerInstance().Start()
}
//Setup resolveconf for Linux
if config.Netclient().Host.OS == "linux" && dns.GetDNSServerInstance().AddrStr != "" && config.Netclient().DNSManagerType == dns.DNS_MANAGER_STUB {

//Setup DNS for Linux and Windows
if (config.Netclient().Host.OS == "linux" && dns.GetDNSServerInstance().AddrStr != "" && config.Netclient().DNSManagerType == dns.DNS_MANAGER_STUB) ||
config.Netclient().Host.OS == "windows" {
dns.SetupDNSConfig()
}
}
Expand Down
4 changes: 1 addition & 3 deletions wireguard/wireguard_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ func (nc *NCIface) Create() error {
adapter, err := driver.OpenAdapter(ncutils.GetInterfaceName())
if err != nil {
slog.Info("creating Windows tunnel")
//{0EF230F0-2EAD-4370-B0F9-AFC2D2A039E6} is a fixed string, for creating the unique GUID. It's meaningless
//unique GUID here to make sure only one network profile created
idString := config.Netclient().Host.ID.String()
if idString == "" {
idString = "0EF230F0-2EAD-4370-B0F9-AFC2D2A039E6"
idString = config.DefaultHostID
}
windowsGUID, err := windows.GUIDFromString("{" + idString + "}")
if err != nil {
Expand Down