You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
251 lines
9.5 KiB
251 lines
9.5 KiB
6 years ago
|
diff --git a/Lib/httplib.py b/Lib/httplib.py
|
||
|
index 592ee57..b69145b 100644
|
||
|
--- a/Lib/httplib.py
|
||
|
+++ b/Lib/httplib.py
|
||
|
@@ -735,25 +735,40 @@ class HTTPConnection:
|
||
|
self._tunnel_host = None
|
||
|
self._tunnel_port = None
|
||
|
self._tunnel_headers = {}
|
||
|
-
|
||
|
- self._set_hostport(host, port)
|
||
|
if strict is not None:
|
||
|
self.strict = strict
|
||
|
|
||
|
+ (self.host, self.port) = self._get_hostport(host, port)
|
||
|
+
|
||
|
+ # This is stored as an instance variable to allow unittests
|
||
|
+ # to replace with a suitable mock
|
||
|
+ self._create_connection = socket.create_connection
|
||
|
+
|
||
|
def set_tunnel(self, host, port=None, headers=None):
|
||
|
- """ Sets up the host and the port for the HTTP CONNECT Tunnelling.
|
||
|
+ """ Set up host and port for HTTP CONNECT tunnelling.
|
||
|
+
|
||
|
+ In a connection that uses HTTP Connect tunneling, the host passed to the
|
||
|
+ constructor is used as proxy server that relays all communication to the
|
||
|
+ endpoint passed to set_tunnel. This is done by sending a HTTP CONNECT
|
||
|
+ request to the proxy server when the connection is established.
|
||
|
+
|
||
|
+ This method must be called before the HTTP connection has been
|
||
|
+ established.
|
||
|
|
||
|
The headers argument should be a mapping of extra HTTP headers
|
||
|
to send with the CONNECT request.
|
||
|
"""
|
||
|
- self._tunnel_host = host
|
||
|
- self._tunnel_port = port
|
||
|
+ # Verify if this is required.
|
||
|
+ if self.sock:
|
||
|
+ raise RuntimeError("Can't setup tunnel for established connection.")
|
||
|
+
|
||
|
+ self._tunnel_host, self._tunnel_port = self._get_hostport(host, port)
|
||
|
if headers:
|
||
|
self._tunnel_headers = headers
|
||
|
else:
|
||
|
self._tunnel_headers.clear()
|
||
|
|
||
|
- def _set_hostport(self, host, port):
|
||
|
+ def _get_hostport(self, host, port):
|
||
|
if port is None:
|
||
|
i = host.rfind(':')
|
||
|
j = host.rfind(']') # ipv6 addresses have [...]
|
||
|
@@ -770,15 +785,14 @@ class HTTPConnection:
|
||
|
port = self.default_port
|
||
|
if host and host[0] == '[' and host[-1] == ']':
|
||
|
host = host[1:-1]
|
||
|
- self.host = host
|
||
|
- self.port = port
|
||
|
+ return (host, port)
|
||
|
|
||
|
def set_debuglevel(self, level):
|
||
|
self.debuglevel = level
|
||
|
|
||
|
def _tunnel(self):
|
||
|
- self._set_hostport(self._tunnel_host, self._tunnel_port)
|
||
|
- self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port))
|
||
|
+ self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self._tunnel_host,
|
||
|
+ self._tunnel_port))
|
||
|
for header, value in self._tunnel_headers.iteritems():
|
||
|
self.send("%s: %s\r\n" % (header, value))
|
||
|
self.send("\r\n")
|
||
|
@@ -803,8 +817,8 @@ class HTTPConnection:
|
||
|
|
||
|
def connect(self):
|
||
|
"""Connect to the host and port specified in __init__."""
|
||
|
- self.sock = socket.create_connection((self.host,self.port),
|
||
|
- self.timeout, self.source_address)
|
||
|
+ self.sock = self._create_connection((self.host,self.port),
|
||
|
+ self.timeout, self.source_address)
|
||
|
|
||
|
if self._tunnel_host:
|
||
|
self._tunnel()
|
||
|
@@ -942,17 +956,24 @@ class HTTPConnection:
|
||
|
netloc_enc = netloc.encode("idna")
|
||
|
self.putheader('Host', netloc_enc)
|
||
|
else:
|
||
|
+ if self._tunnel_host:
|
||
|
+ host = self._tunnel_host
|
||
|
+ port = self._tunnel_port
|
||
|
+ else:
|
||
|
+ host = self.host
|
||
|
+ port = self.port
|
||
|
+
|
||
|
try:
|
||
|
- host_enc = self.host.encode("ascii")
|
||
|
+ host_enc = host.encode("ascii")
|
||
|
except UnicodeEncodeError:
|
||
|
- host_enc = self.host.encode("idna")
|
||
|
+ host_enc = host.encode("idna")
|
||
|
# Wrap the IPv6 Host Header with [] (RFC 2732)
|
||
|
if host_enc.find(':') >= 0:
|
||
|
host_enc = "[" + host_enc + "]"
|
||
|
- if self.port == self.default_port:
|
||
|
+ if port == self.default_port:
|
||
|
self.putheader('Host', host_enc)
|
||
|
else:
|
||
|
- self.putheader('Host', "%s:%s" % (host_enc, self.port))
|
||
|
+ self.putheader('Host', "%s:%s" % (host_enc, port))
|
||
|
|
||
|
# note: we are assuming that clients will not attempt to set these
|
||
|
# headers since *this* library must deal with the
|
||
|
@@ -1141,7 +1162,7 @@ class HTTP:
|
||
|
"Accept arguments to set the host/port, since the superclass doesn't."
|
||
|
|
||
|
if host is not None:
|
||
|
- self._conn._set_hostport(host, port)
|
||
|
+ (self._conn.host, self._conn.port) = self._conn._get_hostport(host, port)
|
||
|
self._conn.connect()
|
||
|
|
||
|
def getfile(self):
|
||
|
diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py
|
||
|
index 29af589..9db30cc 100644
|
||
|
--- a/Lib/test/test_httplib.py
|
||
|
+++ b/Lib/test/test_httplib.py
|
||
|
@@ -21,10 +21,12 @@ CERT_selfsigned_pythontestdotnet = os.path.join(here, 'selfsigned_pythontestdotn
|
||
|
HOST = test_support.HOST
|
||
|
|
||
|
class FakeSocket:
|
||
|
- def __init__(self, text, fileclass=StringIO.StringIO):
|
||
|
+ def __init__(self, text, fileclass=StringIO.StringIO, host=None, port=None):
|
||
|
self.text = text
|
||
|
self.fileclass = fileclass
|
||
|
self.data = ''
|
||
|
+ self.host = host
|
||
|
+ self.port = port
|
||
|
|
||
|
def sendall(self, data):
|
||
|
self.data += ''.join(data)
|
||
|
@@ -34,6 +36,9 @@ class FakeSocket:
|
||
|
raise httplib.UnimplementedFileMode()
|
||
|
return self.fileclass(self.text)
|
||
|
|
||
|
+ def close(self):
|
||
|
+ pass
|
||
|
+
|
||
|
class EPipeSocket(FakeSocket):
|
||
|
|
||
|
def __init__(self, text, pipe_trigger):
|
||
|
@@ -487,7 +492,11 @@ class OfflineTest(TestCase):
|
||
|
self.assertEqual(httplib.responses[httplib.NOT_FOUND], "Not Found")
|
||
|
|
||
|
|
||
|
-class SourceAddressTest(TestCase):
|
||
|
+class TestServerMixin:
|
||
|
+ """A limited socket server mixin.
|
||
|
+
|
||
|
+ This is used by test cases for testing http connection end points.
|
||
|
+ """
|
||
|
def setUp(self):
|
||
|
self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
|
self.port = test_support.bind_port(self.serv)
|
||
|
@@ -502,6 +511,7 @@ class SourceAddressTest(TestCase):
|
||
|
self.serv.close()
|
||
|
self.serv = None
|
||
|
|
||
|
+class SourceAddressTest(TestServerMixin, TestCase):
|
||
|
def testHTTPConnectionSourceAddress(self):
|
||
|
self.conn = httplib.HTTPConnection(HOST, self.port,
|
||
|
source_address=('', self.source_port))
|
||
|
@@ -518,6 +528,24 @@ class SourceAddressTest(TestCase):
|
||
|
# for an ssl_wrapped connect() to actually return from.
|
||
|
|
||
|
|
||
|
+class HTTPTest(TestServerMixin, TestCase):
|
||
|
+ def testHTTPConnection(self):
|
||
|
+ self.conn = httplib.HTTP(host=HOST, port=self.port, strict=None)
|
||
|
+ self.conn.connect()
|
||
|
+ self.assertEqual(self.conn._conn.host, HOST)
|
||
|
+ self.assertEqual(self.conn._conn.port, self.port)
|
||
|
+
|
||
|
+ def testHTTPWithConnectHostPort(self):
|
||
|
+ testhost = 'unreachable.test.domain'
|
||
|
+ testport = '80'
|
||
|
+ self.conn = httplib.HTTP(host=testhost, port=testport)
|
||
|
+ self.conn.connect(host=HOST, port=self.port)
|
||
|
+ self.assertNotEqual(self.conn._conn.host, testhost)
|
||
|
+ self.assertNotEqual(self.conn._conn.port, testport)
|
||
|
+ self.assertEqual(self.conn._conn.host, HOST)
|
||
|
+ self.assertEqual(self.conn._conn.port, self.port)
|
||
|
+
|
||
|
+
|
||
|
class TimeoutTest(TestCase):
|
||
|
PORT = None
|
||
|
|
||
|
@@ -716,13 +744,54 @@ class HTTPSTest(TestCase):
|
||
|
c = httplib.HTTPSConnection(hp, context=context)
|
||
|
self.assertEqual(h, c.host)
|
||
|
self.assertEqual(p, c.port)
|
||
|
-
|
||
|
+
|
||
|
+class TunnelTests(TestCase):
|
||
|
+ def test_connect(self):
|
||
|
+ response_text = (
|
||
|
+ 'HTTP/1.0 200 OK\r\n\r\n' # Reply to CONNECT
|
||
|
+ 'HTTP/1.1 200 OK\r\n' # Reply to HEAD
|
||
|
+ 'Content-Length: 42\r\n\r\n'
|
||
|
+ )
|
||
|
+
|
||
|
+ def create_connection(address, timeout=None, source_address=None):
|
||
|
+ return FakeSocket(response_text, host=address[0], port=address[1])
|
||
|
+
|
||
|
+ conn = httplib.HTTPConnection('proxy.com')
|
||
|
+ conn._create_connection = create_connection
|
||
|
+
|
||
|
+ # Once connected, we should not be able to tunnel anymore
|
||
|
+ conn.connect()
|
||
|
+ self.assertRaises(RuntimeError, conn.set_tunnel, 'destination.com')
|
||
|
+
|
||
|
+ # But if close the connection, we are good.
|
||
|
+ conn.close()
|
||
|
+ conn.set_tunnel('destination.com')
|
||
|
+ conn.request('HEAD', '/', '')
|
||
|
+
|
||
|
+ self.assertEqual(conn.sock.host, 'proxy.com')
|
||
|
+ self.assertEqual(conn.sock.port, 80)
|
||
|
+ self.assertIn('CONNECT destination.com', conn.sock.data)
|
||
|
+ # issue22095
|
||
|
+ self.assertNotIn('Host: destination.com:None', conn.sock.data)
|
||
|
+ # issue22095
|
||
|
+
|
||
|
+ self.assertNotIn('Host: proxy.com', conn.sock.data)
|
||
|
+
|
||
|
+ conn.close()
|
||
|
+
|
||
|
+ conn.request('PUT', '/', '')
|
||
|
+ self.assertEqual(conn.sock.host, 'proxy.com')
|
||
|
+ self.assertEqual(conn.sock.port, 80)
|
||
|
+ self.assertTrue('CONNECT destination.com' in conn.sock.data)
|
||
|
+ self.assertTrue('Host: destination.com' in conn.sock.data)
|
||
|
+
|
||
|
|
||
|
|
||
|
@test_support.reap_threads
|
||
|
def test_main(verbose=None):
|
||
|
test_support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
|
||
|
- HTTPSTest, SourceAddressTest)
|
||
|
+ HTTPTest, HTTPSTest, SourceAddressTest,
|
||
|
+ TunnelTests)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
test_main()
|