Force SSL for all connections of Acceptor (#2231)

* Force SSL for all connections

* Force SSL for all connections of Acceptor

* Force SSL option in ServerOptions
diff --git a/src/brpc/acceptor.cpp b/src/brpc/acceptor.cpp
index 6273288..f2d1c08 100644
--- a/src/brpc/acceptor.cpp
+++ b/src/brpc/acceptor.cpp
@@ -38,6 +38,7 @@
     , _listened_fd(-1)
     , _acception_id(0)
     , _empty_cond(&_map_mutex)
+    , _force_ssl(false)
     , _ssl_ctx(NULL) 
     , _use_rdma(false) {
 }
@@ -48,11 +49,18 @@
 }
 
 int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec,
-                          const std::shared_ptr<SocketSSLContext>& ssl_ctx) {
+                          const std::shared_ptr<SocketSSLContext>& ssl_ctx,
+                          bool force_ssl) {
     if (listened_fd < 0) {
         LOG(FATAL) << "Invalid listened_fd=" << listened_fd;
         return -1;
     }
+
+    if (!ssl_ctx && force_ssl) {
+        LOG(ERROR) << "Fail to force SSL for all connections "
+                      " because ssl_ctx is NULL";
+        return -1;
+    }
     
     BAIDU_SCOPED_LOCK(_map_mutex);
     if (_status == UNINITIALIZED) {
@@ -74,6 +82,7 @@
         }
     }
     _idle_timeout_sec = idle_timeout_sec;
+    _force_ssl = force_ssl;
     _ssl_ctx = ssl_ctx;
     
     // Creation of _acception_id is inside lock so that OnNewConnections
@@ -274,6 +283,7 @@
         options.fd = in_fd;
         butil::sockaddr2endpoint(&in_addr, in_len, &options.remote_side);
         options.user = acception->user();
+        options.force_ssl = am->_force_ssl;
         options.initial_ssl_ctx = am->_ssl_ctx;
 #if BRPC_WITH_RDMA
         if (am->_use_rdma) {
diff --git a/src/brpc/acceptor.h b/src/brpc/acceptor.h
index c442a60..c82cdcc 100644
--- a/src/brpc/acceptor.h
+++ b/src/brpc/acceptor.h
@@ -55,7 +55,8 @@
     // `idle_timeout_sec' > 0
     // Return 0 on success, -1 otherwise.
     int StartAccept(int listened_fd, int idle_timeout_sec,
-                    const std::shared_ptr<SocketSSLContext>& ssl_ctx);
+                    const std::shared_ptr<SocketSSLContext>& ssl_ctx,
+                    bool force_ssl);
 
     // [thread-safe] Stop accepting connections.
     // `closewait_ms' is not used anymore.
@@ -106,6 +107,7 @@
     // The map containing all the accepted sockets
     SocketMap _socket_map;
 
+    bool _force_ssl;
     std::shared_ptr<SocketSSLContext> _ssl_ctx;
 
     // Whether to use rdma or not
diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp
index 4953f88..ce5a0dd 100644
--- a/src/brpc/server.cpp
+++ b/src/brpc/server.cpp
@@ -139,6 +139,7 @@
     , bthread_init_count(0)
     , internal_port(-1)
     , has_builtin_services(true)
+    , force_ssl(false)
     , use_rdma(false)
     , http_master_service(NULL)
     , health_reporter(NULL)
@@ -933,6 +934,10 @@
                 return -1;
             }
         }
+    } else if (_options.force_ssl) {
+        LOG(ERROR) << "Fail to force SSL for all connections "
+                      "without ServerOptions.ssl_options";
+        return -1;
     }
 
     _concurrency = 0;
@@ -1045,7 +1050,8 @@
 
         // Pass ownership of `sockfd' to `_am'
         if (_am->StartAccept(sockfd, _options.idle_timeout_sec,
-                             _default_ssl_ctx) != 0) {
+                             _default_ssl_ctx,
+                             _options.force_ssl) != 0) {
             LOG(ERROR) << "Fail to start acceptor";
             return -1;
         }
@@ -1085,7 +1091,8 @@
         }
         // Pass ownership of `sockfd' to `_internal_am'
         if (_internal_am->StartAccept(sockfd, _options.idle_timeout_sec,
-                                      _default_ssl_ctx) != 0) {
+                                      _default_ssl_ctx,
+                                      false) != 0) {
             LOG(ERROR) << "Fail to start internal_acceptor";
             return -1;
         }
diff --git a/src/brpc/server.h b/src/brpc/server.h
index c00f9dc..e598a6e 100644
--- a/src/brpc/server.h
+++ b/src/brpc/server.h
@@ -217,6 +217,9 @@
     const ServerSSLOptions& ssl_options() const { return *_ssl_options; }
     ServerSSLOptions* mutable_ssl_options();
 
+    // Force ssl for all connections of the port to Start().
+    bool force_ssl;
+
     // Whether the server uses rdma or not
     // Default: false
     bool use_rdma;
diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp
index e0a6942..c49ca08 100644
--- a/src/brpc/socket.cpp
+++ b/src/brpc/socket.cpp
@@ -698,6 +698,7 @@
         m->SetFailed(rc2, "Fail to create auth_id: %s", berror(rc2));
         return -1;
     }
+    m->_force_ssl = options.force_ssl;
     // Disable SSL check if there is no SSL context
     m->_ssl_state = (options.initial_ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN);
     m->_ssl_session = NULL;
@@ -2026,6 +2027,10 @@
     }
     // _ssl_state has been set
     if (ssl_state() == SSL_OFF) {
+        if (_force_ssl) {
+            errno = ESSL;
+            return -1;
+        }
         CHECK(_rdma_state == RDMA_OFF);
         return _read_buf.append_from_file_descriptor(fd(), size_hint);
     }
diff --git a/src/brpc/socket.h b/src/brpc/socket.h
index bd753f6..eff9474 100644
--- a/src/brpc/socket.h
+++ b/src/brpc/socket.h
@@ -205,6 +205,8 @@
     // one thread at any time.
     void (*on_edge_triggered_events)(Socket*);
     int health_check_interval_s;
+    // Only accept ssl connection.
+    bool force_ssl;
     std::shared_ptr<SocketSSLContext> initial_ssl_ctx;
     bool use_rdma;
     bthread_keytable_pool_t* keytable_pool;
@@ -826,6 +828,8 @@
     // exists in server side
     AuthContext* _auth_context;
 
+    // Only accept ssl connection.
+    bool _force_ssl;
     SSLState _ssl_state;
     // SSL objects cannot be read and written at the same time.
     // Use mutex to protect SSL objects when ssl_state is SSL_CONNECTED.
diff --git a/src/brpc/socket_inl.h b/src/brpc/socket_inl.h
index 9423bfd..df93ac7 100644
--- a/src/brpc/socket_inl.h
+++ b/src/brpc/socket_inl.h
@@ -57,6 +57,7 @@
     , user(NULL)
     , on_edge_triggered_events(NULL)
     , health_check_interval_s(-1)
+    , force_ssl(false)
     , use_rdma(false)
     , keytable_pool(NULL)
     , conn(NULL)
diff --git a/test/brpc_channel_unittest.cpp b/test/brpc_channel_unittest.cpp
index 4de8e35..694f3f7 100644
--- a/test/brpc_channel_unittest.cpp
+++ b/test/brpc_channel_unittest.cpp
@@ -263,7 +263,7 @@
                 return -1;
             }
         }
-        if (_messenger.StartAccept(listening_fd, -1, NULL) != 0) {
+        if (_messenger.StartAccept(listening_fd, -1, NULL, false) != 0) {
             return -1;
         }
         return 0;
diff --git a/test/brpc_input_messenger_unittest.cpp b/test/brpc_input_messenger_unittest.cpp
index 7682b83..00b14ed 100644
--- a/test/brpc_input_messenger_unittest.cpp
+++ b/test/brpc_input_messenger_unittest.cpp
@@ -169,7 +169,7 @@
         ASSERT_TRUE(listening_fd > 0);
         butil::make_non_blocking(listening_fd);
         ASSERT_EQ(0, messenger[i].AddHandler(pairs[0]));
-        ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL));
+        ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL, false));
     }
     
     for (size_t i = 0; i < NCLIENT; ++i) {
diff --git a/test/brpc_socket_unittest.cpp b/test/brpc_socket_unittest.cpp
index 3f08091..36a3b1b 100644
--- a/test/brpc_socket_unittest.cpp
+++ b/test/brpc_socket_unittest.cpp
@@ -339,7 +339,7 @@
     ASSERT_TRUE(listening_fd > 0);
     butil::make_non_blocking(listening_fd);
     ASSERT_EQ(0, messenger->AddHandler(pairs[0]));
-    ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL));
+    ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false));
 
     brpc::SocketId id = 8888;
     brpc::SocketOptions options;
@@ -727,7 +727,7 @@
     ASSERT_TRUE(listening_fd > 0);
     butil::make_non_blocking(listening_fd);
     ASSERT_EQ(0, messenger->AddHandler(pairs[0]));
-    ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL));
+    ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false));
 
     int64_t start_time = butil::gettimeofday_us();
     nref = -1;
diff --git a/test/brpc_ssl_unittest.cpp b/test/brpc_ssl_unittest.cpp
index f32dbcb..7d58e45 100644
--- a/test/brpc_ssl_unittest.cpp
+++ b/test/brpc_ssl_unittest.cpp
@@ -35,6 +35,7 @@
 #include "echo.pb.h"
 
 namespace brpc {
+
 void ExtractHostnames(X509* x, std::vector<std::string>* hostnames);
 } // namespace brpc
 
@@ -175,6 +176,55 @@
     ASSERT_EQ(0, server.Join());
 }
 
+TEST_F(SSLTest, force_ssl) {
+    const int port = 8613;
+    brpc::Server server;
+    brpc::ServerOptions options;
+    EchoServiceImpl echo_svc;
+    ASSERT_EQ(0, server.AddService(
+        &echo_svc, brpc::SERVER_DOESNT_OWN_SERVICE));
+
+    options.force_ssl = true;
+    ASSERT_EQ(-1, server.Start(port, &options));
+
+    brpc::CertInfo cert;
+    cert.certificate = "cert1.crt";
+    cert.private_key = "cert1.key";
+    options.mutable_ssl_options()->default_cert = cert;
+
+    ASSERT_EQ(0, server.Start(port, &options));
+
+    test::EchoRequest req;
+    req.set_message(EXP_REQUEST);
+    {
+        brpc::Channel channel;
+        brpc::ChannelOptions coptions;
+        coptions.mutable_ssl_options();
+        coptions.mutable_ssl_options()->sni_name = "localhost";
+        ASSERT_EQ(0, channel.Init("localhost", port, &coptions));
+
+        brpc::Controller cntl;
+        test::EchoService_Stub stub(&channel);
+        test::EchoResponse res;
+        stub.Echo(&cntl, &req, &res, NULL);
+        EXPECT_EQ(EXP_RESPONSE, res.message()) << cntl.ErrorText();
+    }
+
+    {
+        brpc::Channel channel;
+        ASSERT_EQ(0, channel.Init("localhost", port, NULL));
+
+        brpc::Controller cntl;
+        test::EchoService_Stub stub(&channel);
+        test::EchoResponse res;
+        stub.Echo(&cntl, &req, &res, NULL);
+        EXPECT_TRUE(cntl.Failed());
+    }
+
+    ASSERT_EQ(0, server.Stop(0));
+    ASSERT_EQ(0, server.Join());
+}
+
 void CheckCert(const char* cname, const char* cert) {
     const int port = 8613;
     brpc::Channel channel;