channel_multiplexer.cc 14.1 KB
Newer Older
1 2 3 4 5 6
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "remoting/protocol/channel_multiplexer.h"

7
#include <stddef.h>
8 9
#include <string.h>

10 11
#include <utility>

12 13
#include "base/bind.h"
#include "base/callback.h"
14
#include "base/callback_helpers.h"
15
#include "base/location.h"
16
#include "base/macros.h"
17
#include "base/memory/ptr_util.h"
18
#include "base/sequence_checker.h"
19
#include "base/single_thread_task_runner.h"
20
#include "base/threading/thread_task_runner_handle.h"
21
#include "net/base/net_errors.h"
22
#include "remoting/protocol/message_serialization.h"
23
#include "remoting/protocol/p2p_stream_socket.h"
24 25 26 27 28 29 30 31 32 33

namespace remoting {
namespace protocol {

namespace {
const int kChannelIdUnknown = -1;
const int kMaxPacketSize = 1024;

class PendingPacket {
 public:
34
  PendingPacket(std::unique_ptr<MultiplexPacket> packet)
35 36
      : packet(std::move(packet)) {}
  ~PendingPacket() {}
37 38 39 40 41 42 43 44 45 46 47

  bool is_empty() { return pos >= packet->data().size(); }

  int Read(char* buffer, size_t size) {
    size = std::min(size, packet->data().size() - pos);
    memcpy(buffer, packet->data().data() + pos, size);
    pos += size;
    return size;
  }

 private:
48
  std::unique_ptr<MultiplexPacket> packet;
49
  size_t pos = 0U;
50 51 52 53 54 55 56 57 58 59

  DISALLOW_COPY_AND_ASSIGN(PendingPacket);
};

}  // namespace

const char ChannelMultiplexer::kMuxChannelName[] = "mux";

struct ChannelMultiplexer::PendingChannel {
  PendingChannel(const std::string& name,
sergeyu's avatar
sergeyu committed
60
                 const ChannelCreatedCallback& callback)
61 62 63
      : name(name), callback(callback) {
  }
  std::string name;
sergeyu's avatar
sergeyu committed
64
  ChannelCreatedCallback callback;
65 66 67 68 69 70 71 72 73 74 75 76 77
};

class ChannelMultiplexer::MuxChannel {
 public:
  MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
             int send_id);
  ~MuxChannel();

  const std::string& name() { return name_; }
  int receive_id() { return receive_id_; }
  void set_receive_id(int id) { receive_id_ = id; }

  // Called by ChannelMultiplexer.
78 79
  std::unique_ptr<P2PStreamSocket> CreateSocket();
  void OnIncomingPacket(std::unique_ptr<MultiplexPacket> packet);
80
  void OnBaseChannelError(int error);
81 82 83

  // Called by MuxSocket.
  void OnSocketDestroyed();
84
  void DoWrite(std::unique_ptr<MultiplexPacket> packet,
85
               const base::Closure& done_task);
86
  int DoRead(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len);
87 88 89 90 91 92 93 94

 private:
  ChannelMultiplexer* multiplexer_;
  std::string name_;
  int send_id_;
  bool id_sent_;
  int receive_id_;
  MuxSocket* socket_;
95
  std::list<std::unique_ptr<PendingPacket>> pending_packets_;
96 97 98 99

  DISALLOW_COPY_AND_ASSIGN(MuxChannel);
};

100
class ChannelMultiplexer::MuxSocket : public P2PStreamSocket {
101 102
 public:
  MuxSocket(MuxChannel* channel);
103
  ~MuxSocket() override;
104 105

  void OnWriteComplete();
106
  void OnBaseChannelError(int error);
107 108
  void OnPacketReceived();

109 110
  // P2PStreamSocket interface.
  int Read(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
111
           const net::CompletionCallback& callback) override;
112
  int Write(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
113 114
            const net::CompletionCallback& callback) override;

115 116 117
 private:
  MuxChannel* channel_;

118 119
  int base_channel_error_ = net::OK;

120 121 122 123 124 125 126 127
  net::CompletionCallback read_callback_;
  scoped_refptr<net::IOBuffer> read_buffer_;
  int read_buffer_size_;

  bool write_pending_;
  int write_result_;
  net::CompletionCallback write_callback_;

128 129
  SEQUENCE_CHECKER(sequence_checker_);

130 131
  base::WeakPtrFactory<MuxSocket> weak_factory_;

132 133 134
  DISALLOW_COPY_AND_ASSIGN(MuxSocket);
};

135 136 137
ChannelMultiplexer::MuxChannel::MuxChannel(ChannelMultiplexer* multiplexer,
                                           const std::string& name,
                                           int send_id)
138 139 140 141 142
    : multiplexer_(multiplexer),
      name_(name),
      send_id_(send_id),
      id_sent_(false),
      receive_id_(kChannelIdUnknown),
143
      socket_(nullptr) {}
144 145 146 147 148 149

ChannelMultiplexer::MuxChannel::~MuxChannel() {
  // Socket must be destroyed before the channel.
  DCHECK(!socket_);
}

150 151
std::unique_ptr<P2PStreamSocket>
ChannelMultiplexer::MuxChannel::CreateSocket() {
152
  DCHECK(!socket_);  // Can't create more than one socket per channel.
153
  std::unique_ptr<MuxSocket> result(new MuxSocket(this));
154
  socket_ = result.get();
155
  return std::move(result);
156 157 158
}

void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
159
    std::unique_ptr<MultiplexPacket> packet) {
160 161
  DCHECK_EQ(packet->channel_id(), receive_id_);
  if (packet->data().size() > 0) {
162 163
    pending_packets_.push_back(
        base::MakeUnique<PendingPacket>(std::move(packet)));
164 165 166 167 168 169 170
    if (socket_) {
      // Notify the socket that we have more data.
      socket_->OnPacketReceived();
    }
  }
}

171
void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) {
172
  if (socket_)
173
    socket_->OnBaseChannelError(error);
174 175 176 177
}

void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
  DCHECK(socket_);
178
  socket_ = nullptr;
179 180
}

181
void ChannelMultiplexer::MuxChannel::DoWrite(
182
    std::unique_ptr<MultiplexPacket> packet,
183 184 185 186 187 188
    const base::Closure& done_task) {
  packet->set_channel_id(send_id_);
  if (!id_sent_) {
    packet->set_channel_name(name_);
    id_sent_ = true;
  }
189
  multiplexer_->DoWrite(std::move(packet), done_task);
190 191
}

192 193 194
int ChannelMultiplexer::MuxChannel::DoRead(
    const scoped_refptr<net::IOBuffer>& buffer,
    int buffer_len) {
195 196 197 198 199 200 201 202
  int pos = 0;
  while (buffer_len > 0 && !pending_packets_.empty()) {
    DCHECK(!pending_packets_.front()->is_empty());
    int result = pending_packets_.front()->Read(
        buffer->data() + pos, buffer_len);
    DCHECK_LE(result, buffer_len);
    pos += result;
    buffer_len -= pos;
203 204
    if (pending_packets_.front()->is_empty())
      pending_packets_.pop_front();
205 206 207 208 209 210 211 212
  }
  return pos;
}

ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
    : channel_(channel),
      read_buffer_size_(0),
      write_pending_(false),
213 214
      write_result_(0),
      weak_factory_(this) {}
215 216 217 218 219 220

ChannelMultiplexer::MuxSocket::~MuxSocket() {
  channel_->OnSocketDestroyed();
}

int ChannelMultiplexer::MuxSocket::Read(
221
    const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
222
    const net::CompletionCallback& callback) {
223
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
224 225
  DCHECK(read_callback_.is_null());

226 227 228
  if (base_channel_error_ != net::OK)
    return base_channel_error_;

229 230 231 232 233 234 235 236 237 238 239
  int result = channel_->DoRead(buffer, buffer_len);
  if (result == 0) {
    read_buffer_ = buffer;
    read_buffer_size_ = buffer_len;
    read_callback_ = callback;
    return net::ERR_IO_PENDING;
  }
  return result;
}

int ChannelMultiplexer::MuxSocket::Write(
240
    const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
241
    const net::CompletionCallback& callback) {
242
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
243 244 245 246
  DCHECK(write_callback_.is_null());

  if (base_channel_error_ != net::OK)
    return base_channel_error_;
247

248
  std::unique_ptr<MultiplexPacket> packet(new MultiplexPacket());
249 250 251 252
  size_t size = std::min(kMaxPacketSize, buffer_len);
  packet->mutable_data()->assign(buffer->data(), size);

  write_pending_ = true;
253 254 255
  channel_->DoWrite(std::move(packet),
                    base::Bind(&ChannelMultiplexer::MuxSocket::OnWriteComplete,
                               weak_factory_.GetWeakPtr()));
256 257 258 259 260 261 262 263 264 265 266 267 268 269

  // OnWriteComplete() might be called above synchronously.
  if (write_pending_) {
    DCHECK(write_callback_.is_null());
    write_callback_ = callback;
    write_result_ = size;
    return net::ERR_IO_PENDING;
  }

  return size;
}

void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
  write_pending_ = false;
270 271 272
  if (!write_callback_.is_null())
    base::ResetAndReturn(&write_callback_).Run(write_result_);

273 274
}

275 276 277 278 279 280 281 282 283 284 285 286 287
void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) {
  base_channel_error_ = error;

  // Here only one of the read and write callbacks is called if both of them are
  // pending. Ideally both of them should be called in that case, but that would
  // require the second one to be called asynchronously which would complicate
  // this code. Channels handle read and write errors the same way (see
  // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
  // callbacks is enough.

  if (!read_callback_.is_null()) {
    base::ResetAndReturn(&read_callback_).Run(error);
    return;
288
  }
289 290 291

  if (!write_callback_.is_null())
    base::ResetAndReturn(&write_callback_).Run(error);
292 293 294 295
}

void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
  if (!read_callback_.is_null()) {
296
    int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
297
    read_buffer_ = nullptr;
298
    DCHECK_GT(result, 0);
299
    base::ResetAndReturn(&read_callback_).Run(result);
300 301 302
  }
}

303
ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory,
304 305 306 307
                                       const std::string& base_channel_name)
    : base_channel_factory_(factory),
      base_channel_name_(base_channel_name),
      next_channel_id_(0),
sergeyu's avatar
sergeyu committed
308
      weak_factory_(this) {}
309 310

ChannelMultiplexer::~ChannelMultiplexer() {
311
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
312 313 314 315 316 317 318
  DCHECK(pending_channels_.empty());

  // Cancel creation of the base channel if it hasn't finished.
  if (base_channel_factory_)
    base_channel_factory_->CancelChannelCreation(base_channel_name_);
}

sergeyu's avatar
sergeyu committed
319 320
void ChannelMultiplexer::CreateChannel(const std::string& name,
                                       const ChannelCreatedCallback& callback) {
321 322 323 324 325 326
  if (base_channel_.get()) {
    // Already have |base_channel_|. Create new multiplexed channel
    // synchronously.
    callback.Run(GetOrCreateChannel(name)->CreateSocket());
  } else if (!base_channel_.get() && !base_channel_factory_) {
    // Fail synchronously if we failed to create |base_channel_|.
327
    callback.Run(nullptr);
328 329 330
  } else {
    // Still waiting for the |base_channel_|.
    pending_channels_.push_back(PendingChannel(name, callback));
331 332 333

    // If this is the first multiplexed channel then create the base channel.
    if (pending_channels_.size() == 1U) {
sergeyu's avatar
sergeyu committed
334
      base_channel_factory_->CreateChannel(
335 336 337 338
          base_channel_name_,
          base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
                     base::Unretained(this)));
    }
339 340 341 342 343 344 345 346 347 348 349 350 351 352
  }
}

void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
  for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
       it != pending_channels_.end(); ++it) {
    if (it->name == name) {
      pending_channels_.erase(it);
      return;
    }
  }
}

void ChannelMultiplexer::OnBaseChannelReady(
353
    std::unique_ptr<P2PStreamSocket> socket) {
354
  base_channel_factory_ = nullptr;
355
  base_channel_ = std::move(socket);
356

357 358
  if (base_channel_.get()) {
    // Initialize reader and writer.
359
    reader_.StartReading(base_channel_.get(),
sergeyu's avatar
sergeyu committed
360 361
                         base::Bind(&ChannelMultiplexer::OnIncomingPacket,
                                    base::Unretained(this)),
362 363
                         base::Bind(&ChannelMultiplexer::OnBaseChannelError,
                                    base::Unretained(this)));
364 365 366 367
    writer_.Start(base::Bind(&P2PStreamSocket::Write,
                             base::Unretained(base_channel_.get())),
                  base::Bind(&ChannelMultiplexer::OnBaseChannelError,
                             base::Unretained(this)));
368 369
  }

370 371
  DoCreatePendingChannels();
}
372

373 374 375 376 377 378 379 380 381 382 383 384 385 386
void ChannelMultiplexer::DoCreatePendingChannels() {
  if (pending_channels_.empty())
    return;

  // Every time this function is called it connects a single channel and posts a
  // separate task to connect other channels. This is necessary because the
  // callback may destroy the multiplexer or somehow else modify
  // |pending_channels_| list (e.g. call CancelChannelCreation()).
  base::ThreadTaskRunnerHandle::Get()->PostTask(
      FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
                            weak_factory_.GetWeakPtr()));

  PendingChannel c = pending_channels_.front();
  pending_channels_.erase(pending_channels_.begin());
387
  std::unique_ptr<P2PStreamSocket> socket;
388 389
  if (base_channel_.get())
    socket = GetOrCreateChannel(c.name)->CreateSocket();
390
  c.callback.Run(std::move(socket));
391 392 393 394
}

ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
    const std::string& name) {
395 396 397 398 399 400
  std::unique_ptr<MuxChannel>& channel = channels_[name];
  if (!channel) {
    // Create a new channel if we haven't found existing one.
    channel = base::MakeUnique<MuxChannel>(this, name, next_channel_id_);
    ++next_channel_id_;
  }
401

402
  return channel.get();
403 404 405
}


406
void ChannelMultiplexer::OnBaseChannelError(int error) {
407
  for (auto it = channels_.begin(); it != channels_.end(); ++it) {
408
    base::ThreadTaskRunnerHandle::Get()->PostTask(
409 410 411
        FROM_HERE,
        base::Bind(&ChannelMultiplexer::NotifyBaseChannelError,
                   weak_factory_.GetWeakPtr(), it->second->name(), error));
412 413 414
  }
}

415 416
void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name,
                                                int error) {
417
  auto it = channels_.find(name);
418 419
  if (it != channels_.end())
    it->second->OnBaseChannelError(error);
420 421
}

422 423 424
void ChannelMultiplexer::OnIncomingPacket(
    std::unique_ptr<CompoundBuffer> buffer) {
  std::unique_ptr<MultiplexPacket> packet =
sergeyu's avatar
sergeyu committed
425 426 427 428
      ParseMessage<MultiplexPacket>(buffer.get());
  if (!packet)
    return;

429
  DCHECK(packet->has_channel_id());
430 431 432 433 434 435
  if (!packet->has_channel_id()) {
    LOG(ERROR) << "Received packet without channel_id.";
    return;
  }

  int receive_id = packet->channel_id();
436
  MuxChannel* channel = nullptr;
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
  std::map<int, MuxChannel*>::iterator it =
      channels_by_receive_id_.find(receive_id);
  if (it != channels_by_receive_id_.end()) {
    channel = it->second;
  } else {
    // This is a new |channel_id| we haven't seen before. Look it up by name.
    if (!packet->has_channel_name()) {
      LOG(ERROR) << "Received packet with unknown channel_id and "
          "without channel_name.";
      return;
    }
    channel = GetOrCreateChannel(packet->channel_name());
    channel->set_receive_id(receive_id);
    channels_by_receive_id_[receive_id] = channel;
  }

453
  channel->OnIncomingPacket(std::move(packet));
454 455
}

456
void ChannelMultiplexer::DoWrite(std::unique_ptr<MultiplexPacket> packet,
457
                                 const base::Closure& done_task) {
458
  writer_.Write(SerializeAndFrameMessage(*packet), done_task);
459 460 461 462
}

}  // namespace protocol
}  // namespace remoting