Commit aa22c085 authored by sergeyu's avatar sergeyu Committed by Commit bot

Add P2PDatagramSocket and P2PStreamSocket interfaces.

Previously remoting code was using net::Socket and net::StreamSocket
for datagram and stream socket. Problem is that net::StreamSocket
interface contains a lot of methods that are not relevant for
peer-to-peer connections in remoting. Added P2PDatagramSocket and
P2PStreamSocket interfaces independent of net::Socket. This allowed to
remove a lot of the redundant code needed for net::StreamSocket
implementations. There are two new adapters required in
SslHmacChannelAuthenticator for the SSL layer, but these won't be
necessary after we migrate to QUIC.

Review URL: https://codereview.chromium.org/1197853003

Cr-Commit-Position: refs/heads/master@{#339489}
parent 2f8fcec9
......@@ -5,218 +5,130 @@
#include "remoting/base/buffered_socket_writer.h"
#include "base/bind.h"
#include "base/location.h"
#include "base/single_thread_task_runner.h"
#include "base/stl_util.h"
#include "base/thread_task_runner_handle.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/socket/socket.h"
namespace remoting {
struct BufferedSocketWriterBase::PendingPacket {
PendingPacket(scoped_refptr<net::IOBufferWithSize> data,
namespace {
int WriteNetSocket(net::Socket* socket,
const scoped_refptr<net::IOBuffer>& buf,
int buf_len,
const net::CompletionCallback& callback) {
return socket->Write(buf.get(), buf_len, callback);
}
} // namespace
struct BufferedSocketWriter::PendingPacket {
PendingPacket(scoped_refptr<net::DrainableIOBuffer> data,
const base::Closure& done_task)
: data(data),
done_task(done_task) {
}
scoped_refptr<net::IOBufferWithSize> data;
scoped_refptr<net::DrainableIOBuffer> data;
base::Closure done_task;
};
BufferedSocketWriterBase::BufferedSocketWriterBase()
: socket_(nullptr),
write_pending_(false),
closed_(false),
destroyed_flag_(nullptr) {
// static
scoped_ptr<BufferedSocketWriter> BufferedSocketWriter::CreateForSocket(
net::Socket* socket,
const WriteFailedCallback& write_failed_callback) {
scoped_ptr<BufferedSocketWriter> result(new BufferedSocketWriter());
result->Init(base::Bind(&WriteNetSocket, socket), write_failed_callback);
return result.Pass();
}
void BufferedSocketWriterBase::Init(net::Socket* socket,
const WriteFailedCallback& callback) {
DCHECK(CalledOnValidThread());
DCHECK(socket);
socket_ = socket;
write_failed_callback_ = callback;
BufferedSocketWriter::BufferedSocketWriter() : weak_factory_(this) {
}
BufferedSocketWriter::~BufferedSocketWriter() {
STLDeleteElements(&queue_);
}
bool BufferedSocketWriterBase::Write(
scoped_refptr<net::IOBufferWithSize> data, const base::Closure& done_task) {
DCHECK(CalledOnValidThread());
DCHECK(socket_);
void BufferedSocketWriter::Init(
const WriteCallback& write_callback,
const WriteFailedCallback& write_failed_callback) {
write_callback_ = write_callback;
write_failed_callback_ = write_failed_callback;
}
bool BufferedSocketWriter::Write(
const scoped_refptr<net::IOBufferWithSize>& data,
const base::Closure& done_task) {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK(data.get());
// Don't write after Close().
if (closed_)
// Don't write after error.
if (is_closed())
return false;
queue_.push_back(new PendingPacket(data, done_task));
queue_.push_back(new PendingPacket(
new net::DrainableIOBuffer(data.get(), data->size()), done_task));
DoWrite();
// DoWrite() may trigger OnWriteError() to be called.
return !closed_;
return !is_closed();
}
void BufferedSocketWriterBase::DoWrite() {
DCHECK(CalledOnValidThread());
DCHECK(socket_);
// Don't try to write if there is another write pending.
if (write_pending_)
return;
bool BufferedSocketWriter::is_closed() {
return write_callback_.is_null();
}
// Don't write after Close().
if (closed_)
return;
void BufferedSocketWriter::DoWrite() {
DCHECK(thread_checker_.CalledOnValidThread());
while (true) {
net::IOBuffer* current_packet;
int current_packet_size;
GetNextPacket(&current_packet, &current_packet_size);
// Return if the queue is empty.
if (!current_packet)
return;
int result = socket_->Write(
current_packet, current_packet_size,
base::Bind(&BufferedSocketWriterBase::OnWritten,
base::Unretained(this)));
bool write_again = false;
HandleWriteResult(result, &write_again);
if (!write_again)
return;
base::WeakPtr<BufferedSocketWriter> self = weak_factory_.GetWeakPtr();
while (self && !write_pending_ && !is_closed() && !queue_.empty()) {
int result = write_callback_.Run(
queue_.front()->data.get(), queue_.front()->data->BytesRemaining(),
base::Bind(&BufferedSocketWriter::OnWritten,
weak_factory_.GetWeakPtr()));
HandleWriteResult(result);
}
}
void BufferedSocketWriterBase::HandleWriteResult(int result,
bool* write_again) {
*write_again = false;
void BufferedSocketWriter::HandleWriteResult(int result) {
if (result < 0) {
if (result == net::ERR_IO_PENDING) {
write_pending_ = true;
} else {
HandleError(result);
if (!write_failed_callback_.is_null())
write_failed_callback_.Run(result);
write_callback_.Reset();
if (!write_failed_callback_.is_null()) {
WriteFailedCallback callback = write_failed_callback_;
callback.Run(result);
}
}
return;
}
base::Closure done_task = AdvanceBufferPosition(result);
if (!done_task.is_null()) {
bool destroyed = false;
destroyed_flag_ = &destroyed;
done_task.Run();
if (destroyed) {
// Stop doing anything if we've been destroyed by the callback.
return;
}
destroyed_flag_ = nullptr;
}
*write_again = true;
}
void BufferedSocketWriterBase::OnWritten(int result) {
DCHECK(CalledOnValidThread());
DCHECK(write_pending_);
write_pending_ = false;
bool write_again;
HandleWriteResult(result, &write_again);
if (write_again)
DoWrite();
}
void BufferedSocketWriterBase::HandleError(int result) {
DCHECK(CalledOnValidThread());
closed_ = true;
STLDeleteElements(&queue_);
// Notify subclass that an error is received.
OnError(result);
}
void BufferedSocketWriterBase::Close() {
DCHECK(CalledOnValidThread());
closed_ = true;
}
BufferedSocketWriterBase::~BufferedSocketWriterBase() {
if (destroyed_flag_)
*destroyed_flag_ = true;
STLDeleteElements(&queue_);
}
base::Closure BufferedSocketWriterBase::PopQueue() {
base::Closure result = queue_.front()->done_task;
delete queue_.front();
queue_.pop_front();
return result;
}
BufferedSocketWriter::BufferedSocketWriter() {
}
void BufferedSocketWriter::GetNextPacket(
net::IOBuffer** buffer, int* size) {
if (!current_buf_.get()) {
if (queue_.empty()) {
*buffer = nullptr;
return; // Nothing to write.
}
current_buf_ = new net::DrainableIOBuffer(queue_.front()->data.get(),
queue_.front()->data->size());
}
*buffer = current_buf_.get();
*size = current_buf_->BytesRemaining();
}
base::Closure BufferedSocketWriter::AdvanceBufferPosition(int written) {
current_buf_->DidConsume(written);
if (current_buf_->BytesRemaining() == 0) {
current_buf_ = nullptr;
return PopQueue();
}
return base::Closure();
}
void BufferedSocketWriter::OnError(int result) {
current_buf_ = nullptr;
}
DCHECK(!queue_.empty());
BufferedSocketWriter::~BufferedSocketWriter() {
}
queue_.front()->data->DidConsume(result);
BufferedDatagramWriter::BufferedDatagramWriter() {
}
if (queue_.front()->data->BytesRemaining() == 0) {
base::Closure done_task = queue_.front()->done_task;
delete queue_.front();
queue_.pop_front();
void BufferedDatagramWriter::GetNextPacket(
net::IOBuffer** buffer, int* size) {
if (queue_.empty()) {
*buffer = nullptr;
return; // Nothing to write.
if (!done_task.is_null())
done_task.Run();
}
*buffer = queue_.front()->data.get();
*size = queue_.front()->data->size();
}
base::Closure BufferedDatagramWriter::AdvanceBufferPosition(int written) {
DCHECK_EQ(written, queue_.front()->data->size());
return PopQueue();
}
void BufferedDatagramWriter::OnError(int result) {
// Nothing to do here.
}
void BufferedSocketWriter::OnWritten(int result) {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK(write_pending_);
write_pending_ = false;
BufferedDatagramWriter::~BufferedDatagramWriter() {
base::WeakPtr<BufferedSocketWriter> self = weak_factory_.GetWeakPtr();
HandleWriteResult(result);
if (self)
DoWrite();
}
} // namespace remoting
......@@ -8,10 +8,11 @@
#include <list>
#include "base/callback.h"
#include "base/memory/weak_ptr.h"
#include "base/synchronization/lock.h"
#include "base/threading/non_thread_safe.h"
#include "base/threading/thread_checker.h"
#include "net/base/completion_callback.h"
#include "net/base/io_buffer.h"
#include "net/socket/socket.h"
namespace net {
class Socket;
......@@ -19,105 +20,57 @@ class Socket;
namespace remoting {
// BufferedSocketWriter and BufferedDatagramWriter implement write data queue
// for stream and datagram sockets. BufferedSocketWriterBase is a base class
// that implements base functionality common for streams and datagrams.
// These classes are particularly useful when data comes from a thread
// that doesn't own the socket, as Write() can be called from any thread.
// Whenever new data is written it is just put in the queue, and then written
// on the thread that owns the socket. GetBufferChunks() and GetBufferSize()
// can be used to throttle writes.
class BufferedSocketWriterBase : public base::NonThreadSafe {
// BufferedSocketWriter implement write data queue for stream sockets.
class BufferedSocketWriter {
public:
typedef base::Callback<int(const scoped_refptr<net::IOBuffer>& buf,
int buf_len,
const net::CompletionCallback& callback)>
WriteCallback;
typedef base::Callback<void(int)> WriteFailedCallback;
BufferedSocketWriterBase();
virtual ~BufferedSocketWriterBase();
static scoped_ptr<BufferedSocketWriter> CreateForSocket(
net::Socket* socket,
const WriteFailedCallback& write_failed_callback);
BufferedSocketWriter();
virtual ~BufferedSocketWriter();
// Initializes the writer. Must be called on the thread that will be used
// to access the socket in the future. |callback| will be called after each
// failed write. Caller retains ownership of |socket|.
// TODO(sergeyu): Change it so that it take ownership of |socket|.
void Init(net::Socket* socket, const WriteFailedCallback& callback);
// Initializes the writer. |write_callback| is called to write data to the
// socket. |write_failed_callback| is called when write operation fails.
// Writing stops after the first failed write.
void Init(const WriteCallback& write_callback,
const WriteFailedCallback& write_failed_callback);
// Puts a new data chunk in the buffer. Returns false and doesn't enqueue
// the data if called before Init(). Can be called on any thread.
bool Write(scoped_refptr<net::IOBufferWithSize> buffer,
// Puts a new data chunk in the buffer. Returns false if writing has stopped
// because of an error.
bool Write(const scoped_refptr<net::IOBufferWithSize>& buffer,
const base::Closure& done_task);
// Returns true when there is data waiting to be written.
bool has_data_pending() { return !queue_.empty(); }
// Stops writing and drops current buffers. Must be called on the
// network thread.
void Close();
protected:
private:
struct PendingPacket;
typedef std::list<PendingPacket*> DataQueue;
DataQueue queue_;
// Removes element from the front of the queue and returns |done_task| for
// that element. Called from AdvanceBufferPosition() implementation, which
// then returns result of this function to its caller.
base::Closure PopQueue();
// Following three methods must be implemented in child classes.
// Returns next packet that needs to be written to the socket. Implementation
// must set |*buffer| to nullptr if there is nothing left in the queue.
virtual void GetNextPacket(net::IOBuffer** buffer, int* size) = 0;
// Returns true if the writer is closed due to an error.
bool is_closed();
// Returns closure that must be executed or null closure if the last write
// didn't complete any messages.
virtual base::Closure AdvanceBufferPosition(int written) = 0;
// This method is called whenever there is an error writing to the socket.
virtual void OnError(int result) = 0;
private:
void DoWrite();
void HandleWriteResult(int result, bool* write_again);
void HandleWriteResult(int result);
void OnWritten(int result);
// This method is called when an error is encountered.
void HandleError(int result);
base::ThreadChecker thread_checker_;
net::Socket* socket_;
WriteCallback write_callback_;
WriteFailedCallback write_failed_callback_;
bool write_pending_;
bool closed_;
bool* destroyed_flag_;
};
class BufferedSocketWriter : public BufferedSocketWriterBase {
public:
BufferedSocketWriter();
~BufferedSocketWriter() override;
protected:
void GetNextPacket(net::IOBuffer** buffer, int* size) override;
base::Closure AdvanceBufferPosition(int written) override;
void OnError(int result) override;
private:
scoped_refptr<net::DrainableIOBuffer> current_buf_;
};
DataQueue queue_;
class BufferedDatagramWriter : public BufferedSocketWriterBase {
public:
BufferedDatagramWriter();
~BufferedDatagramWriter() override;
bool write_pending_ = false;
protected:
void GetNextPacket(net::IOBuffer** buffer, int* size) override;
base::Closure AdvanceBufferPosition(int written) override;
void OnError(int result) override;
base::WeakPtrFactory<BufferedSocketWriter> weak_factory_;
};
} // namespace remoting
......
......@@ -18,7 +18,8 @@
namespace remoting {
namespace {
const int kTestBufferSize = 10 * 1024; // 10k;
const int kTestBufferSize = 10000;
const size_t kWriteChunkSize = 1024U;
class SocketDataProvider: public net::SocketDataProvider {
......@@ -93,9 +94,9 @@ class BufferedSocketWriterTest : public testing::Test {
net::MockConnect(net::SYNCHRONOUS, net::OK));
EXPECT_EQ(net::OK, socket_->Connect(net::CompletionCallback()));
writer_.reset(new BufferedSocketWriter());
writer_->Init(socket_.get(), base::Bind(
&BufferedSocketWriterTest::OnWriteFailed, base::Unretained(this)));
writer_ = BufferedSocketWriter::CreateForSocket(
socket_.get(), base::Bind(&BufferedSocketWriterTest::OnWriteFailed,
base::Unretained(this)));
test_buffer_ = new net::IOBufferWithSize(kTestBufferSize);
test_buffer_2_ = new net::IOBufferWithSize(kTestBufferSize);
for (int i = 0; i< kTestBufferSize; ++i) {
......@@ -126,7 +127,7 @@ class BufferedSocketWriterTest : public testing::Test {
void TestAppendInCallback() {
writer_->Write(test_buffer_, base::Bind(
base::IgnoreResult(&BufferedSocketWriterBase::Write),
base::IgnoreResult(&BufferedSocketWriter::Write),
base::Unretained(writer_.get()), test_buffer_2_,
base::Closure()));
base::RunLoop().RunUntilIdle();
......
......@@ -4,6 +4,7 @@
#include "remoting/client/key_event_mapper.h"
#include "base/bind.h"
#include "remoting/proto/event.pb.h"
#include "remoting/protocol/protocol_mock_objects.h"
#include "remoting/protocol/test_event_matchers.h"
......
......@@ -15,6 +15,7 @@
#include "remoting/protocol/authenticator.h"
#include "remoting/protocol/channel_authenticator.h"
#include "remoting/protocol/fake_stream_socket.h"
#include "remoting/protocol/p2p_stream_socket.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/webrtc/libjingle/xmllite/xmlelement.h"
......@@ -158,14 +159,14 @@ void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) {
void AuthenticatorTestBase::OnHostConnected(
int error,
scoped_ptr<net::StreamSocket> socket) {
scoped_ptr<P2PStreamSocket> socket) {
host_callback_.OnDone(error);
host_socket_ = socket.Pass();
}
void AuthenticatorTestBase::OnClientConnected(
int error,
scoped_ptr<net::StreamSocket> socket) {
scoped_ptr<P2PStreamSocket> socket) {
client_callback_.OnDone(error);
client_socket_ = socket.Pass();
}
......
......@@ -12,10 +12,6 @@
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
class StreamSocket;
} // namespace net
namespace remoting {
class RsaKeyPair;
......@@ -25,6 +21,7 @@ namespace protocol {
class Authenticator;
class ChannelAuthenticator;
class FakeStreamSocket;
class P2PStreamSocket;
class AuthenticatorTestBase : public testing::Test {
public:
......@@ -49,9 +46,9 @@ class AuthenticatorTestBase : public testing::Test {
void RunChannelAuth(bool expected_fail);
void OnHostConnected(int error,
scoped_ptr<net::StreamSocket> socket);
scoped_ptr<P2PStreamSocket> socket);
void OnClientConnected(int error,
scoped_ptr<net::StreamSocket> socket);
scoped_ptr<P2PStreamSocket> socket);
base::MessageLoop message_loop_;
......@@ -66,8 +63,8 @@ class AuthenticatorTestBase : public testing::Test {
scoped_ptr<ChannelAuthenticator> host_auth_;
MockChannelDoneCallback client_callback_;
MockChannelDoneCallback host_callback_;
scoped_ptr<net::StreamSocket> client_socket_;
scoped_ptr<net::StreamSocket> host_socket_;
scoped_ptr<P2PStreamSocket> client_socket_;
scoped_ptr<P2PStreamSocket> host_socket_;
DISALLOW_COPY_AND_ASSIGN(AuthenticatorTestBase);
};
......
......@@ -9,20 +9,18 @@
#include "base/callback_forward.h"
namespace net {
class StreamSocket;
} // namespace net
namespace remoting {
namespace protocol {
class P2PStreamSocket;
// Interface for channel authentications that perform channel-level
// authentication. Depending on implementation channel authenticators
// may also establish SSL connection. Each instance of this interface
// should be used only once for one channel.
class ChannelAuthenticator {
public:
typedef base::Callback<void(int error, scoped_ptr<net::StreamSocket>)>
typedef base::Callback<void(int error, scoped_ptr<P2PStreamSocket>)>
DoneCallback;
virtual ~ChannelAuthenticator() {}
......@@ -31,7 +29,7 @@ class ChannelAuthenticator {
// authentication is finished. Callback may be invoked before this method
// returns, and may delete the calling authenticator.
virtual void SecureAndAuthenticate(
scoped_ptr<net::StreamSocket> socket,
scoped_ptr<P2PStreamSocket> socket,
const DoneCallback& done_callback) = 0;
};
......
......@@ -5,7 +5,7 @@
#include "remoting/protocol/channel_dispatcher_base.h"
#include "base/bind.h"
#include "net/socket/stream_socket.h"
#include "remoting/protocol/p2p_stream_socket.h"
#include "remoting/protocol/session.h"
#include "remoting/protocol/session_config.h"
#include "remoting/protocol/stream_channel_factory.h"
......@@ -20,7 +20,6 @@ ChannelDispatcherBase::ChannelDispatcherBase(const char* channel_name)
}
ChannelDispatcherBase::~ChannelDispatcherBase() {
writer()->Close();
if (channel_factory_)
channel_factory_->CancelChannelCreation(channel_name_);
}
......@@ -49,7 +48,7 @@ void ChannelDispatcherBase::Init(Session* session,
}
void ChannelDispatcherBase::OnChannelReady(
scoped_ptr<net::StreamSocket> socket) {
scoped_ptr<P2PStreamSocket> socket) {
if (!socket.get()) {
event_handler_->OnChannelError(this, CHANNEL_CONNECTION_ERROR);
return;
......@@ -57,9 +56,10 @@ void ChannelDispatcherBase::OnChannelReady(
channel_factory_ = nullptr;
channel_ = socket.Pass();
writer_.Init(channel_.get(),
base::Bind(&ChannelDispatcherBase::OnReadWriteFailed,
base::Unretained(this)));
writer_.Init(
base::Bind(&P2PStreamSocket::Write, base::Unretained(channel_.get())),
base::Bind(&ChannelDispatcherBase::OnReadWriteFailed,
base::Unretained(this)));
reader_.StartReading(channel_.get(),
base::Bind(&ChannelDispatcherBase::OnReadWriteFailed,
base::Unretained(this)));
......
......@@ -14,10 +14,6 @@
#include "remoting/protocol/errors.h"
#include "remoting/protocol/message_reader.h"
namespace net {
class StreamSocket;
} // namespace net
namespace remoting {
namespace protocol {
......@@ -66,13 +62,13 @@ class ChannelDispatcherBase {
MessageReader* reader() { return &reader_; }
private:
void OnChannelReady(scoped_ptr<net::StreamSocket> socket);
void OnChannelReady(scoped_ptr<P2PStreamSocket> socket);
void OnReadWriteFailed(int error);
std::string channel_name_;
StreamChannelFactory* channel_factory_;
EventHandler* event_handler_;
scoped_ptr<net::StreamSocket> channel_;
scoped_ptr<P2PStreamSocket> channel_;
BufferedSocketWriter writer_;
MessageReader reader_;
......
......@@ -14,8 +14,8 @@
#include "base/stl_util.h"
#include "base/thread_task_runner_handle.h"
#include "net/base/net_errors.h"
#include "net/socket/stream_socket.h"
#include "remoting/protocol/message_serialization.h"
#include "remoting/protocol/p2p_stream_socket.h"
namespace remoting {