channel_multiplexer.cc 14.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "remoting/protocol/channel_multiplexer.h"

#include <string.h>

#include "base/bind.h"
#include "base/callback.h"
11
#include "base/callback_helpers.h"
12
#include "base/location.h"
13
#include "base/single_thread_task_runner.h"
14
#include "base/stl_util.h"
15
#include "base/thread_task_runner_handle.h"
16
#include "net/base/net_errors.h"
17
#include "remoting/protocol/message_serialization.h"
18
#include "remoting/protocol/p2p_stream_socket.h"
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

namespace remoting {
namespace protocol {

namespace {
const int kChannelIdUnknown = -1;
const int kMaxPacketSize = 1024;

class PendingPacket {
 public:
  PendingPacket(scoped_ptr<MultiplexPacket> packet,
                const base::Closure& done_task)
      : packet(packet.Pass()),
        done_task(done_task),
        pos(0U) {
  }
  ~PendingPacket() {
    done_task.Run();
  }

  bool is_empty() { return pos >= packet->data().size(); }

  int Read(char* buffer, size_t size) {
    size = std::min(size, packet->data().size() - pos);
    memcpy(buffer, packet->data().data() + pos, size);
    pos += size;
    return size;
  }

 private:
  scoped_ptr<MultiplexPacket> packet;
  base::Closure done_task;
  size_t pos;

  DISALLOW_COPY_AND_ASSIGN(PendingPacket);
};

}  // namespace

const char ChannelMultiplexer::kMuxChannelName[] = "mux";

struct ChannelMultiplexer::PendingChannel {
  PendingChannel(const std::string& name,
sergeyu's avatar
sergeyu committed
62
                 const ChannelCreatedCallback& callback)
63 64 65
      : name(name), callback(callback) {
  }
  std::string name;
sergeyu's avatar
sergeyu committed
66
  ChannelCreatedCallback callback;
67 68 69 70 71 72 73 74 75 76 77 78 79
};

class ChannelMultiplexer::MuxChannel {
 public:
  MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
             int send_id);
  ~MuxChannel();

  const std::string& name() { return name_; }
  int receive_id() { return receive_id_; }
  void set_receive_id(int id) { receive_id_ = id; }

  // Called by ChannelMultiplexer.
80
  scoped_ptr<P2PStreamSocket> CreateSocket();
81 82
  void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
                        const base::Closure& done_task);
83
  void OnBaseChannelError(int error);
84 85 86

  // Called by MuxSocket.
  void OnSocketDestroyed();
87
  void DoWrite(scoped_ptr<MultiplexPacket> packet,
88
               const base::Closure& done_task);
89
  int DoRead(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len);
90 91 92 93 94 95 96 97 98 99 100 101 102

 private:
  ChannelMultiplexer* multiplexer_;
  std::string name_;
  int send_id_;
  bool id_sent_;
  int receive_id_;
  MuxSocket* socket_;
  std::list<PendingPacket*> pending_packets_;

  DISALLOW_COPY_AND_ASSIGN(MuxChannel);
};

103
class ChannelMultiplexer::MuxSocket : public P2PStreamSocket,
104 105 106 107
                                      public base::NonThreadSafe,
                                      public base::SupportsWeakPtr<MuxSocket> {
 public:
  MuxSocket(MuxChannel* channel);
108
  ~MuxSocket() override;
109 110

  void OnWriteComplete();
111
  void OnBaseChannelError(int error);
112 113
  void OnPacketReceived();

114 115
  // P2PStreamSocket interface.
  int Read(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
116
           const net::CompletionCallback& callback) override;
117
  int Write(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
118 119
            const net::CompletionCallback& callback) override;

120 121 122
 private:
  MuxChannel* channel_;

123 124
  int base_channel_error_ = net::OK;

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
  net::CompletionCallback read_callback_;
  scoped_refptr<net::IOBuffer> read_buffer_;
  int read_buffer_size_;

  bool write_pending_;
  int write_result_;
  net::CompletionCallback write_callback_;

  DISALLOW_COPY_AND_ASSIGN(MuxSocket);
};


ChannelMultiplexer::MuxChannel::MuxChannel(
    ChannelMultiplexer* multiplexer,
    const std::string& name,
    int send_id)
    : multiplexer_(multiplexer),
      name_(name),
      send_id_(send_id),
      id_sent_(false),
      receive_id_(kChannelIdUnknown),
146
      socket_(nullptr) {
147 148 149 150 151 152 153 154
}

ChannelMultiplexer::MuxChannel::~MuxChannel() {
  // Socket must be destroyed before the channel.
  DCHECK(!socket_);
  STLDeleteElements(&pending_packets_);
}

155
scoped_ptr<P2PStreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
156 157 158
  DCHECK(!socket_);  // Can't create more than one socket per channel.
  scoped_ptr<MuxSocket> result(new MuxSocket(this));
  socket_ = result.get();
159
  return result.Pass();
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
}

void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
    scoped_ptr<MultiplexPacket> packet,
    const base::Closure& done_task) {
  DCHECK_EQ(packet->channel_id(), receive_id_);
  if (packet->data().size() > 0) {
    pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
    if (socket_) {
      // Notify the socket that we have more data.
      socket_->OnPacketReceived();
    }
  }
}

175
void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) {
176
  if (socket_)
177
    socket_->OnBaseChannelError(error);
178 179 180 181
}

void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
  DCHECK(socket_);
182
  socket_ = nullptr;
183 184
}

185
void ChannelMultiplexer::MuxChannel::DoWrite(
186 187 188 189 190 191 192
    scoped_ptr<MultiplexPacket> packet,
    const base::Closure& done_task) {
  packet->set_channel_id(send_id_);
  if (!id_sent_) {
    packet->set_channel_name(name_);
    id_sent_ = true;
  }
193
  multiplexer_->DoWrite(packet.Pass(), done_task);
194 195
}

196 197 198
int ChannelMultiplexer::MuxChannel::DoRead(
    const scoped_refptr<net::IOBuffer>& buffer,
    int buffer_len) {
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
  int pos = 0;
  while (buffer_len > 0 && !pending_packets_.empty()) {
    DCHECK(!pending_packets_.front()->is_empty());
    int result = pending_packets_.front()->Read(
        buffer->data() + pos, buffer_len);
    DCHECK_LE(result, buffer_len);
    pos += result;
    buffer_len -= pos;
    if (pending_packets_.front()->is_empty()) {
      delete pending_packets_.front();
      pending_packets_.erase(pending_packets_.begin());
    }
  }
  return pos;
}

ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
    : channel_(channel),
      read_buffer_size_(0),
      write_pending_(false),
      write_result_(0) {
}

ChannelMultiplexer::MuxSocket::~MuxSocket() {
  channel_->OnSocketDestroyed();
}

int ChannelMultiplexer::MuxSocket::Read(
227
    const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
228 229 230 231
    const net::CompletionCallback& callback) {
  DCHECK(CalledOnValidThread());
  DCHECK(read_callback_.is_null());

232 233 234
  if (base_channel_error_ != net::OK)
    return base_channel_error_;

235 236 237 238 239 240 241 242 243 244 245
  int result = channel_->DoRead(buffer, buffer_len);
  if (result == 0) {
    read_buffer_ = buffer;
    read_buffer_size_ = buffer_len;
    read_callback_ = callback;
    return net::ERR_IO_PENDING;
  }
  return result;
}

int ChannelMultiplexer::MuxSocket::Write(
246
    const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
247 248
    const net::CompletionCallback& callback) {
  DCHECK(CalledOnValidThread());
249 250 251 252
  DCHECK(write_callback_.is_null());

  if (base_channel_error_ != net::OK)
    return base_channel_error_;
253 254 255 256 257 258

  scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
  size_t size = std::min(kMaxPacketSize, buffer_len);
  packet->mutable_data()->assign(buffer->data(), size);

  write_pending_ = true;
259
  channel_->DoWrite(packet.Pass(), base::Bind(
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
      &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));

  // OnWriteComplete() might be called above synchronously.
  if (write_pending_) {
    DCHECK(write_callback_.is_null());
    write_callback_ = callback;
    write_result_ = size;
    return net::ERR_IO_PENDING;
  }

  return size;
}

void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
  write_pending_ = false;
275 276 277
  if (!write_callback_.is_null())
    base::ResetAndReturn(&write_callback_).Run(write_result_);

278 279
}

280 281 282 283 284 285 286 287 288 289 290 291 292
void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) {
  base_channel_error_ = error;

  // Here only one of the read and write callbacks is called if both of them are
  // pending. Ideally both of them should be called in that case, but that would
  // require the second one to be called asynchronously which would complicate
  // this code. Channels handle read and write errors the same way (see
  // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
  // callbacks is enough.

  if (!read_callback_.is_null()) {
    base::ResetAndReturn(&read_callback_).Run(error);
    return;
293
  }
294 295 296

  if (!write_callback_.is_null())
    base::ResetAndReturn(&write_callback_).Run(error);
297 298 299 300
}

void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
  if (!read_callback_.is_null()) {
301
    int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
302
    read_buffer_ = nullptr;
303
    DCHECK_GT(result, 0);
304
    base::ResetAndReturn(&read_callback_).Run(result);
305 306 307
  }
}

308
ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory,
309 310 311 312
                                       const std::string& base_channel_name)
    : base_channel_factory_(factory),
      base_channel_name_(base_channel_name),
      next_channel_id_(0),
sergeyu's avatar
sergeyu committed
313 314 315
      parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket,
                         base::Unretained(this)),
              &reader_),
316
      weak_factory_(this) {
317 318 319 320 321 322 323 324 325 326 327
}

ChannelMultiplexer::~ChannelMultiplexer() {
  DCHECK(pending_channels_.empty());
  STLDeleteValues(&channels_);

  // Cancel creation of the base channel if it hasn't finished.
  if (base_channel_factory_)
    base_channel_factory_->CancelChannelCreation(base_channel_name_);
}

sergeyu's avatar
sergeyu committed
328 329
void ChannelMultiplexer::CreateChannel(const std::string& name,
                                       const ChannelCreatedCallback& callback) {
330 331 332 333 334 335
  if (base_channel_.get()) {
    // Already have |base_channel_|. Create new multiplexed channel
    // synchronously.
    callback.Run(GetOrCreateChannel(name)->CreateSocket());
  } else if (!base_channel_.get() && !base_channel_factory_) {
    // Fail synchronously if we failed to create |base_channel_|.
336
    callback.Run(nullptr);
337 338 339
  } else {
    // Still waiting for the |base_channel_|.
    pending_channels_.push_back(PendingChannel(name, callback));
340 341 342

    // If this is the first multiplexed channel then create the base channel.
    if (pending_channels_.size() == 1U) {
sergeyu's avatar
sergeyu committed
343
      base_channel_factory_->CreateChannel(
344 345 346 347
          base_channel_name_,
          base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
                     base::Unretained(this)));
    }
348 349 350 351 352 353 354 355 356 357 358 359 360 361
  }
}

void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
  for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
       it != pending_channels_.end(); ++it) {
    if (it->name == name) {
      pending_channels_.erase(it);
      return;
    }
  }
}

void ChannelMultiplexer::OnBaseChannelReady(
362
    scoped_ptr<P2PStreamSocket> socket) {
363
  base_channel_factory_ = nullptr;
364 365
  base_channel_ = socket.Pass();

366 367
  if (base_channel_.get()) {
    // Initialize reader and writer.
368 369 370
    reader_.StartReading(base_channel_.get(),
                         base::Bind(&ChannelMultiplexer::OnBaseChannelError,
                                    base::Unretained(this)));
371 372
    writer_.Init(base::Bind(&P2PStreamSocket::Write,
                            base::Unretained(base_channel_.get())),
373
                 base::Bind(&ChannelMultiplexer::OnBaseChannelError,
374
                            base::Unretained(this)));
375 376
  }

377 378
  DoCreatePendingChannels();
}
379

380 381 382 383 384 385 386 387 388 389 390 391 392 393
void ChannelMultiplexer::DoCreatePendingChannels() {
  if (pending_channels_.empty())
    return;

  // Every time this function is called it connects a single channel and posts a
  // separate task to connect other channels. This is necessary because the
  // callback may destroy the multiplexer or somehow else modify
  // |pending_channels_| list (e.g. call CancelChannelCreation()).
  base::ThreadTaskRunnerHandle::Get()->PostTask(
      FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
                            weak_factory_.GetWeakPtr()));

  PendingChannel c = pending_channels_.front();
  pending_channels_.erase(pending_channels_.begin());
394
  scoped_ptr<P2PStreamSocket> socket;
395 396 397
  if (base_channel_.get())
    socket = GetOrCreateChannel(c.name)->CreateSocket();
  c.callback.Run(socket.Pass());
398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
}

ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
    const std::string& name) {
  // Check if we already have a channel with the requested name.
  std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
  if (it != channels_.end())
    return it->second;

  // Create a new channel if we haven't found existing one.
  MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
  ++next_channel_id_;
  channels_[channel->name()] = channel;
  return channel;
}


415
void ChannelMultiplexer::OnBaseChannelError(int error) {
416 417
  for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
       it != channels_.end(); ++it) {
418
    base::ThreadTaskRunnerHandle::Get()->PostTask(
419 420 421
        FROM_HERE,
        base::Bind(&ChannelMultiplexer::NotifyBaseChannelError,
                   weak_factory_.GetWeakPtr(), it->second->name(), error));
422 423 424
  }
}

425 426
void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name,
                                                int error) {
427
  std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
428 429
  if (it != channels_.end())
    it->second->OnBaseChannelError(error);
430 431 432 433
}

void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
                                          const base::Closure& done_task) {
434
  DCHECK(packet->has_channel_id());
435 436 437 438 439 440 441
  if (!packet->has_channel_id()) {
    LOG(ERROR) << "Received packet without channel_id.";
    done_task.Run();
    return;
  }

  int receive_id = packet->channel_id();
442
  MuxChannel* channel = nullptr;
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
  std::map<int, MuxChannel*>::iterator it =
      channels_by_receive_id_.find(receive_id);
  if (it != channels_by_receive_id_.end()) {
    channel = it->second;
  } else {
    // This is a new |channel_id| we haven't seen before. Look it up by name.
    if (!packet->has_channel_name()) {
      LOG(ERROR) << "Received packet with unknown channel_id and "
          "without channel_name.";
      done_task.Run();
      return;
    }
    channel = GetOrCreateChannel(packet->channel_name());
    channel->set_receive_id(receive_id);
    channels_by_receive_id_[receive_id] = channel;
  }

  channel->OnIncomingPacket(packet.Pass(), done_task);
}

463
void ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
464
                                 const base::Closure& done_task) {
465
  writer_.Write(SerializeAndFrameMessage(*packet), done_task);
466 467 468 469
}

}  // namespace protocol
}  // namespace remoting