p2p/nat: limit UPNP request concurrency (#21390)
This adds a lock around requests because some routers can't handle concurrent requests. Requests are also rate-limited. The Map function request a new mapping exactly when the map timeout occurs instead of 5 minutes earlier. This should prevent duplicate mappings.
This commit is contained in:
		| @@ -91,15 +91,14 @@ func Parse(spec string) (Interface, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	mapTimeout        = 20 * time.Minute | 	mapTimeout = 10 * time.Minute | ||||||
| 	mapUpdateInterval = 15 * time.Minute |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Map adds a port mapping on m and keeps it alive until c is closed. | // Map adds a port mapping on m and keeps it alive until c is closed. | ||||||
| // This function is typically invoked in its own goroutine. | // This function is typically invoked in its own goroutine. | ||||||
| func Map(m Interface, c chan struct{}, protocol string, extport, intport int, name string) { | func Map(m Interface, c <-chan struct{}, protocol string, extport, intport int, name string) { | ||||||
| 	log := log.New("proto", protocol, "extport", extport, "intport", intport, "interface", m) | 	log := log.New("proto", protocol, "extport", extport, "intport", intport, "interface", m) | ||||||
| 	refresh := time.NewTimer(mapUpdateInterval) | 	refresh := time.NewTimer(mapTimeout) | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		refresh.Stop() | 		refresh.Stop() | ||||||
| 		log.Debug("Deleting port mapping") | 		log.Debug("Deleting port mapping") | ||||||
| @@ -121,7 +120,7 @@ func Map(m Interface, c chan struct{}, protocol string, extport, intport int, na | |||||||
| 			if err := m.AddMapping(protocol, extport, intport, name, mapTimeout); err != nil { | 			if err := m.AddMapping(protocol, extport, intport, name, mapTimeout); err != nil { | ||||||
| 				log.Debug("Couldn't add port mapping", "err", err) | 				log.Debug("Couldn't add port mapping", "err", err) | ||||||
| 			} | 			} | ||||||
| 			refresh.Reset(mapUpdateInterval) | 			refresh.Reset(mapTimeout) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -21,6 +21,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/huin/goupnp" | 	"github.com/huin/goupnp" | ||||||
| @@ -28,12 +29,17 @@ import ( | |||||||
| 	"github.com/huin/goupnp/dcps/internetgateway2" | 	"github.com/huin/goupnp/dcps/internetgateway2" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const soapRequestTimeout = 3 * time.Second | const ( | ||||||
|  | 	soapRequestTimeout = 3 * time.Second | ||||||
|  | 	rateLimit          = 200 * time.Millisecond | ||||||
|  | ) | ||||||
|  |  | ||||||
| type upnp struct { | type upnp struct { | ||||||
| 	dev         *goupnp.RootDevice | 	dev         *goupnp.RootDevice | ||||||
| 	service     string | 	service     string | ||||||
| 	client      upnpClient | 	client      upnpClient | ||||||
|  | 	mu          sync.Mutex | ||||||
|  | 	lastReqTime time.Time | ||||||
| } | } | ||||||
|  |  | ||||||
| type upnpClient interface { | type upnpClient interface { | ||||||
| @@ -43,8 +49,23 @@ type upnpClient interface { | |||||||
| 	GetNATRSIPStatus() (sip bool, nat bool, err error) | 	GetNATRSIPStatus() (sip bool, nat bool, err error) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (n *upnp) natEnabled() bool { | ||||||
|  | 	var ok bool | ||||||
|  | 	var err error | ||||||
|  | 	n.withRateLimit(func() error { | ||||||
|  | 		_, ok, err = n.client.GetNATRSIPStatus() | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | 	return err == nil && ok | ||||||
|  | } | ||||||
|  |  | ||||||
| func (n *upnp) ExternalIP() (addr net.IP, err error) { | func (n *upnp) ExternalIP() (addr net.IP, err error) { | ||||||
| 	ipString, err := n.client.GetExternalIPAddress() | 	var ipString string | ||||||
|  | 	n.withRateLimit(func() error { | ||||||
|  | 		ipString, err = n.client.GetExternalIPAddress() | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -63,7 +84,10 @@ func (n *upnp) AddMapping(protocol string, extport, intport int, desc string, li | |||||||
| 	protocol = strings.ToUpper(protocol) | 	protocol = strings.ToUpper(protocol) | ||||||
| 	lifetimeS := uint32(lifetime / time.Second) | 	lifetimeS := uint32(lifetime / time.Second) | ||||||
| 	n.DeleteMapping(protocol, extport, intport) | 	n.DeleteMapping(protocol, extport, intport) | ||||||
|  |  | ||||||
|  | 	return n.withRateLimit(func() error { | ||||||
| 		return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS) | 		return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS) | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *upnp) internalAddress() (net.IP, error) { | func (n *upnp) internalAddress() (net.IP, error) { | ||||||
| @@ -90,36 +114,51 @@ func (n *upnp) internalAddress() (net.IP, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (n *upnp) DeleteMapping(protocol string, extport, intport int) error { | func (n *upnp) DeleteMapping(protocol string, extport, intport int) error { | ||||||
|  | 	return n.withRateLimit(func() error { | ||||||
| 		return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol)) | 		return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol)) | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *upnp) String() string { | func (n *upnp) String() string { | ||||||
| 	return "UPNP " + n.service | 	return "UPNP " + n.service | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (n *upnp) withRateLimit(fn func() error) error { | ||||||
|  | 	n.mu.Lock() | ||||||
|  | 	defer n.mu.Unlock() | ||||||
|  |  | ||||||
|  | 	lastreq := time.Since(n.lastReqTime) | ||||||
|  | 	if lastreq < rateLimit { | ||||||
|  | 		time.Sleep(rateLimit - lastreq) | ||||||
|  | 	} | ||||||
|  | 	err := fn() | ||||||
|  | 	n.lastReqTime = time.Now() | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
| // discoverUPnP searches for Internet Gateway Devices | // discoverUPnP searches for Internet Gateway Devices | ||||||
| // and returns the first one it can find on the local network. | // and returns the first one it can find on the local network. | ||||||
| func discoverUPnP() Interface { | func discoverUPnP() Interface { | ||||||
| 	found := make(chan *upnp, 2) | 	found := make(chan *upnp, 2) | ||||||
| 	// IGDv1 | 	// IGDv1 | ||||||
| 	go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp { | 	go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(sc goupnp.ServiceClient) *upnp { | ||||||
| 		switch sc.Service.ServiceType { | 		switch sc.Service.ServiceType { | ||||||
| 		case internetgateway1.URN_WANIPConnection_1: | 		case internetgateway1.URN_WANIPConnection_1: | ||||||
| 			return &upnp{dev, "IGDv1-IP1", &internetgateway1.WANIPConnection1{ServiceClient: sc}} | 			return &upnp{service: "IGDv1-IP1", client: &internetgateway1.WANIPConnection1{ServiceClient: sc}} | ||||||
| 		case internetgateway1.URN_WANPPPConnection_1: | 		case internetgateway1.URN_WANPPPConnection_1: | ||||||
| 			return &upnp{dev, "IGDv1-PPP1", &internetgateway1.WANPPPConnection1{ServiceClient: sc}} | 			return &upnp{service: "IGDv1-PPP1", client: &internetgateway1.WANPPPConnection1{ServiceClient: sc}} | ||||||
| 		} | 		} | ||||||
| 		return nil | 		return nil | ||||||
| 	}) | 	}) | ||||||
| 	// IGDv2 | 	// IGDv2 | ||||||
| 	go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp { | 	go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(sc goupnp.ServiceClient) *upnp { | ||||||
| 		switch sc.Service.ServiceType { | 		switch sc.Service.ServiceType { | ||||||
| 		case internetgateway2.URN_WANIPConnection_1: | 		case internetgateway2.URN_WANIPConnection_1: | ||||||
| 			return &upnp{dev, "IGDv2-IP1", &internetgateway2.WANIPConnection1{ServiceClient: sc}} | 			return &upnp{service: "IGDv2-IP1", client: &internetgateway2.WANIPConnection1{ServiceClient: sc}} | ||||||
| 		case internetgateway2.URN_WANIPConnection_2: | 		case internetgateway2.URN_WANIPConnection_2: | ||||||
| 			return &upnp{dev, "IGDv2-IP2", &internetgateway2.WANIPConnection2{ServiceClient: sc}} | 			return &upnp{service: "IGDv2-IP2", client: &internetgateway2.WANIPConnection2{ServiceClient: sc}} | ||||||
| 		case internetgateway2.URN_WANPPPConnection_1: | 		case internetgateway2.URN_WANPPPConnection_1: | ||||||
| 			return &upnp{dev, "IGDv2-PPP1", &internetgateway2.WANPPPConnection1{ServiceClient: sc}} | 			return &upnp{service: "IGDv2-PPP1", client: &internetgateway2.WANPPPConnection1{ServiceClient: sc}} | ||||||
| 		} | 		} | ||||||
| 		return nil | 		return nil | ||||||
| 	}) | 	}) | ||||||
| @@ -134,7 +173,7 @@ func discoverUPnP() Interface { | |||||||
| // finds devices matching the given target and calls matcher for all | // finds devices matching the given target and calls matcher for all | ||||||
| // advertised services of each device. The first non-nil service found | // advertised services of each device. The first non-nil service found | ||||||
| // is sent into out. If no service matched, nil is sent. | // is sent into out. If no service matched, nil is sent. | ||||||
| func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, goupnp.ServiceClient) *upnp) { | func discover(out chan<- *upnp, target string, matcher func(goupnp.ServiceClient) *upnp) { | ||||||
| 	devs, err := goupnp.DiscoverDevices(target) | 	devs, err := goupnp.DiscoverDevices(target) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		out <- nil | 		out <- nil | ||||||
| @@ -157,16 +196,17 @@ func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, | |||||||
| 				Service:    service, | 				Service:    service, | ||||||
| 			} | 			} | ||||||
| 			sc.SOAPClient.HTTPClient.Timeout = soapRequestTimeout | 			sc.SOAPClient.HTTPClient.Timeout = soapRequestTimeout | ||||||
| 			upnp := matcher(devs[i].Root, sc) | 			upnp := matcher(sc) | ||||||
| 			if upnp == nil { | 			if upnp == nil { | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
|  | 			upnp.dev = devs[i].Root | ||||||
|  |  | ||||||
| 			// check whether port mapping is enabled | 			// check whether port mapping is enabled | ||||||
| 			if _, nat, err := upnp.client.GetNATRSIPStatus(); err != nil || !nat { | 			if upnp.natEnabled() { | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 				out <- upnp | 				out <- upnp | ||||||
| 				found = true | 				found = true | ||||||
|  | 			} | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| 	if !found { | 	if !found { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user