diff --git a/config/config.go b/config/config.go index b62a13d9..4bf8a1e2 100644 --- a/config/config.go +++ b/config/config.go @@ -49,6 +49,13 @@ const ( Initd ) +const ( + // DefaultHostID {0EF230F0-2EAD-4370-B0F9-AFC2D2A039E6} is a fixed string, + // for creating the unique GUID. It's a meaningless unique GUID here to + // make sure only one network profile is created. + DefaultHostID = "0EF230F0-2EAD-4370-B0F9-AFC2D2A039E6" +) + // Initype - the type of init system in use type InitType int diff --git a/dns/config.go b/dns/config.go index 70897e11..a2b291de 100644 --- a/dns/config.go +++ b/dns/config.go @@ -2,7 +2,9 @@ package dns import ( "encoding/json" + "errors" "os" + "strings" "sync" "github.com/gravitl/netclient/config" @@ -10,10 +12,11 @@ import ( ) const ( - DNS_MANAGER_STUB = "stub" // '/run/systemd/resolve/stub-resolv.conf' - DNS_MANAGER_UPLINK = "uplink" // '/run/systemd/resolve/resolv.conf' - DNS_MANAGER_RESOLVECONF = "resolveconf" // 'generated by resolvconf(8)' - DNS_MANAGER_FILE = "file" // other than above + DNS_MANAGER_STUB = "stub" // '/run/systemd/resolve/stub-resolv.conf' + DNS_MANAGER_UPLINK = "uplink" // '/run/systemd/resolve/resolv.conf' + DNS_MANAGER_RESOLVECONF = "resolveconf" // 'generated by resolvconf(8)' + DNS_MANAGER_FILE = "file" // other than above + DNS_MANAGER_WINDOWS_REGISTRY = "registry" // for Windows machines ) var ( @@ -115,3 +118,36 @@ func cleanDNSJsonFile() error { return nil } + +// getDnsIp return the ip address of the dns server +func getDnsIp() (string, error) { + dnsIp := GetDNSServerInstance().AddrStr + if dnsIp == "" { + return "", errors.New("no listener is running") + } + + if len(config.GetNodes()) == 0 { + return "", errors.New("no network joint") + } + + dnsIp = getIpFromServerString(dnsIp) + + return dnsIp, nil +} + +// getIpFromServerString returns ip address from the ip:port +// address pair. +func getIpFromServerString(addrStr string) string { + s := "" + s = addrStr[0:strings.LastIndex(addrStr, ":")] + + if strings.Contains(s, "[") { + s = strings.ReplaceAll(s, "[", "") + } + + if strings.Contains(s, "]") { + s = strings.ReplaceAll(s, "]", "") + } + + return s +} diff --git a/dns/config_linux.go b/dns/config_linux.go index 5d027c72..9a61da8f 100644 --- a/dns/config_linux.go +++ b/dns/config_linux.go @@ -92,13 +92,11 @@ func RestoreDNSConfig() (err error) { } func buildAddConfigContentUplink() ([]string, error) { - - dnsIp := GetDNSServerInstance().AddrStr - if dnsIp == "" { - return []string{}, errors.New("no listener is running") + dnsIp, err := getDnsIp() + if err != nil { + return nil, err } - dnsIp = getIpFromServerString(dnsIp) ns := "DNS=" + dnsIp f, err := os.Open(resolvUplinkPath) @@ -171,16 +169,11 @@ func setupResolveUplink() (err error) { func setupResolvectl() (err error) { - dnsIp := GetDNSServerInstance().AddrStr - if dnsIp == "" { - return errors.New("no listener is running") - } - if len(config.GetNodes()) == 0 { - return errors.New("no network joint") + dnsIp, err := getDnsIp() + if err != nil { + return err } - dnsIp = getIpFromServerString(dnsIp) - _, err = ncutils.RunCmd(fmt.Sprintf("resolvectl dns netmaker %s", dnsIp), false) if err != nil { slog.Warn("add DNS IP for netmaker failed", "error", err.Error()) @@ -206,21 +199,6 @@ func setupResolvectl() (err error) { return nil } -func getIpFromServerString(addrStr string) string { - s := "" - s = addrStr[0:strings.LastIndex(addrStr, ":")] - - if strings.Contains(s, "[") { - s = strings.ReplaceAll(s, "[", "") - } - - if strings.Contains(s, "]") { - s = strings.ReplaceAll(s, "]", "") - } - - return s -} - func backupResolveconfFile(src, dst string) error { _, err := os.Stat(dst) diff --git a/dns/config_windows.go b/dns/config_windows.go index 4f683098..74668479 100644 --- a/dns/config_windows.go +++ b/dns/config_windows.go @@ -1,16 +1,76 @@ package dns -func FlushLocalDnsCache() (err error) { - return nil +import ( + "fmt" + "github.com/gravitl/netclient/config" + "github.com/gravitl/netclient/ncutils" + "golang.org/x/exp/slog" + "golang.org/x/sys/windows/registry" + "net/netip" + "sync" +) + +var dnsConfigMutex sync.Mutex + +func FlushLocalDnsCache() error { + _, err := ncutils.RunCmd("ipconfig /flushdns", false) + if err != nil { + slog.Warn("failed to flush local dns cache", "error", err.Error()) + } + + return err } -func SetupDNSConfig() (err error) { - return nil +func SetupDNSConfig() error { + dnsConfigMutex.Lock() + defer dnsConfigMutex.Unlock() + + // ignore if dns manager is not Windows Registry + if config.Netclient().DNSManagerType != DNS_MANAGER_WINDOWS_REGISTRY { + return nil + } + + dnsIp, err := getDnsIp() + if err != nil { + return err + } + + ip, err := netip.ParseAddr(dnsIp) + if err != nil { + return err + } + + guid := config.Netclient().Host.ID.String() + if guid == "" { + guid = config.DefaultHostID + } + + guid = "{" + guid + "}" + + keyPath := "" + if ip.Is6() { + keyPath = fmt.Sprintf(`SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\%s`, guid) + } else { + keyPath = fmt.Sprintf(`SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\%s`, guid) + } + + // open registry key with permissions to set value + key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE|registry.SET_VALUE) + if err != nil { + return err + } + defer key.Close() + + return key.SetStringValue("NameServer", dnsIp) } -func RestoreDNSConfig() (err error) { +func RestoreDNSConfig() error { return nil } func InitDNSConfig() { + dnsConfigMutex.Lock() + defer dnsConfigMutex.Unlock() + + config.Netclient().DNSManagerType = DNS_MANAGER_WINDOWS_REGISTRY } diff --git a/dns/listener.go b/dns/listener.go index 05354c5c..d5a880b7 100644 --- a/dns/listener.go +++ b/dns/listener.go @@ -33,7 +33,7 @@ func GetDNSServerInstance() *DNSServer { // Start the DNS listener func (dnsServer *DNSServer) Start() { - if runtime.GOOS != "linux" { + if runtime.GOOS != "linux" && runtime.GOOS != "windows" { return } dnsMutex.Lock() @@ -95,10 +95,10 @@ func (dnsServer *DNSServer) Start() { } //Setup DNS config for Linux - if config.Netclient().Host.OS == "linux" { + if config.Netclient().Host.OS == "linux" || config.Netclient().Host.OS == "windows" { err := SetupDNSConfig() if err != nil { - slog.Error("setup DNS conig failed", "error", err.Error()) + slog.Error("setup DNS config failed", "error", err.Error()) } } @@ -107,7 +107,7 @@ func (dnsServer *DNSServer) Start() { // Stop the DNS listener func (dnsServer *DNSServer) Stop() { - if runtime.GOOS != "linux" { + if runtime.GOOS != "linux" && runtime.GOOS != "windows" { return } dnsMutex.Lock() @@ -117,7 +117,7 @@ func (dnsServer *DNSServer) Stop() { } //restore DNS config for Linux - if config.Netclient().Host.OS == "linux" { + if config.Netclient().Host.OS == "linux" || config.Netclient().Host.OS == "windows" { err := RestoreDNSConfig() if err != nil { slog.Warn("Restore DNS conig failed", "error", err.Error()) diff --git a/dns/resolver.go b/dns/resolver.go index 4c0090ce..6cdaf9f8 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -1,6 +1,7 @@ package dns import ( + "errors" "net" "strings" "sync" @@ -16,6 +17,11 @@ const ( var dnsMapMutex = sync.Mutex{} // used to mutex functions of the DNS +var ( + ErrNXDomain = errors.New("non existent domain") + ErrNoQTypeRecord = errors.New("domain exists but no record matching the question type") +) + type DNSResolver struct { DnsEntriesCacheStore map[string]dns.RR DnsEntriesCacheMap map[string][]dnsRecord @@ -53,11 +59,10 @@ func handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { reply.RecursionAvailable = true reply.RecursionDesired = true reply.Rcode = dns.RcodeSuccess + reply.Authoritative = true - resp := GetDNSResolverInstance().Lookup(r) - if resp != nil { - reply.Answer = append(reply.Answer, resp) - } else { + resp, err := GetDNSResolverInstance().Lookup(r) + if err != nil && errors.Is(err, ErrNXDomain) { nslist := config.Netclient().NameServers if config.Netclient().CurrGwNmIP != nil { nslist = []string{} @@ -97,7 +102,11 @@ func handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { } } - err := w.WriteMsg(reply) + if resp != nil { + reply.Answer = append(reply.Answer, resp) + } + + err = w.WriteMsg(reply) if err != nil { slog.Error("write DNS response message error: ", "error", err.Error()) } @@ -136,14 +145,28 @@ func (d *DNSResolver) RegisterAAAA(record dnsRecord) error { } // Lookup DNS entry in local directory -func (d *DNSResolver) Lookup(m *dns.Msg) dns.RR { +func (d *DNSResolver) Lookup(m *dns.Msg) (dns.RR, error) { dnsMapMutex.Lock() defer dnsMapMutex.Unlock() q := m.Question[0] r, ok := d.DnsEntriesCacheStore[buildDNSEntryKey(strings.TrimSuffix(q.Name, "."), q.Qtype)] if !ok { - return nil + if q.Qtype == dns.TypeA { + _, ok = d.DnsEntriesCacheStore[buildDNSEntryKey(strings.TrimSuffix(q.Name, "."), dns.TypeAAAA)] + if ok { + // aware but no ipv6 address + return nil, ErrNoQTypeRecord + } + } else if q.Qtype == dns.TypeAAAA { + _, ok = d.DnsEntriesCacheStore[buildDNSEntryKey(strings.TrimSuffix(q.Name, "."), dns.TypeA)] + if ok { + // aware but no ipv4 address + return nil, ErrNoQTypeRecord + } + } + + return nil, ErrNXDomain } - return r + return r, nil } diff --git a/functions/daemon.go b/functions/daemon.go index 554ddc57..7af400ca 100644 --- a/functions/daemon.go +++ b/functions/daemon.go @@ -173,7 +173,7 @@ func startGoRoutines(wg *sync.WaitGroup) context.CancelFunc { stun.SetDefaultStunServers() } netclientCfg := config.Netclient() - if netclientCfg.Host.OS == "linux" { + if netclientCfg.Host.OS == "linux" || netclientCfg.Host.OS == "windows" { dns.InitDNSConfig() updateConfig = true } diff --git a/functions/mqhandlers.go b/functions/mqhandlers.go index c6a8bd2f..aaa130a6 100644 --- a/functions/mqhandlers.go +++ b/functions/mqhandlers.go @@ -269,8 +269,12 @@ func HostPeerUpdate(client mqtt.Client, msg mqtt.Message) { dns.GetDNSServerInstance().Stop() } } - if server.ManageDNS && config.Netclient().DNSManagerType == dns.DNS_MANAGER_STUB { - dns.SetupDNSConfig() + + if server.ManageDNS { + if (config.Netclient().Host.OS == "linux" && dns.GetDNSServerInstance().AddrStr != "" && config.Netclient().DNSManagerType == dns.DNS_MANAGER_STUB) || + config.Netclient().Host.OS == "windows" { + dns.SetupDNSConfig() + } } reloadStun := false @@ -492,8 +496,10 @@ func resetInterfaceFunc() { if dns.GetDNSServerInstance().AddrStr == "" { dns.GetDNSServerInstance().Start() } - //Setup resolveconf for Linux - if config.Netclient().Host.OS == "linux" && dns.GetDNSServerInstance().AddrStr != "" && config.Netclient().DNSManagerType == dns.DNS_MANAGER_STUB { + + //Setup DNS for Linux and Windows + if (config.Netclient().Host.OS == "linux" && dns.GetDNSServerInstance().AddrStr != "" && config.Netclient().DNSManagerType == dns.DNS_MANAGER_STUB) || + config.Netclient().Host.OS == "windows" { dns.SetupDNSConfig() } } diff --git a/wireguard/wireguard_windows.go b/wireguard/wireguard_windows.go index 43c4106c..16ec6c7b 100644 --- a/wireguard/wireguard_windows.go +++ b/wireguard/wireguard_windows.go @@ -29,11 +29,9 @@ func (nc *NCIface) Create() error { adapter, err := driver.OpenAdapter(ncutils.GetInterfaceName()) if err != nil { slog.Info("creating Windows tunnel") - //{0EF230F0-2EAD-4370-B0F9-AFC2D2A039E6} is a fixed string, for creating the unique GUID. It's meaningless - //unique GUID here to make sure only one network profile created idString := config.Netclient().Host.ID.String() if idString == "" { - idString = "0EF230F0-2EAD-4370-B0F9-AFC2D2A039E6" + idString = config.DefaultHostID } windowsGUID, err := windows.GUIDFromString("{" + idString + "}") if err != nil {