Commit b3f4b4c5 authored by sergeyu@chromium.org's avatar sergeyu@chromium.org

Fix ChannelMultiplexer to properly handle base channel creation failure.

ChannelMultiplexer may be destroyed from channel connection callback when
it fails to connect the base channel. Previously it didn't handle this case
properly.

BUG=152039


Review URL: https://chromiumcodereview.appspot.com/10981009

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@158740 0039d316-1c4b-4281-b951-d872f2087c98
parent 63a8954d
......@@ -9,7 +9,9 @@
#include "base/bind.h"
#include "base/callback.h"
#include "base/location.h"
#include "base/single_thread_task_runner.h"
#include "base/stl_util.h"
#include "base/thread_task_runner_handle.h"
#include "net/base/net_errors.h"
#include "net/socket/stream_socket.h"
#include "remoting/protocol/util.h"
......@@ -364,7 +366,7 @@ ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory,
: base_channel_factory_(factory),
base_channel_name_(base_channel_name),
next_channel_id_(0),
destroyed_flag_(NULL) {
ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) {
}
ChannelMultiplexer::~ChannelMultiplexer() {
......@@ -374,9 +376,6 @@ ChannelMultiplexer::~ChannelMultiplexer() {
// Cancel creation of the base channel if it hasn't finished.
if (base_channel_factory_)
base_channel_factory_->CancelChannelCreation(base_channel_name_);
if (destroyed_flag_)
*destroyed_flag_ = true;
}
void ChannelMultiplexer::CreateStreamChannel(
......@@ -425,30 +424,37 @@ void ChannelMultiplexer::OnBaseChannelReady(
base_channel_factory_ = NULL;
base_channel_ = socket.Pass();
if (!base_channel_.get()) {
// Notify all callers that we can't create any channels.
for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
it != pending_channels_.end(); ++it) {
it->callback.Run(scoped_ptr<net::StreamSocket>());
}
pending_channels_.clear();
return;
if (base_channel_.get()) {
// Initialize reader and writer.
reader_.Init(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnIncomingPacket,
base::Unretained(this)));
writer_.Init(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnWriteFailed,
base::Unretained(this)));
}
// Initialize reader and writer.
reader_.Init(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnIncomingPacket,
base::Unretained(this)));
writer_.Init(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnWriteFailed,
base::Unretained(this)));
DoCreatePendingChannels();
}
// Now create all pending channels.
for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
it != pending_channels_.end(); ++it) {
it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket());
}
pending_channels_.clear();
void ChannelMultiplexer::DoCreatePendingChannels() {
if (pending_channels_.empty())
return;
// Every time this function is called it connects a single channel and posts a
// separate task to connect other channels. This is necessary because the
// callback may destroy the multiplexer or somehow else modify
// |pending_channels_| list (e.g. call CancelChannelCreation()).
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
weak_factory_.GetWeakPtr()));
PendingChannel c = pending_channels_.front();
pending_channels_.erase(pending_channels_.begin());
scoped_ptr<net::StreamSocket> socket;
if (base_channel_.get())
socket = GetOrCreateChannel(c.name)->CreateSocket();
c.callback.Run(socket.Pass());
}
ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
......@@ -467,15 +473,19 @@ ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
void ChannelMultiplexer::OnWriteFailed(int error) {
bool destroyed = false;
destroyed_flag_ = &destroyed;
for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
it != channels_.end(); ++it) {
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed,
weak_factory_.GetWeakPtr(), it->second->name()));
}
}
void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) {
std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
if (it != channels_.end()) {
it->second->OnWriteFailed();
if (destroyed)
return;
}
destroyed_flag_ = NULL;
}
void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
......
......@@ -5,6 +5,7 @@
#ifndef REMOTING_PROTOCOL_CHANNEL_MULTIPLEXER_H_
#define REMOTING_PROTOCOL_CHANNEL_MULTIPLEXER_H_
#include "base/memory/weak_ptr.h"
#include "remoting/proto/mux.pb.h"
#include "remoting/protocol/buffered_socket_writer.h"
#include "remoting/protocol/channel_factory.h"
......@@ -40,11 +41,19 @@ class ChannelMultiplexer : public ChannelFactory {
// Callback for |base_channel_| creation.
void OnBaseChannelReady(scoped_ptr<net::StreamSocket> socket);
// Helper to create channels asynchronously.
void DoCreatePendingChannels();
// Helper method used to create channels.
MuxChannel* GetOrCreateChannel(const std::string& name);
// Callbacks for |writer_| and |reader_|.
// Error handling callback for |writer_|.
void OnWriteFailed(int error);
// Failed write notifier, queued asynchronously by OnWriteFailed().
void NotifyWriteFailed(const std::string& name);
// Callback for |reader_;
void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task);
......@@ -75,8 +84,7 @@ class ChannelMultiplexer : public ChannelFactory {
BufferedSocketWriter writer_;
ProtobufMessageReader<MultiplexPacket> reader_;
// Flag used by OnWriteFailed() to detect when the multiplexer is destroyed.
bool* destroyed_flag_;
base::WeakPtrFactory<ChannelMultiplexer> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(ChannelMultiplexer);
};
......
......@@ -28,6 +28,10 @@ const int kMessageSize = 1024;
const int kMessages = 100;
const char kMuxChannelName[] = "mux";
const char kTestChannelName[] = "test";
const char kTestChannelName2[] = "test2";
void QuitCurrentThread() {
MessageLoop::current()->PostTask(FROM_HERE, MessageLoop::QuitClosure());
}
......@@ -37,6 +41,14 @@ class MockSocketCallback {
MOCK_METHOD1(OnDone, void(int result));
};
class MockConnectCallback {
public:
MOCK_METHOD1(OnConnectedPtr, void(net::StreamSocket* socket));
void OnConnected(scoped_ptr<net::StreamSocket> socket) {
OnConnectedPtr(socket.release());
}
};
} // namespace
class ChannelMultiplexerTest : public testing::Test {
......@@ -50,6 +62,11 @@ class ChannelMultiplexerTest : public testing::Test {
client_mux_.reset();
}
void DeleteAfterSessionFail() {
host_mux_->CancelChannelCreation(kTestChannelName2);
DeleteAll();
}
protected:
virtual void SetUp() OVERRIDE {
// Create pair of multiplexers and connect them to each other.
......@@ -126,7 +143,8 @@ class ChannelMultiplexerTest : public testing::Test {
TEST_F(ChannelMultiplexerTest, OneChannel) {
scoped_ptr<net::StreamSocket> host_socket;
scoped_ptr<net::StreamSocket> client_socket;
ASSERT_NO_FATAL_FAILURE(CreateChannel("test", &host_socket, &client_socket));
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName, &host_socket, &client_socket));
ConnectSockets();
......@@ -141,12 +159,12 @@ TEST_F(ChannelMultiplexerTest, TwoChannels) {
scoped_ptr<net::StreamSocket> host_socket1_;
scoped_ptr<net::StreamSocket> client_socket1_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test", &host_socket1_, &client_socket1_));
CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
scoped_ptr<net::StreamSocket> host_socket2_;
scoped_ptr<net::StreamSocket> client_socket2_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
ConnectSockets();
......@@ -168,12 +186,12 @@ TEST_F(ChannelMultiplexerTest, FourChannels) {
scoped_ptr<net::StreamSocket> host_socket1_;
scoped_ptr<net::StreamSocket> client_socket1_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test", &host_socket1_, &client_socket1_));
CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
scoped_ptr<net::StreamSocket> host_socket2_;
scoped_ptr<net::StreamSocket> client_socket2_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
scoped_ptr<net::StreamSocket> host_socket3;
scoped_ptr<net::StreamSocket> client_socket3;
......@@ -209,16 +227,16 @@ TEST_F(ChannelMultiplexerTest, FourChannels) {
tester4.CheckResults();
}
TEST_F(ChannelMultiplexerTest, SyncFail) {
TEST_F(ChannelMultiplexerTest, WriteFailSync) {
scoped_ptr<net::StreamSocket> host_socket1_;
scoped_ptr<net::StreamSocket> client_socket1_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test", &host_socket1_, &client_socket1_));
CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
scoped_ptr<net::StreamSocket> host_socket2_;
scoped_ptr<net::StreamSocket> client_socket2_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
ConnectSockets();
......@@ -245,12 +263,12 @@ TEST_F(ChannelMultiplexerTest, SyncFail) {
message_loop_.RunAllPending();
}
TEST_F(ChannelMultiplexerTest, AsyncFail) {
TEST_F(ChannelMultiplexerTest, WriteFailAsync) {
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test", &host_socket1_, &client_socket1_));
CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
ConnectSockets();
......@@ -278,9 +296,9 @@ TEST_F(ChannelMultiplexerTest, AsyncFail) {
TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) {
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test", &host_socket1_, &client_socket1_));
CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
ConnectSockets();
......@@ -314,5 +332,27 @@ TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) {
EXPECT_FALSE(host_mux_.get());
}
TEST_F(ChannelMultiplexerTest, SessionFail) {
host_session_.set_async_creation(true);
host_session_.set_error(AUTHENTICATION_FAILED);
MockConnectCallback cb1;
MockConnectCallback cb2;
host_mux_->CreateStreamChannel(kTestChannelName, base::Bind(
&MockConnectCallback::OnConnected, base::Unretained(&cb1)));
host_mux_->CreateStreamChannel(kTestChannelName2, base::Bind(
&MockConnectCallback::OnConnected, base::Unretained(&cb2)));
EXPECT_CALL(cb1, OnConnectedPtr(NULL))
.Times(AtMost(1))
.WillOnce(InvokeWithoutArgs(
this, &ChannelMultiplexerTest::DeleteAfterSessionFail));
EXPECT_CALL(cb2, OnConnectedPtr(_))
.Times(0);
message_loop_.RunAllPending();
}
} // namespace protocol
} // namespace remoting
......@@ -27,7 +27,6 @@ class ConnectionToClientTest : public testing::Test {
protected:
virtual void SetUp() OVERRIDE {
session_ = new FakeSession();
session_->set_message_loop(&message_loop_);
// Allocate a ClientConnection object with the mock objects.
viewer_.reset(new ConnectionToClient(session_));
......
......@@ -284,10 +284,12 @@ FakeSession::FakeSession()
: event_handler_(NULL),
candidate_config_(CandidateSessionConfig::CreateDefault()),
config_(SessionConfig::GetDefault()),
message_loop_(NULL),
message_loop_(MessageLoop::current()),
async_creation_(false),
jid_(kTestJid),
error_(OK),
closed_(false) {
closed_(false),
ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) {
}
FakeSession::~FakeSession() { }
......@@ -337,20 +339,60 @@ void FakeSession::Close() {
}
void FakeSession::CreateStreamChannel(
const std::string& name, const StreamChannelCallback& callback) {
scoped_ptr<FakeSocket> channel(new FakeSocket());
stream_channels_[name] = channel.get();
callback.Run(channel.PassAs<net::StreamSocket>());
const std::string& name,
const StreamChannelCallback& callback) {
scoped_ptr<FakeSocket> channel;
// If we are in the error state then we put NULL in the channels list, so that
// NotifyStreamChannelCallback() still calls the callback.
if (error_ == OK)
channel.reset(new FakeSocket());
stream_channels_[name] = channel.release();
if (async_creation_) {
message_loop_->PostTask(FROM_HERE, base::Bind(
&FakeSession::NotifyStreamChannelCallback, weak_factory_.GetWeakPtr(),
name, callback));
} else {
NotifyStreamChannelCallback(name, callback);
}
}
void FakeSession::NotifyStreamChannelCallback(
const std::string& name,
const StreamChannelCallback& callback) {
if (stream_channels_.find(name) != stream_channels_.end())
callback.Run(scoped_ptr<net::StreamSocket>(stream_channels_[name]));
}
void FakeSession::CreateDatagramChannel(
const std::string& name, const DatagramChannelCallback& callback) {
scoped_ptr<FakeUdpSocket> channel(new FakeUdpSocket());
datagram_channels_[name] = channel.get();
callback.Run(channel.PassAs<net::Socket>());
const std::string& name,
const DatagramChannelCallback& callback) {
scoped_ptr<FakeUdpSocket> channel;
// If we are in the error state then we put NULL in the channels list, so that
// NotifyStreamChannelCallback() still calls the callback.
if (error_ == OK)
channel.reset(new FakeUdpSocket());
datagram_channels_[name] = channel.release();
if (async_creation_) {
message_loop_->PostTask(FROM_HERE, base::Bind(
&FakeSession::NotifyDatagramChannelCallback, weak_factory_.GetWeakPtr(),
name, callback));
} else {
NotifyDatagramChannelCallback(name, callback);
}
}
void FakeSession::NotifyDatagramChannelCallback(
const std::string& name,
const DatagramChannelCallback& callback) {
if (datagram_channels_.find(name) != datagram_channels_.end())
callback.Run(scoped_ptr<net::Socket>(datagram_channels_[name]));
}
void FakeSession::CancelChannelCreation(const std::string& name) {
stream_channels_.erase(name);
datagram_channels_.erase(name);
}
} // namespace protocol
......
......@@ -155,8 +155,8 @@ class FakeSession : public Session,
EventHandler* event_handler() { return event_handler_; }
void set_message_loop(MessageLoop* message_loop) {
message_loop_ = message_loop;
void set_async_creation(bool async_creation) {
async_creation_ = async_creation;
}
void set_error(ErrorCode error) { error_ = error; }
......@@ -179,18 +179,28 @@ class FakeSession : public Session,
// ChannelFactory interface.
virtual void CreateStreamChannel(
const std::string& name, const StreamChannelCallback& callback) OVERRIDE;
const std::string& name,
const StreamChannelCallback& callback) OVERRIDE;
virtual void CreateDatagramChannel(
const std::string& name,
const DatagramChannelCallback& callback) OVERRIDE;
virtual void CancelChannelCreation(const std::string& name) OVERRIDE;
public:
void NotifyStreamChannelCallback(
const std::string& name,
const StreamChannelCallback& callback);
void NotifyDatagramChannelCallback(
const std::string& name,
const DatagramChannelCallback& callback);
EventHandler* event_handler_;
scoped_ptr<const CandidateSessionConfig> candidate_config_;
SessionConfig config_;
MessageLoop* message_loop_;
bool async_creation_;
std::map<std::string, FakeSocket*> stream_channels_;
std::map<std::string, FakeUdpSocket*> datagram_channels_;
......@@ -199,6 +209,8 @@ class FakeSession : public Session,
ErrorCode error_;
bool closed_;
base::WeakPtrFactory<FakeSession> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(FakeSession);
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment