diff --git a/Lib/httplib.py b/Lib/httplib.py index 592ee57..b69145b 100644 --- a/Lib/httplib.py +++ b/Lib/httplib.py @@ -735,25 +735,40 @@ class HTTPConnection: self._tunnel_host = None self._tunnel_port = None self._tunnel_headers = {} - - self._set_hostport(host, port) if strict is not None: self.strict = strict + (self.host, self.port) = self._get_hostport(host, port) + + # This is stored as an instance variable to allow unittests + # to replace with a suitable mock + self._create_connection = socket.create_connection + def set_tunnel(self, host, port=None, headers=None): - """ Sets up the host and the port for the HTTP CONNECT Tunnelling. + """ Set up host and port for HTTP CONNECT tunnelling. + + In a connection that uses HTTP Connect tunneling, the host passed to the + constructor is used as proxy server that relays all communication to the + endpoint passed to set_tunnel. This is done by sending a HTTP CONNECT + request to the proxy server when the connection is established. + + This method must be called before the HTTP connection has been + established. The headers argument should be a mapping of extra HTTP headers to send with the CONNECT request. """ - self._tunnel_host = host - self._tunnel_port = port + # Verify if this is required. + if self.sock: + raise RuntimeError("Can't setup tunnel for established connection.") + + self._tunnel_host, self._tunnel_port = self._get_hostport(host, port) if headers: self._tunnel_headers = headers else: self._tunnel_headers.clear() - def _set_hostport(self, host, port): + def _get_hostport(self, host, port): if port is None: i = host.rfind(':') j = host.rfind(']') # ipv6 addresses have [...] @@ -770,15 +785,14 @@ class HTTPConnection: port = self.default_port if host and host[0] == '[' and host[-1] == ']': host = host[1:-1] - self.host = host - self.port = port + return (host, port) def set_debuglevel(self, level): self.debuglevel = level def _tunnel(self): - self._set_hostport(self._tunnel_host, self._tunnel_port) - self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port)) + self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self._tunnel_host, + self._tunnel_port)) for header, value in self._tunnel_headers.iteritems(): self.send("%s: %s\r\n" % (header, value)) self.send("\r\n") @@ -803,8 +817,8 @@ class HTTPConnection: def connect(self): """Connect to the host and port specified in __init__.""" - self.sock = socket.create_connection((self.host,self.port), - self.timeout, self.source_address) + self.sock = self._create_connection((self.host,self.port), + self.timeout, self.source_address) if self._tunnel_host: self._tunnel() @@ -942,17 +956,24 @@ class HTTPConnection: netloc_enc = netloc.encode("idna") self.putheader('Host', netloc_enc) else: + if self._tunnel_host: + host = self._tunnel_host + port = self._tunnel_port + else: + host = self.host + port = self.port + try: - host_enc = self.host.encode("ascii") + host_enc = host.encode("ascii") except UnicodeEncodeError: - host_enc = self.host.encode("idna") + host_enc = host.encode("idna") # Wrap the IPv6 Host Header with [] (RFC 2732) if host_enc.find(':') >= 0: host_enc = "[" + host_enc + "]" - if self.port == self.default_port: + if port == self.default_port: self.putheader('Host', host_enc) else: - self.putheader('Host', "%s:%s" % (host_enc, self.port)) + self.putheader('Host', "%s:%s" % (host_enc, port)) # note: we are assuming that clients will not attempt to set these # headers since *this* library must deal with the @@ -1141,7 +1162,7 @@ class HTTP: "Accept arguments to set the host/port, since the superclass doesn't." if host is not None: - self._conn._set_hostport(host, port) + (self._conn.host, self._conn.port) = self._conn._get_hostport(host, port) self._conn.connect() def getfile(self): diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index 29af589..9db30cc 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -21,10 +21,12 @@ CERT_selfsigned_pythontestdotnet = os.path.join(here, 'selfsigned_pythontestdotn HOST = test_support.HOST class FakeSocket: - def __init__(self, text, fileclass=StringIO.StringIO): + def __init__(self, text, fileclass=StringIO.StringIO, host=None, port=None): self.text = text self.fileclass = fileclass self.data = '' + self.host = host + self.port = port def sendall(self, data): self.data += ''.join(data) @@ -34,6 +36,9 @@ class FakeSocket: raise httplib.UnimplementedFileMode() return self.fileclass(self.text) + def close(self): + pass + class EPipeSocket(FakeSocket): def __init__(self, text, pipe_trigger): @@ -487,7 +492,11 @@ class OfflineTest(TestCase): self.assertEqual(httplib.responses[httplib.NOT_FOUND], "Not Found") -class SourceAddressTest(TestCase): +class TestServerMixin: + """A limited socket server mixin. + + This is used by test cases for testing http connection end points. + """ def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.port = test_support.bind_port(self.serv) @@ -502,6 +511,7 @@ class SourceAddressTest(TestCase): self.serv.close() self.serv = None +class SourceAddressTest(TestServerMixin, TestCase): def testHTTPConnectionSourceAddress(self): self.conn = httplib.HTTPConnection(HOST, self.port, source_address=('', self.source_port)) @@ -518,6 +528,24 @@ class SourceAddressTest(TestCase): # for an ssl_wrapped connect() to actually return from. +class HTTPTest(TestServerMixin, TestCase): + def testHTTPConnection(self): + self.conn = httplib.HTTP(host=HOST, port=self.port, strict=None) + self.conn.connect() + self.assertEqual(self.conn._conn.host, HOST) + self.assertEqual(self.conn._conn.port, self.port) + + def testHTTPWithConnectHostPort(self): + testhost = 'unreachable.test.domain' + testport = '80' + self.conn = httplib.HTTP(host=testhost, port=testport) + self.conn.connect(host=HOST, port=self.port) + self.assertNotEqual(self.conn._conn.host, testhost) + self.assertNotEqual(self.conn._conn.port, testport) + self.assertEqual(self.conn._conn.host, HOST) + self.assertEqual(self.conn._conn.port, self.port) + + class TimeoutTest(TestCase): PORT = None @@ -716,13 +744,54 @@ class HTTPSTest(TestCase): c = httplib.HTTPSConnection(hp, context=context) self.assertEqual(h, c.host) self.assertEqual(p, c.port) - + +class TunnelTests(TestCase): + def test_connect(self): + response_text = ( + 'HTTP/1.0 200 OK\r\n\r\n' # Reply to CONNECT + 'HTTP/1.1 200 OK\r\n' # Reply to HEAD + 'Content-Length: 42\r\n\r\n' + ) + + def create_connection(address, timeout=None, source_address=None): + return FakeSocket(response_text, host=address[0], port=address[1]) + + conn = httplib.HTTPConnection('proxy.com') + conn._create_connection = create_connection + + # Once connected, we should not be able to tunnel anymore + conn.connect() + self.assertRaises(RuntimeError, conn.set_tunnel, 'destination.com') + + # But if close the connection, we are good. + conn.close() + conn.set_tunnel('destination.com') + conn.request('HEAD', '/', '') + + self.assertEqual(conn.sock.host, 'proxy.com') + self.assertEqual(conn.sock.port, 80) + self.assertIn('CONNECT destination.com', conn.sock.data) + # issue22095 + self.assertNotIn('Host: destination.com:None', conn.sock.data) + # issue22095 + + self.assertNotIn('Host: proxy.com', conn.sock.data) + + conn.close() + + conn.request('PUT', '/', '') + self.assertEqual(conn.sock.host, 'proxy.com') + self.assertEqual(conn.sock.port, 80) + self.assertTrue('CONNECT destination.com' in conn.sock.data) + self.assertTrue('Host: destination.com' in conn.sock.data) + @test_support.reap_threads def test_main(verbose=None): test_support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest, - HTTPSTest, SourceAddressTest) + HTTPTest, HTTPSTest, SourceAddressTest, + TunnelTests) if __name__ == '__main__': test_main()