Commit 3e9187d0 authored by sergeyu@chromium.org's avatar sergeyu@chromium.org

Add support for multiplexed channels in remoting::protocol::Session interface.

Now the Session interface has two methods that return channel factories - one for regular channels and one for multiplexed channels.
Also refactored AudioReader and AudioWriter to inherit from 
ChannelDispatcherBase.

BUG=137135


Review URL: https://chromiumcodereview.appspot.com/10823323

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@152240 0039d316-1c4b-4281-b951-d872f2087c98
parent 2c539b89
......@@ -14,14 +14,12 @@ namespace remoting {
namespace protocol {
AudioReader::AudioReader(AudioPacket::Encoding encoding)
: session_(NULL),
: ChannelDispatcherBase(kAudioChannelName),
encoding_(encoding),
audio_stub_(NULL) {
}
AudioReader::~AudioReader() {
if (session_)
session_->CancelChannelCreation(kAudioChannelName);
}
// static
......@@ -32,33 +30,9 @@ scoped_ptr<AudioReader> AudioReader::Create(const SessionConfig& config) {
return scoped_ptr<AudioReader>(new AudioReader(AudioPacket::ENCODING_RAW));
}
void AudioReader::Init(protocol::Session* session,
AudioStub* audio_stub,
const InitializedCallback& callback) {
session_ = session;
initialized_callback_ = callback;
audio_stub_ = audio_stub;
session_->CreateStreamChannel(
kAudioChannelName,
base::Bind(&AudioReader::OnChannelReady, base::Unretained(this)));
}
bool AudioReader::is_connected() {
return channel_.get() != NULL;
}
void AudioReader::OnChannelReady(scoped_ptr<net::StreamSocket> socket) {
if (!socket.get()) {
initialized_callback_.Run(false);
return;
}
DCHECK(!channel_.get());
channel_ = socket.Pass();
reader_.Init(channel_.get(), base::Bind(&AudioReader::OnNewData,
base::Unretained(this)));
initialized_callback_.Run(true);
void AudioReader::OnInitialized() {
reader_.Init(channel(), base::Bind(&AudioReader::OnNewData,
base::Unretained(this)));
}
void AudioReader::OnNewData(scoped_ptr<AudioPacket> packet,
......
......@@ -10,6 +10,7 @@
#include "remoting/proto/audio.pb.h"
#include "remoting/protocol/audio_stub.h"
#include "remoting/protocol/message_reader.h"
#include "remoting/protocol/channel_dispatcher_base.h"
namespace net {
class StreamSocket;
......@@ -21,38 +22,25 @@ namespace protocol {
class Session;
class SessionConfig;
class AudioReader {
class AudioReader : public ChannelDispatcherBase {
public:
// The callback is called when initialization is finished. The
// parameter is set to true on success.
typedef base::Callback<void(bool)> InitializedCallback;
static scoped_ptr<AudioReader> Create(const SessionConfig& config);
virtual ~AudioReader();
static scoped_ptr<AudioReader> Create(const SessionConfig& config);
void set_audio_stub(AudioStub* audio_stub) { audio_stub_ = audio_stub; }
// Initializies the reader.
void Init(Session* session,
AudioStub* audio_stub,
const InitializedCallback& callback);
bool is_connected();
protected:
virtual void OnInitialized() OVERRIDE;
private:
explicit AudioReader(AudioPacket::Encoding encoding);
void OnChannelReady(scoped_ptr<net::StreamSocket> socket);
void OnNewData(scoped_ptr<AudioPacket> packet,
const base::Closure& done_task);
Session* session_;
InitializedCallback initialized_callback_;
AudioPacket::Encoding encoding_;
// TODO(sergeyu): Remove |channel_| and let |reader_| own it.
scoped_ptr<net::StreamSocket> channel_;
ProtobufMessageReader<AudioPacket> reader_;
// The stub that processes all received packets.
......
......@@ -16,53 +16,20 @@ namespace remoting {
namespace protocol {
AudioWriter::AudioWriter()
: session_(NULL) {
: ChannelDispatcherBase(kAudioChannelName) {
}
AudioWriter::~AudioWriter() {
Close();
}
void AudioWriter::Init(protocol::Session* session,
const InitializedCallback& callback) {
session_ = session;
initialized_callback_ = callback;
session_->CreateStreamChannel(
kAudioChannelName,
base::Bind(&AudioWriter::OnChannelReady, base::Unretained(this)));
}
void AudioWriter::OnChannelReady(scoped_ptr<net::StreamSocket> socket) {
if (!socket.get()) {
initialized_callback_.Run(false);
return;
}
DCHECK(!channel_.get());
channel_ = socket.Pass();
// TODO(sergeyu): Provide WriteFailedCallback for the buffered writer.
void AudioWriter::OnInitialized() {
// TODO(sergeyu): Provide a non-null WriteFailedCallback for the writer.
buffered_writer_.Init(
channel_.get(), BufferedSocketWriter::WriteFailedCallback());
initialized_callback_.Run(true);
}
void AudioWriter::Close() {
buffered_writer_.Close();
channel_.reset();
if (session_) {
session_->CancelChannelCreation(kAudioChannelName);
session_ = NULL;
}
}
bool AudioWriter::is_connected() {
return channel_.get() != NULL;
channel(), BufferedSocketWriter::WriteFailedCallback());
}
void AudioWriter::ProcessAudioPacket(scoped_ptr<AudioPacket> packet,
const base::Closure& done) {
const base::Closure& done) {
buffered_writer_.Write(SerializeAndFrameMessage(*packet), done);
}
......
......@@ -13,6 +13,7 @@
#include "base/memory/scoped_ptr.h"
#include "remoting/protocol/audio_stub.h"
#include "remoting/protocol/buffered_socket_writer.h"
#include "remoting/protocol/channel_dispatcher_base.h"
namespace net {
class StreamSocket;
......@@ -24,42 +25,25 @@ namespace protocol {
class Session;
class SessionConfig;
class AudioWriter : public AudioStub {
class AudioWriter : public ChannelDispatcherBase,
public AudioStub {
public:
virtual ~AudioWriter();
// The callback is called when initialization is finished. The
// parameter is set to true on success.
typedef base::Callback<void(bool)> InitializedCallback;
// Once AudioWriter is created, the Init() method of ChannelDispatcherBase
// should be used to initialize it for the session.
static scoped_ptr<AudioWriter> Create(const SessionConfig& config);
// Initializes the writer.
void Init(Session* session, const InitializedCallback& callback);
// Stops writing. Must be called on the network thread before this
// object is destroyed.
void Close();
// Returns true if the channel is connected.
bool is_connected();
virtual ~AudioWriter();
// AudioStub interface.
virtual void ProcessAudioPacket(scoped_ptr<AudioPacket> packet,
const base::Closure& done) OVERRIDE;
protected:
virtual void OnInitialized() OVERRIDE;
private:
AudioWriter();
void OnChannelReady(scoped_ptr<net::StreamSocket> socket);
Session* session_;
InitializedCallback initialized_callback_;
// TODO(sergeyu): Remove |channel_| and let |buffered_writer_| own it.
scoped_ptr<net::StreamSocket> channel_;
BufferedSocketWriter buffered_writer_;
DISALLOW_COPY_AND_ASSIGN(AudioWriter);
......
......@@ -6,6 +6,7 @@
#include "base/bind.h"
#include "net/socket/stream_socket.h"
#include "remoting/protocol/channel_factory.h"
#include "remoting/protocol/session.h"
namespace remoting {
......@@ -13,21 +14,21 @@ namespace protocol {
ChannelDispatcherBase::ChannelDispatcherBase(const char* channel_name)
: channel_name_(channel_name),
session_(NULL) {
channel_factory_(NULL) {
}
ChannelDispatcherBase::~ChannelDispatcherBase() {
if (session_)
session_->CancelChannelCreation(channel_name_);
if (channel_factory_)
channel_factory_->CancelChannelCreation(channel_name_);
}
void ChannelDispatcherBase::Init(Session* session,
const InitializedCallback& callback) {
DCHECK(session);
session_ = session;
channel_factory_ = session->GetTransportChannelFactory();
initialized_callback_ = callback;
session_->CreateStreamChannel(channel_name_, base::Bind(
channel_factory_->CreateStreamChannel(channel_name_, base::Bind(
&ChannelDispatcherBase::OnChannelReady, base::Unretained(this)));
}
......
......@@ -18,6 +18,7 @@ class StreamSocket;
namespace remoting {
namespace protocol {
class ChannelFactory;
class Session;
// Base class for channel message dispatchers. It's responsible for
......@@ -52,7 +53,7 @@ class ChannelDispatcherBase {
void OnChannelReady(scoped_ptr<net::StreamSocket> socket);
std::string channel_name_;
Session* session_;
ChannelFactory* channel_factory_;
InitializedCallback initialized_callback_;
scoped_ptr<net::StreamSocket> channel_;
......
......@@ -365,10 +365,6 @@ ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory,
base_channel_name_(base_channel_name),
next_channel_id_(0),
destroyed_flag_(NULL) {
factory->CreateStreamChannel(
base_channel_name,
base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
base::Unretained(this)));
}
ChannelMultiplexer::~ChannelMultiplexer() {
......@@ -396,6 +392,14 @@ void ChannelMultiplexer::CreateStreamChannel(
} else {
// Still waiting for the |base_channel_|.
pending_channels_.push_back(PendingChannel(name, callback));
// If this is the first multiplexed channel then create the base channel.
if (pending_channels_.size() == 1U) {
base_channel_factory_->CreateStreamChannel(
base_channel_name_,
base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
base::Unretained(this)));
}
}
}
......
......@@ -56,6 +56,11 @@ class ChannelMultiplexerTest : public testing::Test {
host_mux_.reset(new ChannelMultiplexer(&host_session_, kMuxChannelName));
client_mux_.reset(new ChannelMultiplexer(&client_session_,
kMuxChannelName));
}
// Connect sockets to each other. Must be called after we've created at least
// one channel with each multiplexer.
void ConnectSockets() {
FakeSocket* host_socket =
host_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
FakeSocket* client_socket =
......@@ -123,6 +128,8 @@ TEST_F(ChannelMultiplexerTest, OneChannel) {
scoped_ptr<net::StreamSocket> client_socket;
ASSERT_NO_FATAL_FAILURE(CreateChannel("test", &host_socket, &client_socket));
ConnectSockets();
StreamConnectionTester tester(host_socket.get(), client_socket.get(),
kMessageSize, kMessages);
tester.Start();
......@@ -141,6 +148,8 @@ TEST_F(ChannelMultiplexerTest, TwoChannels) {
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
ConnectSockets();
StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
kMessageSize, kMessages);
StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
......@@ -176,6 +185,8 @@ TEST_F(ChannelMultiplexerTest, FourChannels) {
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch4", &host_socket4, &client_socket4));
ConnectSockets();
StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
kMessageSize, kMessages);
StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
......@@ -209,6 +220,8 @@ TEST_F(ChannelMultiplexerTest, SyncFail) {
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
ConnectSockets();
host_session_.GetStreamChannel(kMuxChannelName)->
set_next_write_error(net::ERR_FAILED);
host_session_.GetStreamChannel(kMuxChannelName)->
......@@ -239,6 +252,8 @@ TEST_F(ChannelMultiplexerTest, AsyncFail) {
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
ConnectSockets();
host_session_.GetStreamChannel(kMuxChannelName)->
set_next_write_error(net::ERR_FAILED);
host_session_.GetStreamChannel(kMuxChannelName)->
......@@ -267,6 +282,8 @@ TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) {
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
ConnectSockets();
host_session_.GetStreamChannel(kMuxChannelName)->
set_next_write_error(net::ERR_FAILED);
host_session_.GetStreamChannel(kMuxChannelName)->
......
......@@ -171,8 +171,9 @@ void ConnectionToHost::OnSessionStateChange(
audio_reader_ = AudioReader::Create(session_->config());
if (audio_reader_.get()) {
audio_reader_->Init(session_.get(), audio_stub_, base::Bind(
audio_reader_->Init(session_.get(), base::Bind(
&ConnectionToHost::OnChannelInitialized, base::Unretained(this)));
audio_reader_->set_audio_stub(audio_stub_);
}
control_dispatcher_.reset(new ClientControlDispatcher());
......
......@@ -308,23 +308,6 @@ ErrorCode FakeSession::error() {
return error_;
}
void FakeSession::CreateStreamChannel(
const std::string& name, const StreamChannelCallback& callback) {
scoped_ptr<FakeSocket> channel(new FakeSocket());
stream_channels_[name] = channel.get();
callback.Run(channel.PassAs<net::StreamSocket>());
}
void FakeSession::CreateDatagramChannel(
const std::string& name, const DatagramChannelCallback& callback) {
scoped_ptr<FakeUdpSocket> channel(new FakeUdpSocket());
datagram_channels_[name] = channel.get();
callback.Run(channel.PassAs<net::Socket>());
}
void FakeSession::CancelChannelCreation(const std::string& name) {
}
const std::string& FakeSession::jid() {
return jid_;
}
......@@ -341,9 +324,34 @@ void FakeSession::set_config(const SessionConfig& config) {
config_ = config;
}
ChannelFactory* FakeSession::GetTransportChannelFactory() {
return this;
}
ChannelFactory* FakeSession::GetMultiplexedChannelFactory() {
return this;
}
void FakeSession::Close() {
closed_ = true;
}
void FakeSession::CreateStreamChannel(
const std::string& name, const StreamChannelCallback& callback) {
scoped_ptr<FakeSocket> channel(new FakeSocket());
stream_channels_[name] = channel.get();
callback.Run(channel.PassAs<net::StreamSocket>());
}
void FakeSession::CreateDatagramChannel(
const std::string& name, const DatagramChannelCallback& callback) {
scoped_ptr<FakeUdpSocket> channel(new FakeUdpSocket());
datagram_channels_[name] = channel.get();
callback.Run(channel.PassAs<net::Socket>());
}
void FakeSession::CancelChannelCreation(const std::string& name) {
}
} // namespace protocol
} // namespace remoting
......@@ -14,6 +14,7 @@
#include "net/base/completion_callback.h"
#include "net/socket/socket.h"
#include "net/socket/stream_socket.h"
#include "remoting/protocol/channel_factory.h"
#include "remoting/protocol/session.h"
class MessageLoop;
......@@ -146,7 +147,8 @@ class FakeUdpSocket : public net::Socket {
// FakeSession is a dummy protocol::Session that uses FakeSocket for all
// channels.
class FakeSession : public Session {
class FakeSession : public Session,
public ChannelFactory {
public:
FakeSession();
virtual ~FakeSession();
......@@ -164,11 +166,18 @@ class FakeSession : public Session {
FakeSocket* GetStreamChannel(const std::string& name);
FakeUdpSocket* GetDatagramChannel(const std::string& name);
// Session implementation.
// Session interface.
virtual void SetEventHandler(EventHandler* event_handler) OVERRIDE;
virtual ErrorCode error() OVERRIDE;
virtual const std::string& jid() OVERRIDE;
virtual const CandidateSessionConfig* candidate_config() OVERRIDE;
virtual const SessionConfig& config() OVERRIDE;
virtual void set_config(const SessionConfig& config) OVERRIDE;
virtual ChannelFactory* GetTransportChannelFactory() OVERRIDE;
virtual ChannelFactory* GetMultiplexedChannelFactory() OVERRIDE;
virtual void Close() OVERRIDE;
// ChannelFactory interface.
virtual void CreateStreamChannel(
const std::string& name, const StreamChannelCallback& callback) OVERRIDE;
virtual void CreateDatagramChannel(
......@@ -176,14 +185,6 @@ class FakeSession : public Session {
const DatagramChannelCallback& callback) OVERRIDE;
virtual void CancelChannelCreation(const std::string& name) OVERRIDE;
virtual const std::string& jid() OVERRIDE;
virtual const CandidateSessionConfig* candidate_config() OVERRIDE;
virtual const SessionConfig& config() OVERRIDE;
virtual void set_config(const SessionConfig& config) OVERRIDE;
virtual void Close() OVERRIDE;
public:
EventHandler* event_handler_;
scoped_ptr<const CandidateSessionConfig> candidate_config_;
......
......@@ -13,6 +13,7 @@
#include "remoting/jingle_glue/iq_sender.h"
#include "remoting/protocol/authenticator.h"
#include "remoting/protocol/channel_authenticator.h"
#include "remoting/protocol/channel_multiplexer.h"
#include "remoting/protocol/content_description.h"
#include "remoting/protocol/jingle_messages.h"
#include "remoting/protocol/jingle_session_manager.h"
......@@ -38,6 +39,9 @@ const int kTransportInfoSendDelayMs = 2;
// |transport-info|.
const int kMessageResponseTimeoutSeconds = 10;
// Name of the multiplexed channel.
const char kMuxChannelName[] = "mux";
ErrorCode AuthRejectionReasonToErrorCode(
Authenticator::RejectionReason reason) {
switch (reason) {
......@@ -61,6 +65,7 @@ JingleSession::JingleSession(JingleSessionManager* session_manager)
}
JingleSession::~JingleSession() {
channel_multiplexer_.reset();
STLDeleteContainerPointers(pending_requests_.begin(),
pending_requests_.end());
STLDeleteContainerPairSecondPointers(channels_.begin(), channels_.end());
......@@ -170,6 +175,46 @@ void JingleSession::AcceptIncomingConnection(
return;
}
const std::string& JingleSession::jid() {
DCHECK(CalledOnValidThread());
return peer_jid_;
}
const CandidateSessionConfig* JingleSession::candidate_config() {
DCHECK(CalledOnValidThread());
return candidate_config_.get();
}
const SessionConfig& JingleSession::config() {
DCHECK(CalledOnValidThread());
return config_;
}
void JingleSession::set_config(const SessionConfig& config) {
DCHECK(CalledOnValidThread());
DCHECK(!config_is_set_);
config_ = config;
config_is_set_ = true;
}
ChannelFactory* JingleSession::GetTransportChannelFactory() {
DCHECK(CalledOnValidThread());
return this;
}
ChannelFactory* JingleSession::GetMultiplexedChannelFactory() {
DCHECK(CalledOnValidThread());
if (!channel_multiplexer_.get())
channel_multiplexer_.reset(new ChannelMultiplexer(this, kMuxChannelName));
return channel_multiplexer_.get();
}
void JingleSession::Close() {
DCHECK(CalledOnValidThread());
CloseInternal(OK);
}
void JingleSession::CreateStreamChannel(
const std::string& name,
const StreamChannelCallback& callback) {
......@@ -206,34 +251,6 @@ void JingleSession::CancelChannelCreation(const std::string& name) {
}
}
const std::string& JingleSession::jid() {
DCHECK(CalledOnValidThread());
return peer_jid_;
}
const CandidateSessionConfig* JingleSession::candidate_config() {
DCHECK(CalledOnValidThread());
return candidate_config_.get();
}
const SessionConfig& JingleSession::config() {
DCHECK(CalledOnValidThread());
return config_;
}
void JingleSession::set_config(const SessionConfig& config) {
DCHECK(CalledOnValidThread());
DCHECK(!config_is_set_);
config_ = config;
config_is_set_ = true;
}
void JingleSession::Close() {
DCHECK(CalledOnValidThread());
CloseInternal(OK);
}
void JingleSession::OnTransportCandidate(Transport* transport,
const cricket::Candidate& candidate) {
pending_candidates_.push_back(JingleMessage::NamedCandidate(
......
......@@ -15,6 +15,7 @@
#include "net/base/completion_callback.h"
#include "remoting/jingle_glue/iq_sender.h"
#include "remoting/protocol/authenticator.h"
#include "remoting/protocol/channel_factory.h"
#include "remoting/protocol/jingle_messages.h"
#include "remoting/protocol/session.h"
#include "remoting/protocol/session_config.h"
......@@ -28,12 +29,14 @@ class StreamSocket;
namespace remoting {
namespace protocol {
class ChannelMultiplexer;
class JingleSessionManager;
// JingleSessionManager and JingleSession implement the subset of the
// Jingle protocol used in Chromoting. Instances of this class are
// created by the JingleSessionManager.
class JingleSession : public Session,
public ChannelFactory,
public Transport::EventHandler {
public:
virtual ~JingleSession();
......@@ -41,6 +44,15 @@ class JingleSession : public Session,
// Session interface.
virtual void SetEventHandler(Session::EventHandler* event_handler) OVERRIDE;
virtual ErrorCode error() OVERRIDE;
virtual const std::string& jid() OVERRIDE;
virtual const CandidateSessionConfig* candidate_config() OVERRIDE;
virtual const SessionConfig& config() OVERRIDE;
virtual void set_config(const SessionConfig& config) OVERRIDE;
virtual ChannelFactory* GetTransportChannelFactory() OVERRIDE;
virtual ChannelFactory* GetMultiplexedChannelFactory() OVERRIDE;
virtual void Close() OVERRIDE;
// ChannelFactory interface.
virtual void CreateStreamChannel(
const std::string& name,
const StreamChannelCallback& callback) OVERRIDE;
......@@ -48,11 +60,6 @@ class JingleSession : public Session,
const std::string& name,
const DatagramChannelCallback& callback) OVERRIDE;
virtual void CancelChannelCreation(const std::string& name) OVERRIDE;
virtual const std::string& jid() OVERRIDE;
virtual const CandidateSessionConfig* candidate_config() OVERRIDE;
virtual const SessionConfig& config() OVERRIDE;
virtual void set_config(const SessionConfig& config) OVERRIDE;
virtual void Close() OVERRIDE;
// Transport::EventHandler interface.
virtual void OnTransportCandidate(
......@@ -142,6 +149,7 @@ class JingleSession : public Session,
std::list<IqRequest*> pending_requests_;
ChannelsMap channels_;
scoped_ptr<ChannelMultiplexer> channel_multiplexer_;
base::OneShotTimer<JingleSession> transport_infos_timer_;
std::list<JingleMessage::NamedCandidate> pending_candidates_;
......
......@@ -228,13 +228,15 @@ class JingleSessionTest : public testing::Test {
}
void CreateChannel() {
client_session_->CreateStreamChannel(kChannelName, base::Bind(
&JingleSessionTest::OnClientChannelCreated, base::Unretained(this)));
host_session_->CreateStreamChannel(kChannelName, base::Bind(
&JingleSessionTest::OnHostChannelCreated, base::Unretained(this)));
client_session_->GetTransportChannelFactory()->CreateStreamChannel(
kChannelName, base::Bind(&JingleSessionTest::OnClientChannelCreated,
base::Unretained(this)));
host_session_->GetTransportChannelFactory()->CreateStreamChannel(
kChannelName, base::Bind(&JingleSessionTest::OnHostChannelCreated,
base::Unretained(this)));
int counter = 2;
ExpectRouteChange();
ExpectRouteChange(kChannelName);
EXPECT_CALL(client_channel_callback_, OnDone(_))
.WillOnce(QuitThreadOnCounter(&counter));
EXPECT_CALL(host_channel_callback_, OnDone(_))
......@@ -245,12 +247,12 @@ class JingleSessionTest : public testing::Test {
EXPECT_TRUE(host_socket_.get());
}
void ExpectRouteChange() {
void ExpectRouteChange(const std::string&