diff --git a/README.md b/README.md index e1ff63b..3a42c5f 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ Better implementation of [KVAS](https://github.com/qzeleza/kvas) Roadmap: - [x] DNS Proxy +- [x] DNS Records table - [ ] IPTables rules to remap DNS server [1] - [ ] Rule composer - [ ] List loading/watching (temporary) diff --git a/main.go b/main.go index ee6d08f..1ce4300 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" dnsProxy "kvas2-go/dns-proxy" + ruleComposer "kvas2-go/rule-composer" "log" ) @@ -13,6 +14,7 @@ var ( ) func main() { + records := ruleComposer.NewRecords() proxy := dnsProxy.New("", ListenPort, UsableDNSServerAddress, UsableDNSServerPort) proxy.MsgHandler = func(msg *dnsProxy.Message) { for _, q := range msg.QD { @@ -22,8 +24,10 @@ func main() { switch v := a.(type) { case dnsProxy.Address: fmt.Printf("%x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL) + records.PutIPv4Address(v.Name.String(), v.Address, int64(v.TTL)) case dnsProxy.CName: fmt.Printf("%x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName) + records.PutCName(v.Name.String(), v.CName.String(), int64(v.TTL)) default: fmt.Printf("%x: -> Unknown: %x\n", msg.ID, v.EncodeResource()) } @@ -34,6 +38,13 @@ func main() { for _, a := range msg.AR { fmt.Printf("%x: -> NS: %x\n", msg.ID, a.EncodeResource()) } + + for _, q := range msg.QD { + fmt.Printf("%x: DBG Known addresses for: %s\n", msg.ID, q.QName.String()) + for idx, addr := range records.GetIPv4Addresses(q.QName.String()) { + fmt.Printf("%x: #%d: %s\n", msg.ID, idx, addr.String()) + } + } } err := proxy.Listen() if err != nil { diff --git a/rule-composer/records.go b/rule-composer/records.go new file mode 100644 index 0000000..86ba1c2 --- /dev/null +++ b/rule-composer/records.go @@ -0,0 +1,114 @@ +package ruleComposer + +import ( + "bytes" + "net" + "sync" + "time" +) + +type Records struct { + mutex sync.RWMutex + ipv4Addresses map[string]map[string]time.Time + cNames map[string]map[string]time.Time +} + +func (r *Records) getCNames(domainName string) []string { + _, ok := r.cNames[domainName] + if !ok { + return nil + } + + cNameList := make([]string, 0, len(r.cNames[domainName])) + for cname, ttl := range r.cNames[domainName] { + if time.Now().Sub(ttl).Nanoseconds() < 0 { + cNameList = append(cNameList, cname) + } + } + + origCNameLen := len(cNameList) + for i := 0; i < origCNameLen; i++ { + parentList := r.getCNames(cNameList[i]) + if parentList != nil { + cNameList = append(cNameList, parentList...) + } + } + + return cNameList +} + +func (r *Records) GetIPv4Addresses(domainName string) []net.IP { + r.mutex.RLock() + defer r.mutex.RUnlock() + + cNameList := append([]string{domainName}, r.getCNames(domainName)...) + ipAddresses := make([]net.IP, 0) + for _, cName := range cNameList { + addresses, ok := r.ipv4Addresses[cName] + if !ok { + continue + } + + addressesNetIP := make([]net.IP, 0, len(addresses)) + for addr, ttl := range addresses { + if time.Now().Sub(ttl).Nanoseconds() < 0 { + addressesNetIP = append(addressesNetIP, []byte(addr)) + } + } + + ipAddresses = append(ipAddresses, addressesNetIP...) + } + + return ipAddresses +} + +func (r *Records) PutCName(domainName string, cName string, ttl int64) { + r.mutex.Lock() + defer r.mutex.Unlock() + + if r.cNames[domainName] == nil { + r.cNames[domainName] = make(map[string]time.Time) + } + + skipPut := false + for name, _ := range r.cNames[domainName] { + if name == cName { + r.cNames[domainName][name] = time.Now().Add(time.Second * time.Duration(ttl)) + skipPut = true + break + } + } + + if !skipPut { + r.cNames[domainName][cName] = time.Now().Add(time.Second * time.Duration(ttl)) + } +} + +func (r *Records) PutIPv4Address(domainName string, addr net.IP, ttl int64) { + r.mutex.Lock() + defer r.mutex.Unlock() + + if r.ipv4Addresses[domainName] == nil { + r.ipv4Addresses[domainName] = make(map[string]time.Time) + } + + skipPut := false + for address, _ := range r.ipv4Addresses[domainName] { + if bytes.Compare([]byte(address), addr) == 0 { + r.ipv4Addresses[domainName][address] = time.Now().Add(time.Second * time.Duration(ttl)) + skipPut = true + break + } + } + + if !skipPut { + r.ipv4Addresses[domainName][string(addr)] = time.Now().Add(time.Second * time.Duration(ttl)) + } +} + +func NewRecords() *Records { + return &Records{ + ipv4Addresses: make(map[string]map[string]time.Time), + cNames: make(map[string]map[string]time.Time), + } +}