You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
250 lines
9.5 KiB
250 lines
9.5 KiB
diff --git a/Lib/httplib.py b/Lib/httplib.py |
|
index 592ee57..b69145b 100644 |
|
--- a/Lib/httplib.py |
|
+++ b/Lib/httplib.py |
|
@@ -735,25 +735,40 @@ class HTTPConnection: |
|
self._tunnel_host = None |
|
self._tunnel_port = None |
|
self._tunnel_headers = {} |
|
- |
|
- self._set_hostport(host, port) |
|
if strict is not None: |
|
self.strict = strict |
|
|
|
+ (self.host, self.port) = self._get_hostport(host, port) |
|
+ |
|
+ # This is stored as an instance variable to allow unittests |
|
+ # to replace with a suitable mock |
|
+ self._create_connection = socket.create_connection |
|
+ |
|
def set_tunnel(self, host, port=None, headers=None): |
|
- """ Sets up the host and the port for the HTTP CONNECT Tunnelling. |
|
+ """ Set up host and port for HTTP CONNECT tunnelling. |
|
+ |
|
+ In a connection that uses HTTP Connect tunneling, the host passed to the |
|
+ constructor is used as proxy server that relays all communication to the |
|
+ endpoint passed to set_tunnel. This is done by sending a HTTP CONNECT |
|
+ request to the proxy server when the connection is established. |
|
+ |
|
+ This method must be called before the HTTP connection has been |
|
+ established. |
|
|
|
The headers argument should be a mapping of extra HTTP headers |
|
to send with the CONNECT request. |
|
""" |
|
- self._tunnel_host = host |
|
- self._tunnel_port = port |
|
+ # Verify if this is required. |
|
+ if self.sock: |
|
+ raise RuntimeError("Can't setup tunnel for established connection.") |
|
+ |
|
+ self._tunnel_host, self._tunnel_port = self._get_hostport(host, port) |
|
if headers: |
|
self._tunnel_headers = headers |
|
else: |
|
self._tunnel_headers.clear() |
|
|
|
- def _set_hostport(self, host, port): |
|
+ def _get_hostport(self, host, port): |
|
if port is None: |
|
i = host.rfind(':') |
|
j = host.rfind(']') # ipv6 addresses have [...] |
|
@@ -770,15 +785,14 @@ class HTTPConnection: |
|
port = self.default_port |
|
if host and host[0] == '[' and host[-1] == ']': |
|
host = host[1:-1] |
|
- self.host = host |
|
- self.port = port |
|
+ return (host, port) |
|
|
|
def set_debuglevel(self, level): |
|
self.debuglevel = level |
|
|
|
def _tunnel(self): |
|
- self._set_hostport(self._tunnel_host, self._tunnel_port) |
|
- self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port)) |
|
+ self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self._tunnel_host, |
|
+ self._tunnel_port)) |
|
for header, value in self._tunnel_headers.iteritems(): |
|
self.send("%s: %s\r\n" % (header, value)) |
|
self.send("\r\n") |
|
@@ -803,8 +817,8 @@ class HTTPConnection: |
|
|
|
def connect(self): |
|
"""Connect to the host and port specified in __init__.""" |
|
- self.sock = socket.create_connection((self.host,self.port), |
|
- self.timeout, self.source_address) |
|
+ self.sock = self._create_connection((self.host,self.port), |
|
+ self.timeout, self.source_address) |
|
|
|
if self._tunnel_host: |
|
self._tunnel() |
|
@@ -942,17 +956,24 @@ class HTTPConnection: |
|
netloc_enc = netloc.encode("idna") |
|
self.putheader('Host', netloc_enc) |
|
else: |
|
+ if self._tunnel_host: |
|
+ host = self._tunnel_host |
|
+ port = self._tunnel_port |
|
+ else: |
|
+ host = self.host |
|
+ port = self.port |
|
+ |
|
try: |
|
- host_enc = self.host.encode("ascii") |
|
+ host_enc = host.encode("ascii") |
|
except UnicodeEncodeError: |
|
- host_enc = self.host.encode("idna") |
|
+ host_enc = host.encode("idna") |
|
# Wrap the IPv6 Host Header with [] (RFC 2732) |
|
if host_enc.find(':') >= 0: |
|
host_enc = "[" + host_enc + "]" |
|
- if self.port == self.default_port: |
|
+ if port == self.default_port: |
|
self.putheader('Host', host_enc) |
|
else: |
|
- self.putheader('Host', "%s:%s" % (host_enc, self.port)) |
|
+ self.putheader('Host', "%s:%s" % (host_enc, port)) |
|
|
|
# note: we are assuming that clients will not attempt to set these |
|
# headers since *this* library must deal with the |
|
@@ -1141,7 +1162,7 @@ class HTTP: |
|
"Accept arguments to set the host/port, since the superclass doesn't." |
|
|
|
if host is not None: |
|
- self._conn._set_hostport(host, port) |
|
+ (self._conn.host, self._conn.port) = self._conn._get_hostport(host, port) |
|
self._conn.connect() |
|
|
|
def getfile(self): |
|
diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py |
|
index 29af589..9db30cc 100644 |
|
--- a/Lib/test/test_httplib.py |
|
+++ b/Lib/test/test_httplib.py |
|
@@ -21,10 +21,12 @@ CERT_selfsigned_pythontestdotnet = os.path.join(here, 'selfsigned_pythontestdotn |
|
HOST = test_support.HOST |
|
|
|
class FakeSocket: |
|
- def __init__(self, text, fileclass=StringIO.StringIO): |
|
+ def __init__(self, text, fileclass=StringIO.StringIO, host=None, port=None): |
|
self.text = text |
|
self.fileclass = fileclass |
|
self.data = '' |
|
+ self.host = host |
|
+ self.port = port |
|
|
|
def sendall(self, data): |
|
self.data += ''.join(data) |
|
@@ -34,6 +36,9 @@ class FakeSocket: |
|
raise httplib.UnimplementedFileMode() |
|
return self.fileclass(self.text) |
|
|
|
+ def close(self): |
|
+ pass |
|
+ |
|
class EPipeSocket(FakeSocket): |
|
|
|
def __init__(self, text, pipe_trigger): |
|
@@ -487,7 +492,11 @@ class OfflineTest(TestCase): |
|
self.assertEqual(httplib.responses[httplib.NOT_FOUND], "Not Found") |
|
|
|
|
|
-class SourceAddressTest(TestCase): |
|
+class TestServerMixin: |
|
+ """A limited socket server mixin. |
|
+ |
|
+ This is used by test cases for testing http connection end points. |
|
+ """ |
|
def setUp(self): |
|
self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
self.port = test_support.bind_port(self.serv) |
|
@@ -502,6 +511,7 @@ class SourceAddressTest(TestCase): |
|
self.serv.close() |
|
self.serv = None |
|
|
|
+class SourceAddressTest(TestServerMixin, TestCase): |
|
def testHTTPConnectionSourceAddress(self): |
|
self.conn = httplib.HTTPConnection(HOST, self.port, |
|
source_address=('', self.source_port)) |
|
@@ -518,6 +528,24 @@ class SourceAddressTest(TestCase): |
|
# for an ssl_wrapped connect() to actually return from. |
|
|
|
|
|
+class HTTPTest(TestServerMixin, TestCase): |
|
+ def testHTTPConnection(self): |
|
+ self.conn = httplib.HTTP(host=HOST, port=self.port, strict=None) |
|
+ self.conn.connect() |
|
+ self.assertEqual(self.conn._conn.host, HOST) |
|
+ self.assertEqual(self.conn._conn.port, self.port) |
|
+ |
|
+ def testHTTPWithConnectHostPort(self): |
|
+ testhost = 'unreachable.test.domain' |
|
+ testport = '80' |
|
+ self.conn = httplib.HTTP(host=testhost, port=testport) |
|
+ self.conn.connect(host=HOST, port=self.port) |
|
+ self.assertNotEqual(self.conn._conn.host, testhost) |
|
+ self.assertNotEqual(self.conn._conn.port, testport) |
|
+ self.assertEqual(self.conn._conn.host, HOST) |
|
+ self.assertEqual(self.conn._conn.port, self.port) |
|
+ |
|
+ |
|
class TimeoutTest(TestCase): |
|
PORT = None |
|
|
|
@@ -716,13 +744,54 @@ class HTTPSTest(TestCase): |
|
c = httplib.HTTPSConnection(hp, context=context) |
|
self.assertEqual(h, c.host) |
|
self.assertEqual(p, c.port) |
|
- |
|
+ |
|
+class TunnelTests(TestCase): |
|
+ def test_connect(self): |
|
+ response_text = ( |
|
+ 'HTTP/1.0 200 OK\r\n\r\n' # Reply to CONNECT |
|
+ 'HTTP/1.1 200 OK\r\n' # Reply to HEAD |
|
+ 'Content-Length: 42\r\n\r\n' |
|
+ ) |
|
+ |
|
+ def create_connection(address, timeout=None, source_address=None): |
|
+ return FakeSocket(response_text, host=address[0], port=address[1]) |
|
+ |
|
+ conn = httplib.HTTPConnection('proxy.com') |
|
+ conn._create_connection = create_connection |
|
+ |
|
+ # Once connected, we should not be able to tunnel anymore |
|
+ conn.connect() |
|
+ self.assertRaises(RuntimeError, conn.set_tunnel, 'destination.com') |
|
+ |
|
+ # But if close the connection, we are good. |
|
+ conn.close() |
|
+ conn.set_tunnel('destination.com') |
|
+ conn.request('HEAD', '/', '') |
|
+ |
|
+ self.assertEqual(conn.sock.host, 'proxy.com') |
|
+ self.assertEqual(conn.sock.port, 80) |
|
+ self.assertIn('CONNECT destination.com', conn.sock.data) |
|
+ # issue22095 |
|
+ self.assertNotIn('Host: destination.com:None', conn.sock.data) |
|
+ # issue22095 |
|
+ |
|
+ self.assertNotIn('Host: proxy.com', conn.sock.data) |
|
+ |
|
+ conn.close() |
|
+ |
|
+ conn.request('PUT', '/', '') |
|
+ self.assertEqual(conn.sock.host, 'proxy.com') |
|
+ self.assertEqual(conn.sock.port, 80) |
|
+ self.assertTrue('CONNECT destination.com' in conn.sock.data) |
|
+ self.assertTrue('Host: destination.com' in conn.sock.data) |
|
+ |
|
|
|
|
|
@test_support.reap_threads |
|
def test_main(verbose=None): |
|
test_support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest, |
|
- HTTPSTest, SourceAddressTest) |
|
+ HTTPTest, HTTPSTest, SourceAddressTest, |
|
+ TunnelTests) |
|
|
|
if __name__ == '__main__': |
|
test_main()
|
|
|