Commit a04c0468 authored by sergeyu's avatar sergeyu Committed by Commit bot

Fix MessageReader to pass errors to the channel

Previously MessageReader was stopping reading after the first error,
but wasn't notifying the client about the problem. This results in some
errors (e.g. from SSL layer) being ignores while they should terminate
connection.

BUG=487451

Review URL: https://codereview.chromium.org/1143443003

Cr-Commit-Position: refs/heads/master@{#329780}
parent 915be3ac
......@@ -57,14 +57,17 @@ void ChannelDispatcherBase::OnChannelReady(
channel_factory_ = nullptr;
channel_ = socket.Pass();
writer_.Init(channel_.get(), base::Bind(&ChannelDispatcherBase::OnWriteFailed,
base::Unretained(this)));
reader_.StartReading(channel_.get());
writer_.Init(channel_.get(),
base::Bind(&ChannelDispatcherBase::OnReadWriteFailed,
base::Unretained(this)));
reader_.StartReading(channel_.get(),
base::Bind(&ChannelDispatcherBase::OnReadWriteFailed,
base::Unretained(this)));
event_handler_->OnChannelInitialized(this);
}
void ChannelDispatcherBase::OnWriteFailed(int error) {
void ChannelDispatcherBase::OnReadWriteFailed(int error) {
event_handler_->OnChannelError(this, CHANNEL_CONNECTION_ERROR);
}
......
......@@ -67,7 +67,7 @@ class ChannelDispatcherBase {
private:
void OnChannelReady(scoped_ptr<net::StreamSocket> socket);
void OnWriteFailed(int error);
void OnReadWriteFailed(int error);
std::string channel_name_;
StreamChannelFactory* channel_factory_;
......
......@@ -8,6 +8,7 @@
#include "base/bind.h"
#include "base/callback.h"
#include "base/callback_helpers.h"
#include "base/location.h"
#include "base/single_thread_task_runner.h"
#include "base/stl_util.h"
......@@ -79,7 +80,7 @@ class ChannelMultiplexer::MuxChannel {
scoped_ptr<net::StreamSocket> CreateSocket();
void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task);
void OnWriteFailed();
void OnBaseChannelError(int error);
// Called by MuxSocket.
void OnSocketDestroyed();
......@@ -107,7 +108,7 @@ class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
~MuxSocket() override;
void OnWriteComplete();
void OnWriteFailed();
void OnBaseChannelError(int error);
void OnPacketReceived();
// net::StreamSocket interface.
......@@ -168,6 +169,8 @@ class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
private:
MuxChannel* channel_;
int base_channel_error_ = net::OK;
net::CompletionCallback read_callback_;
scoped_refptr<net::IOBuffer> read_buffer_;
int read_buffer_size_;
......@@ -220,9 +223,9 @@ void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
}
}
void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) {
if (socket_)
socket_->OnWriteFailed();
socket_->OnBaseChannelError(error);
}
void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
......@@ -276,6 +279,9 @@ int ChannelMultiplexer::MuxSocket::Read(
DCHECK(CalledOnValidThread());
DCHECK(read_callback_.is_null());
if (base_channel_error_ != net::OK)
return base_channel_error_;
int result = channel_->DoRead(buffer, buffer_len);
if (result == 0) {
read_buffer_ = buffer;
......@@ -290,6 +296,10 @@ int ChannelMultiplexer::MuxSocket::Write(
net::IOBuffer* buffer, int buffer_len,
const net::CompletionCallback& callback) {
DCHECK(CalledOnValidThread());
DCHECK(write_callback_.is_null());
if (base_channel_error_ != net::OK)
return base_channel_error_;
scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
size_t size = std::min(kMaxPacketSize, buffer_len);
......@@ -317,19 +327,28 @@ int ChannelMultiplexer::MuxSocket::Write(
void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
write_pending_ = false;
if (!write_callback_.is_null()) {
net::CompletionCallback cb;
std::swap(cb, write_callback_);
cb.Run(write_result_);
}
if (!write_callback_.is_null())
base::ResetAndReturn(&write_callback_).Run(write_result_);
}
void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
if (!write_callback_.is_null()) {
net::CompletionCallback cb;
std::swap(cb, write_callback_);
cb.Run(net::ERR_FAILED);
void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) {
base_channel_error_ = error;
// Here only one of the read and write callbacks is called if both of them are
// pending. Ideally both of them should be called in that case, but that would
// require the second one to be called asynchronously which would complicate
// this code. Channels handle read and write errors the same way (see
// ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
// callbacks is enough.
if (!read_callback_.is_null()) {
base::ResetAndReturn(&read_callback_).Run(error);
return;
}
if (!write_callback_.is_null())
base::ResetAndReturn(&write_callback_).Run(error);
}
void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
......@@ -337,9 +356,7 @@ void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
read_buffer_ = nullptr;
DCHECK_GT(result, 0);
net::CompletionCallback cb;
std::swap(cb, read_callback_);
cb.Run(result);
base::ResetAndReturn(&read_callback_).Run(result);
}
}
......@@ -403,9 +420,11 @@ void ChannelMultiplexer::OnBaseChannelReady(
if (base_channel_.get()) {
// Initialize reader and writer.
reader_.StartReading(base_channel_.get());
reader_.StartReading(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnBaseChannelError,
base::Unretained(this)));
writer_.Init(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnWriteFailed,
base::Bind(&ChannelMultiplexer::OnBaseChannelError,
base::Unretained(this)));
}
......@@ -447,20 +466,21 @@ ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
}
void ChannelMultiplexer::OnWriteFailed(int error) {
void ChannelMultiplexer::OnBaseChannelError(int error) {
for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
it != channels_.end(); ++it) {
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed,
weak_factory_.GetWeakPtr(), it->second->name()));
FROM_HERE,
base::Bind(&ChannelMultiplexer::NotifyBaseChannelError,
weak_factory_.GetWeakPtr(), it->second->name(), error));
}
}
void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) {
void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name,
int error) {
std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
if (it != channels_.end()) {
it->second->OnWriteFailed();
}
if (it != channels_.end())
it->second->OnBaseChannelError(error);
}
void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
......
......@@ -44,11 +44,12 @@ class ChannelMultiplexer : public StreamChannelFactory {
// Helper method used to create channels.
MuxChannel* GetOrCreateChannel(const std::string& name);
// Error handling callback for |writer_|.
void OnWriteFailed(int error);
// Error handling callback for |reader_| and |writer_|.
void OnBaseChannelError(int error);
// Failed write notifier, queued asynchronously by OnWriteFailed().
void NotifyWriteFailed(const std::string& name);
// Propagates base channel error to channel |name|, queued asynchronously by
// OnBaseChannelError().
void NotifyBaseChannelError(const std::string& name, int error);
// Callback for |reader_;
void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
......
......@@ -36,6 +36,7 @@ class ClientVideoDispatcherTest : public testing::Test,
protected:
void OnVideoAck(scoped_ptr<VideoAck> ack, const base::Closure& done);
void OnReadError(int error);
base::MessageLoop message_loop_;
......@@ -72,7 +73,9 @@ ClientVideoDispatcherTest::ClientVideoDispatcherTest()
DCHECK(initialized_);
host_socket_.PairWith(
session_.fake_channel_factory().GetFakeChannel(kVideoChannelName));
reader_.StartReading(&host_socket_);
reader_.StartReading(&host_socket_,
base::Bind(&ClientVideoDispatcherTest::OnReadError,
base::Unretained(this)));
writer_.Init(&host_socket_, BufferedSocketWriter::WriteFailedCallback());
}
......@@ -101,6 +104,10 @@ void ClientVideoDispatcherTest::OnVideoAck(scoped_ptr<VideoAck> ack,
done.Run();
}
void ClientVideoDispatcherTest::OnReadError(int error) {
LOG(FATAL) << "Unexpected read error: " << error;
}
// Verify that the client can receive video packets and acks are not sent for
// VideoPackets that don't have frame_id field set.
TEST_F(ClientVideoDispatcherTest, WithoutAcks) {
......
......@@ -38,10 +38,15 @@ void MessageReader::SetMessageReceivedCallback(
message_received_callback_ = callback;
}
void MessageReader::StartReading(net::Socket* socket) {
void MessageReader::StartReading(
net::Socket* socket,
const ReadFailedCallback& read_failed_callback) {
DCHECK(CalledOnValidThread());
DCHECK(socket);
DCHECK(!read_failed_callback.is_null());
socket_ = socket;
read_failed_callback_ = read_failed_callback;
DoRead();
}
......@@ -49,13 +54,16 @@ void MessageReader::DoRead() {
DCHECK(CalledOnValidThread());
// Don't try to read again if there is another read pending or we
// have messages that we haven't finished processing yet.
while (!closed_ && !read_pending_ && pending_messages_ == 0) {
bool read_succeeded = true;
while (read_succeeded && !closed_ && !read_pending_ &&
pending_messages_ == 0) {
read_buffer_ = new net::IOBuffer(kReadBufferSize);
int result = socket_->Read(
read_buffer_.get(),
kReadBufferSize,
base::Bind(&MessageReader::OnRead, weak_factory_.GetWeakPtr()));
HandleReadResult(result);
HandleReadResult(result, &read_succeeded);
}
}
......@@ -65,26 +73,34 @@ void MessageReader::OnRead(int result) {
read_pending_ = false;
if (!closed_) {
HandleReadResult(result);
DoRead();
bool read_succeeded;
HandleReadResult(result, &read_succeeded);
if (read_succeeded)
DoRead();
}
}
void MessageReader::HandleReadResult(int result) {
void MessageReader::HandleReadResult(int result, bool* read_succeeded) {
DCHECK(CalledOnValidThread());
if (closed_)
return;
*read_succeeded = true;
if (result > 0) {
OnDataReceived(read_buffer_.get(), result);
*read_succeeded = true;
} else if (result == net::ERR_IO_PENDING) {
read_pending_ = true;
} else {
if (result != net::ERR_CONNECTION_CLOSED) {
LOG(ERROR) << "Read() returned error " << result;
}
DCHECK_LT(result, 0);
// Stop reading after any error.
closed_ = true;
*read_succeeded = false;
LOG(ERROR) << "Read() returned error " << result;
read_failed_callback_.Run(result);
}
}
......
......@@ -35,6 +35,7 @@ class MessageReader : public base::NonThreadSafe {
public:
typedef base::Callback<void(scoped_ptr<CompoundBuffer>, const base::Closure&)>
MessageReceivedCallback;
typedef base::Callback<void(int)> ReadFailedCallback;
MessageReader();
virtual ~MessageReader();
......@@ -43,16 +44,19 @@ class MessageReader : public base::NonThreadSafe {
void SetMessageReceivedCallback(const MessageReceivedCallback& callback);
// Starts reading from |socket|.
void StartReading(net::Socket* socket);
void StartReading(net::Socket* socket,
const ReadFailedCallback& read_failed_callback);
private:
void DoRead();
void OnRead(int result);
void HandleReadResult(int result);
void HandleReadResult(int result, bool* read_succeeded);
void OnDataReceived(net::IOBuffer* data, int data_size);
void RunCallback(scoped_ptr<CompoundBuffer> message);
void OnMessageDone();
ReadFailedCallback read_failed_callback_;
net::Socket* socket_;
// Set to true, when we have a socket read pending, and expecting
......
......@@ -76,7 +76,8 @@ class MessageReaderTest : public testing::Test {
void InitReader() {
reader_->SetMessageReceivedCallback(
base::Bind(&MessageReaderTest::OnMessage, base::Unretained(this)));
reader_->StartReading(&socket_);
reader_->StartReading(&socket_, base::Bind(&MessageReaderTest::OnReadError,
base::Unretained(this)));
}
void AddMessage(const std::string& message) {
......@@ -92,6 +93,11 @@ class MessageReaderTest : public testing::Test {
return result == expected;
}
void OnReadError(int error) {
read_error_ = error;
reader_.reset();
}
void OnMessage(scoped_ptr<CompoundBuffer> buffer,
const base::Closure& done_callback) {
messages_.push_back(buffer.release());
......@@ -102,6 +108,7 @@ class MessageReaderTest : public testing::Test {
scoped_ptr<MessageReader> reader_;
FakeStreamSocket socket_;
MockMessageReceivedCallback callback_;
int read_error_ = 0;
std::vector<CompoundBuffer*> messages_;
bool in_callback_;
};
......@@ -281,13 +288,12 @@ TEST_F(MessageReaderTest, TwoMessages_Separately) {
TEST_F(MessageReaderTest, ReadError) {
socket_.AppendReadError(net::ERR_FAILED);
// Add a message. It should never be read after the error above.
AddMessage(kTestMessage1);
EXPECT_CALL(callback_, OnMessage(_))
.Times(0);
EXPECT_CALL(callback_, OnMessage(_)).Times(0);
InitReader();
EXPECT_EQ(net::ERR_FAILED, read_error_);
EXPECT_FALSE(reader_);
}
// Verify that we the OnMessage callback is not reentered.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment