Commit 571fe62b authored by Jun Choi's avatar Jun Choi Committed by Commit Bot

Reland "Update U2fPacket::GetSerializedData() to get rid of report_id."

Fixed bug introduced from r513104 by removing report ID from
serialization buffer both when reading from and writing to hid
devices. Also,removed use of net::IOBuffer since U2f Hid devices are
now servicified.

Bug: 782825
Change-Id: I084337af2e2355a74f2a888b5164ecb9c16cc9a0
Reviewed-on: https://chromium-review.googlesource.com/765058
Commit-Queue: Jun Choi <hongjunchoi@chromium.org>
Reviewed-by: 's avatarReilly Grant <reillyg@chromium.org>
Cr-Commit-Position: refs/heads/master@{#516117}
parent 26def82f
......@@ -13,7 +13,6 @@
#include "device/u2f/u2f_command_type.h"
#include "device/u2f/u2f_message.h"
#include "mojo/public/cpp/bindings/interface_request.h"
#include "net/base/io_buffer.h"
namespace device {
......@@ -21,6 +20,11 @@ namespace switches {
static constexpr char kEnableU2fHidTest[] = "enable-u2f-hid-tests";
} // namespace switches
namespace {
// U2F devices only provide a single report so specify a report ID of 0 here.
static constexpr uint8_t kReportId = 0x00;
} // namespace
U2fHidDevice::U2fHidDevice(device::mojom::HidDeviceInfoPtr device_info,
device::mojom::HidManager* hid_manager)
: U2fDevice(),
......@@ -172,12 +176,8 @@ void U2fHidDevice::WriteMessage(std::unique_ptr<U2fMessage> message,
return;
}
scoped_refptr<net::IOBufferWithSize> io_buffer = message->PopNextPacket();
std::vector<uint8_t> buffer(io_buffer->data() + 1,
io_buffer->data() + io_buffer->size());
connection_->Write(
0 /* report_id */, buffer,
kReportId, message->PopNextPacket(),
base::BindOnce(&U2fHidDevice::PacketWritten, weak_factory_.GetWeakPtr(),
std::move(message), true, std::move(callback)));
}
......@@ -215,11 +215,8 @@ void U2fHidDevice::OnRead(U2fHidMessageCallback callback,
}
DCHECK(buf);
std::vector<uint8_t> read_buffer;
read_buffer.push_back(report_id);
read_buffer.insert(read_buffer.end(), buf->begin(), buf->end());
std::unique_ptr<U2fMessage> read_message =
U2fMessage::CreateFromSerializedData(read_buffer);
U2fMessage::CreateFromSerializedData(*buf);
if (!read_message) {
std::move(callback).Run(false, nullptr);
......@@ -257,10 +254,7 @@ void U2fHidDevice::OnReadContinuation(
}
DCHECK(buf);
std::vector<uint8_t> read_buffer;
read_buffer.push_back(report_id);
read_buffer.insert(read_buffer.end(), buf->begin(), buf->end());
message->AddContinuationPacket(read_buffer);
message->AddContinuationPacket(*buf);
if (message->MessageComplete()) {
std::move(callback).Run(success, std::move(message));
return;
......
......@@ -4,7 +4,6 @@
#include "base/memory/ptr_util.h"
#include "device/u2f/u2f_packet.h"
#include "net/base/io_buffer.h"
#include "u2f_message.h"
......@@ -91,15 +90,13 @@ std::list<std::unique_ptr<U2fPacket>>::const_iterator U2fMessage::end() {
return packets_.cend();
}
scoped_refptr<net::IOBufferWithSize> U2fMessage::PopNextPacket() {
std::vector<uint8_t> U2fMessage::PopNextPacket() {
std::vector<uint8_t> data;
if (NumPackets() > 0) {
scoped_refptr<net::IOBufferWithSize> data =
packets_.front()->GetSerializedData();
data = packets_.front()->GetSerializedData();
packets_.pop_front();
return data;
}
return nullptr;
return data;
}
bool U2fMessage::AddContinuationPacket(const std::vector<uint8_t>& buf) {
......
......@@ -12,10 +12,6 @@
#include "device/u2f/u2f_command_type.h"
#include "device/u2f/u2f_packet.h"
namespace net {
class IOBufferWithSize;
} // namespace net
namespace device {
// U2fMessages are defined by the specification at
......@@ -37,7 +33,7 @@ class U2fMessage {
static std::unique_ptr<U2fMessage> CreateFromSerializedData(
const std::vector<uint8_t>& buf);
// Pop front of queue with next packet
scoped_refptr<net::IOBufferWithSize> PopNextPacket();
std::vector<uint8_t> PopNextPacket();
// Adds a continuation packet to the packet list, from the serialized
// response value
bool AddContinuationPacket(const std::vector<uint8_t>& packet_buf);
......
......@@ -4,7 +4,6 @@
#include "device/u2f/u2f_message.h"
#include "base/memory/ptr_util.h"
#include "net/base/io_buffer.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
......@@ -19,11 +18,11 @@ TEST_F(U2fMessageTest, TestPacketSize) {
auto init_packet =
std::make_unique<U2fInitPacket>(channel_id, 0, data, data.size());
EXPECT_EQ(65, init_packet->GetSerializedData()->size());
EXPECT_EQ(64u, init_packet->GetSerializedData().size());
auto continuation_packet =
std::make_unique<U2fContinuationPacket>(channel_id, 0, data);
EXPECT_EQ(65, continuation_packet->GetSerializedData()->size());
EXPECT_EQ(64u, continuation_packet->GetSerializedData().size());
}
/*
......@@ -42,29 +41,21 @@ TEST_F(U2fMessageTest, TestPacketData) {
uint8_t cmd = static_cast<uint8_t>(U2fCommandType::CMD_WINK);
auto init_packet =
std::make_unique<U2fInitPacket>(channel_id, cmd, data, data.size());
int index = 0;
scoped_refptr<net::IOBufferWithSize> serialized =
init_packet->GetSerializedData();
EXPECT_EQ(0, serialized->data()[index++]);
EXPECT_EQ((channel_id >> 24) & 0xff,
static_cast<uint8_t>(serialized->data()[index++]));
EXPECT_EQ((channel_id >> 16) & 0xff,
static_cast<uint8_t>(serialized->data()[index++]));
EXPECT_EQ((channel_id >> 8) & 0xff,
static_cast<uint8_t>(serialized->data()[index++]));
EXPECT_EQ(channel_id & 0xff,
static_cast<uint8_t>(serialized->data()[index++]));
EXPECT_EQ(cmd, static_cast<uint8_t>(serialized->data()[index++]));
EXPECT_EQ(data.size() >> 8,
static_cast<uint8_t>(serialized->data()[index++]));
EXPECT_EQ(data.size() & 0xff,
static_cast<uint8_t>(serialized->data()[index++]));
EXPECT_EQ(data[0], serialized->data()[index++]);
EXPECT_EQ(data[1], serialized->data()[index++]);
for (; index < serialized->size(); index++)
EXPECT_EQ(0, serialized->data()[index]) << "mismatch at index " << index;
size_t index = 0;
std::vector<uint8_t> serialized = init_packet->GetSerializedData();
EXPECT_EQ((channel_id >> 24) & 0xff, serialized[index++]);
EXPECT_EQ((channel_id >> 16) & 0xff, serialized[index++]);
EXPECT_EQ((channel_id >> 8) & 0xff, serialized[index++]);
EXPECT_EQ(channel_id & 0xff, serialized[index++]);
EXPECT_EQ(cmd, serialized[index++]);
EXPECT_EQ(data.size() >> 8, serialized[index++]);
EXPECT_EQ(data.size() & 0xff, serialized[index++]);
EXPECT_EQ(data[0], serialized[index++]);
EXPECT_EQ(data[1], serialized[index++]);
for (; index < serialized.size(); index++)
EXPECT_EQ(0, serialized[index]) << "mismatch at index " << index;
}
TEST_F(U2fMessageTest, TestPacketConstructors) {
......@@ -75,10 +66,8 @@ TEST_F(U2fMessageTest, TestPacketConstructors) {
std::make_unique<U2fInitPacket>(channel_id, cmd, data, data.size());
size_t payload_length = static_cast<size_t>(orig_packet->payload_length());
scoped_refptr<net::IOBufferWithSize> buffer =
orig_packet->GetSerializedData();
std::vector<uint8_t> orig_data(buffer->data(),
buffer->data() + buffer->size());
std::vector<uint8_t> orig_data = orig_packet->GetSerializedData();
std::unique_ptr<U2fInitPacket> reconstructed_packet =
U2fInitPacket::CreateFromSerializedData(orig_data, &payload_length);
EXPECT_EQ(orig_packet->command(), reconstructed_packet->command());
......@@ -89,15 +78,15 @@ TEST_F(U2fMessageTest, TestPacketConstructors) {
EXPECT_EQ(channel_id, reconstructed_packet->channel_id());
ASSERT_EQ(orig_packet->GetSerializedData()->size(),
reconstructed_packet->GetSerializedData()->size());
for (int index = 0; index < orig_packet->GetSerializedData()->size();
ASSERT_EQ(orig_packet->GetSerializedData().size(),
reconstructed_packet->GetSerializedData().size());
for (size_t index = 0; index < orig_packet->GetSerializedData().size();
++index) {
EXPECT_EQ(orig_packet->GetSerializedData()->data()[index],
reconstructed_packet->GetSerializedData()->data()[index])
EXPECT_EQ(orig_packet->GetSerializedData()[index],
reconstructed_packet->GetSerializedData()[index])
<< "mismatch at index " << index;
}
}
}
TEST_F(U2fMessageTest, TestMaxLengthPacketConstructors) {
uint32_t channel_id = 0xAAABACAD;
......@@ -110,15 +99,12 @@ TEST_F(U2fMessageTest, TestMaxLengthPacketConstructors) {
U2fMessage::Create(channel_id, cmd, data);
auto it = orig_msg->begin();
scoped_refptr<net::IOBufferWithSize> buffer = (*it)->GetSerializedData();
std::vector<uint8_t> msg_data(buffer->data(),
buffer->data() + buffer->size());
std::vector<uint8_t> msg_data = (*it)->GetSerializedData();
std::unique_ptr<U2fMessage> new_msg =
U2fMessage::CreateFromSerializedData(msg_data);
it++;
for (; it != orig_msg->end(); ++it) {
buffer = (*it)->GetSerializedData();
msg_data.assign(buffer->data(), buffer->data() + buffer->size());
msg_data = (*it)->GetSerializedData();
new_msg->AddContinuationPacket(msg_data);
}
......@@ -132,12 +118,12 @@ TEST_F(U2fMessageTest, TestMaxLengthPacketConstructors) {
EXPECT_EQ((*orig_it)->channel_id(), (*new_it)->channel_id());
ASSERT_EQ((*orig_it)->GetSerializedData()->size(),
(*new_it)->GetSerializedData()->size());
for (int index = 0; index < (*orig_it)->GetSerializedData()->size();
ASSERT_EQ((*orig_it)->GetSerializedData().size(),
(*new_it)->GetSerializedData().size());
for (size_t index = 0; index < (*orig_it)->GetSerializedData().size();
++index) {
EXPECT_EQ((*orig_it)->GetSerializedData()->data()[index],
(*new_it)->GetSerializedData()->data()[index])
EXPECT_EQ((*orig_it)->GetSerializedData()[index],
(*new_it)->GetSerializedData()[index])
<< "mismatch at index " << index;
}
}
......@@ -185,23 +171,21 @@ TEST_F(U2fMessageTest, TestDeserialize) {
std::unique_ptr<U2fMessage> orig_message =
U2fMessage::Create(channel_id, U2fCommandType::CMD_PING, data);
std::list<scoped_refptr<net::IOBufferWithSize>> orig_list;
scoped_refptr<net::IOBufferWithSize> buf = orig_message->PopNextPacket();
std::list<std::vector<uint8_t>> orig_list;
std::vector<uint8_t> buf = orig_message->PopNextPacket();
orig_list.push_back(buf);
std::vector<uint8_t> message_data(buf->data(), buf->data() + buf->size());
std::unique_ptr<U2fMessage> new_message =
U2fMessage::CreateFromSerializedData(message_data);
U2fMessage::CreateFromSerializedData(buf);
while (!new_message->MessageComplete()) {
buf = orig_message->PopNextPacket();
orig_list.push_back(buf);
message_data.assign(buf->data(), buf->data() + buf->size());
new_message->AddContinuationPacket(message_data);
new_message->AddContinuationPacket(buf);
}
while ((buf = new_message->PopNextPacket())) {
ASSERT_EQ(buf->size(), orig_list.front()->size());
EXPECT_EQ(0, memcmp(buf->data(), orig_list.front()->data(), buf->size()));
while (!(buf = new_message->PopNextPacket()).empty()) {
ASSERT_EQ(buf.size(), orig_list.front().size());
EXPECT_EQ(0, memcmp(buf.data(), orig_list.front().data(), buf.size()));
orig_list.pop_front();
}
}
......
......@@ -5,7 +5,6 @@
#include <cstring>
#include "base/memory/ptr_util.h"
#include "net/base/io_buffer.h"
#include "u2f_packet.h"
......@@ -37,24 +36,19 @@ U2fInitPacket::U2fInitPacket(uint32_t channel_id,
command_(cmd),
payload_length_(payload_length) {}
scoped_refptr<net::IOBufferWithSize> U2fInitPacket::GetSerializedData() {
auto serialized =
base::WrapRefCounted(new net::IOBufferWithSize(kPacketSize));
size_t index = 0;
// Byte at offset 0 is the report ID, which is always 0
serialized->data()[index++] = 0;
serialized->data()[index++] = (channel_id_ >> 24) & 0xff;
serialized->data()[index++] = (channel_id_ >> 16) & 0xff;
serialized->data()[index++] = (channel_id_ >> 8) & 0xff;
serialized->data()[index++] = channel_id_ & 0xff;
serialized->data()[index++] = command_;
serialized->data()[index++] = (payload_length_ >> 8) & 0xff;
serialized->data()[index++] = payload_length_ & 0xff;
std::memcpy(&serialized->data()[index], data_.data(), data_.size());
index += data_.size();
std::memset(&serialized->data()[index], 0, serialized->size() - index);
std::vector<uint8_t> U2fInitPacket::GetSerializedData() {
std::vector<uint8_t> serialized;
serialized.reserve(kPacketSize);
serialized.push_back((channel_id_ >> 24) & 0xff);
serialized.push_back((channel_id_ >> 16) & 0xff);
serialized.push_back((channel_id_ >> 8) & 0xff);
serialized.push_back(channel_id_ & 0xff);
serialized.push_back(command_);
serialized.push_back((payload_length_ >> 8) & 0xff);
serialized.push_back(payload_length_ & 0xff);
serialized.insert(serialized.end(), data_.begin(), data_.end());
serialized.resize(kPacketSize, 0);
return serialized;
}
......@@ -70,16 +64,14 @@ std::unique_ptr<U2fInitPacket> U2fInitPacket::CreateFromSerializedData(
U2fInitPacket::U2fInitPacket(const std::vector<uint8_t>& serialized,
size_t* remaining_size) {
// Report ID is at index 0, so start at index 1 for channel ID
size_t index = 1;
uint16_t payload_size = 0;
size_t index = 0;
channel_id_ = (serialized[index++] & 0xff) << 24;
channel_id_ |= (serialized[index++] & 0xff) << 16;
channel_id_ |= (serialized[index++] & 0xff) << 8;
channel_id_ |= serialized[index++] & 0xff;
command_ = serialized[index++];
payload_size = serialized[index++] << 8;
uint16_t payload_size = serialized[index++] << 8;
payload_size |= serialized[index++];
payload_length_ = payload_size;
......@@ -105,23 +97,17 @@ U2fContinuationPacket::U2fContinuationPacket(const uint32_t channel_id,
const std::vector<uint8_t>& data)
: U2fPacket(data, channel_id), sequence_(sequence) {}
scoped_refptr<net::IOBufferWithSize>
U2fContinuationPacket::GetSerializedData() {
auto serialized =
base::WrapRefCounted(new net::IOBufferWithSize(kPacketSize));
size_t index = 0;
// Byte at offset 0 is the report ID, which is always 0
serialized->data()[index++] = 0;
serialized->data()[index++] = (channel_id_ >> 24) & 0xff;
serialized->data()[index++] = (channel_id_ >> 16) & 0xff;
serialized->data()[index++] = (channel_id_ >> 8) & 0xff;
serialized->data()[index++] = channel_id_ & 0xff;
serialized->data()[index++] = sequence_;
std::memcpy(&serialized->data()[index], data_.data(), data_.size());
index += data_.size();
std::memset(&serialized->data()[index], 0, serialized->size() - index);
std::vector<uint8_t> U2fContinuationPacket::GetSerializedData() {
std::vector<uint8_t> serialized;
serialized.reserve(kPacketSize);
serialized.push_back((channel_id_ >> 24) & 0xff);
serialized.push_back((channel_id_ >> 16) & 0xff);
serialized.push_back((channel_id_ >> 8) & 0xff);
serialized.push_back(channel_id_ & 0xff);
serialized.push_back(sequence_);
serialized.insert(serialized.end(), data_.begin(), data_.end());
serialized.resize(kPacketSize, 0);
return serialized;
}
......@@ -139,10 +125,7 @@ U2fContinuationPacket::CreateFromSerializedData(
U2fContinuationPacket::U2fContinuationPacket(
const std::vector<uint8_t>& serialized,
size_t* remaining_size) {
// Report ID is at index 0, so start at index 1 for channel ID
size_t index = 1;
size_t data_size;
size_t index = 0;
channel_id_ = (serialized[index++] & 0xff) << 24;
channel_id_ |= (serialized[index++] & 0xff) << 16;
channel_id_ |= (serialized[index++] & 0xff) << 8;
......@@ -150,7 +133,7 @@ U2fContinuationPacket::U2fContinuationPacket(
sequence_ = serialized[index++];
// Check to see if packet payload is less than maximum size and padded with 0s
data_size = std::min(*remaining_size, kPacketSize - index);
size_t data_size = std::min(*remaining_size, kPacketSize - index);
*remaining_size -= data_size;
data_.insert(data_.end(), serialized.begin() + index,
serialized.begin() + index + data_size);
......
......@@ -10,10 +10,6 @@
#include "base/memory/ref_counted.h"
namespace net {
class IOBufferWithSize;
} // namespace net
namespace device {
// U2fPackets are defined by the specification at
......@@ -28,15 +24,14 @@ class U2fPacket {
U2fPacket(const std::vector<uint8_t>& data, uint32_t channel_id);
virtual ~U2fPacket();
virtual scoped_refptr<net::IOBufferWithSize> GetSerializedData() = 0;
virtual std::vector<uint8_t> GetSerializedData() = 0;
std::vector<uint8_t> GetPacketPayload() const;
uint32_t channel_id() { return channel_id_; }
protected:
U2fPacket();
// Packet size of 64 bytes + 1 byte report ID
static constexpr size_t kPacketSize = 65;
static constexpr size_t kPacketSize = 64;
std::vector<uint8_t> data_;
uint32_t channel_id_;
......@@ -66,7 +61,7 @@ class U2fInitPacket : public U2fPacket {
static std::unique_ptr<U2fInitPacket> CreateFromSerializedData(
const std::vector<uint8_t>& serialized,
size_t* remaining_size);
scoped_refptr<net::IOBufferWithSize> GetSerializedData() final;
std::vector<uint8_t> GetSerializedData() final;
uint8_t command() { return command_; }
uint16_t payload_length() { return payload_length_; }
......@@ -96,7 +91,7 @@ class U2fContinuationPacket : public U2fPacket {
static std::unique_ptr<U2fContinuationPacket> CreateFromSerializedData(
const std::vector<uint8_t>& serialized,
size_t* remaining_size);
scoped_refptr<net::IOBufferWithSize> GetSerializedData() final;
std::vector<uint8_t> GetSerializedData() final;
uint8_t sequence() { return sequence_; }
private:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment