Commit a15e4f0a authored by sergeyu's avatar sergeyu Committed by Commit bot

Make TokenValidatorFactory ref-counted.

This makes TokeValidatorFactory consistent with PairingRegistry and now
it can be shared between multiple authenticator objects.

Review URL: https://codereview.chromium.org/1788443005

Cr-Commit-Position: refs/heads/master@{#380871}
parent f18571e9
......@@ -806,13 +806,12 @@ void HostProcess::CreateAuthenticatorFactory() {
DCHECK(third_party_auth_config_.token_url.is_valid());
DCHECK(third_party_auth_config_.token_validation_url.is_valid());
scoped_ptr<protocol::TokenValidatorFactory> token_validator_factory(
new TokenValidatorFactoryImpl(
third_party_auth_config_,
key_pair_, context_->url_request_context_getter()));
scoped_refptr<protocol::TokenValidatorFactory> token_validator_factory =
new TokenValidatorFactoryImpl(third_party_auth_config_, key_pair_,
context_->url_request_context_getter());
factory = protocol::Me2MeHostAuthenticatorFactory::CreateWithThirdPartyAuth(
use_service_account_, host_owner_, local_certificate, key_pair_,
client_domain_, std::move(token_validator_factory));
client_domain_, token_validator_factory);
}
#if defined(OS_POSIX)
......
......@@ -28,14 +28,14 @@ class TokenValidatorFactoryImpl : public protocol::TokenValidatorFactory {
scoped_refptr<RsaKeyPair> key_pair,
scoped_refptr<net::URLRequestContextGetter> request_context_getter);
~TokenValidatorFactoryImpl() override;
// TokenValidatorFactory interface.
scoped_ptr<protocol::TokenValidator> CreateTokenValidator(
const std::string& local_jid,
const std::string& remote_jid) override;
private:
~TokenValidatorFactoryImpl() override;
ThirdPartyAuthConfig third_party_auth_config_;
scoped_refptr<RsaKeyPair> key_pair_;
scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
......
......@@ -98,8 +98,8 @@ class TokenValidatorFactoryImplTest : public testing::Test {
config.token_url = GURL(kTokenUrl);
config.token_validation_url = GURL(kTokenValidationUrl);
config.token_validation_cert_issuer = kTokenValidationCertIssuer;
token_validator_factory_.reset(new TokenValidatorFactoryImpl(
config, key_pair_, request_context_getter_));
token_validator_factory_ = new TokenValidatorFactoryImpl(
config, key_pair_, request_context_getter_);
}
static std::string CreateResponse(const std::string& scope) {
......@@ -131,7 +131,7 @@ class TokenValidatorFactoryImplTest : public testing::Test {
base::MessageLoop message_loop_;
scoped_refptr<RsaKeyPair> key_pair_;
scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
scoped_ptr<TokenValidatorFactoryImpl> token_validator_factory_;
scoped_refptr<TokenValidatorFactoryImpl> token_validator_factory_;
scoped_ptr<protocol::TokenValidator> token_validator_;
};
......
......@@ -48,9 +48,8 @@ Me2MeHostAuthenticatorFactory::CreateWithThirdPartyAuth(
const std::string& host_owner,
const std::string& local_cert,
scoped_refptr<RsaKeyPair> key_pair,
const std::string& required_client_domain,
scoped_ptr<TokenValidatorFactory>
token_validator_factory) {
const std::string& required_client_domain,
scoped_refptr<TokenValidatorFactory> token_validator_factory) {
scoped_ptr<Me2MeHostAuthenticatorFactory> result(
new Me2MeHostAuthenticatorFactory());
result->use_service_account_ = use_service_account;
......@@ -58,7 +57,7 @@ Me2MeHostAuthenticatorFactory::CreateWithThirdPartyAuth(
result->local_cert_ = local_cert;
result->key_pair_ = key_pair;
result->required_client_domain_ = required_client_domain;
result->token_validator_factory_ = std::move(token_validator_factory);
result->token_validator_factory_ = token_validator_factory;
return std::move(result);
}
......@@ -120,8 +119,7 @@ scoped_ptr<Authenticator> Me2MeHostAuthenticatorFactory::CreateAuthenticator(
if (token_validator_factory_) {
return NegotiatingHostAuthenticator::CreateWithThirdPartyAuth(
local_jid, remote_jid, local_cert_, key_pair_,
token_validator_factory_->CreateTokenValidator(local_jid,
remote_jid));
token_validator_factory_);
}
return NegotiatingHostAuthenticator::CreateWithPin(
......
......@@ -42,7 +42,7 @@ class Me2MeHostAuthenticatorFactory : public AuthenticatorFactory {
const std::string& local_cert,
scoped_refptr<RsaKeyPair> key_pair,
const std::string& required_client_domain,
scoped_ptr<TokenValidatorFactory> token_validator_factory);
scoped_refptr<TokenValidatorFactory> token_validator_factory);
Me2MeHostAuthenticatorFactory();
~Me2MeHostAuthenticatorFactory() override;
......@@ -64,7 +64,7 @@ class Me2MeHostAuthenticatorFactory : public AuthenticatorFactory {
std::string pin_hash_;
// Used only for third party host authenticators.
scoped_ptr<TokenValidatorFactory> token_validator_factory_;
scoped_refptr<TokenValidatorFactory> token_validator_factory_;
// Used only for pairing host authenticators.
scoped_refptr<PairingRegistry> pairing_registry_;
......
......@@ -78,11 +78,11 @@ NegotiatingHostAuthenticator::CreateWithThirdPartyAuth(
const std::string& remote_id,
const std::string& local_cert,
scoped_refptr<RsaKeyPair> key_pair,
scoped_ptr<TokenValidator> token_validator) {
scoped_refptr<TokenValidatorFactory> token_validator_factory) {
scoped_ptr<NegotiatingHostAuthenticator> result(
new NegotiatingHostAuthenticator(local_id, remote_id, local_cert,
key_pair));
result->token_validator_ = std::move(token_validator);
result->token_validator_factory_ = token_validator_factory;
result->AddMethod(Method::THIRD_PARTY_SPAKE2_CURVE25519);
result->AddMethod(Method::THIRD_PARTY_SPAKE2_P224);
return std::move(result);
......@@ -183,23 +183,15 @@ void NegotiatingHostAuthenticator::CreateAuthenticator(
DCHECK(current_method_ != Method::INVALID);
if (current_method_ == Method::THIRD_PARTY_SPAKE2_P224) {
// |ThirdPartyHostAuthenticator| takes ownership of |token_validator_|.
// The authentication method negotiation logic should guarantee that only
// one |ThirdPartyHostAuthenticator| will need to be created per session.
DCHECK(token_validator_);
current_authenticator_.reset(new ThirdPartyHostAuthenticator(
base::Bind(&V2Authenticator::CreateForHost, local_cert_,
local_key_pair_),
std::move(token_validator_)));
token_validator_factory_->CreateTokenValidator(local_id_, remote_id_)));
} else if (current_method_ == Method::THIRD_PARTY_SPAKE2_CURVE25519) {
// |ThirdPartyHostAuthenticator| takes ownership of |token_validator_|.
// The authentication method negotiation logic should guarantee that only
// one |ThirdPartyHostAuthenticator| will need to be created per session.
DCHECK(token_validator_);
current_authenticator_.reset(new ThirdPartyHostAuthenticator(
base::Bind(&Spake2Authenticator::CreateForHost, local_id_, remote_id_,
local_cert_, local_key_pair_),
std::move(token_validator_)));
token_validator_factory_->CreateTokenValidator(local_id_, remote_id_)));
} else if (current_method_ == Method::SHARED_SECRET_SPAKE2_CURVE25519) {
current_authenticator_ = Spake2Authenticator::CreateForHost(
local_id_, remote_id_, local_cert_, local_key_pair_,
......
......@@ -22,6 +22,8 @@ class RsaKeyPair;
namespace protocol {
class TokenValidatorFactory;
// Host-side implementation of NegotiatingAuthenticatorBase.
// See comments in negotiating_authenticator_base.h for a general explanation.
class NegotiatingHostAuthenticator : public NegotiatingAuthenticatorBase {
......@@ -53,7 +55,7 @@ class NegotiatingHostAuthenticator : public NegotiatingAuthenticatorBase {
const std::string& remote_id,
const std::string& local_cert,
scoped_refptr<RsaKeyPair> key_pair,
scoped_ptr<TokenValidator> token_validator);
scoped_refptr<TokenValidatorFactory> token_validator_factory);
// Overriden from Authenticator.
void ProcessMessage(const buzz::XmlElement* message,
......@@ -83,7 +85,7 @@ class NegotiatingHostAuthenticator : public NegotiatingAuthenticatorBase {
std::string shared_secret_hash_;
// Used only for third party host authenticators.
scoped_ptr<TokenValidator> token_validator_;
scoped_refptr<TokenValidatorFactory> token_validator_factory_;
// Used only for pairing authenticators.
scoped_refptr<PairingRegistry> pairing_registry_;
......
......@@ -8,6 +8,7 @@
#include <string>
#include "base/callback.h"
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
#include "url/gurl.h"
......@@ -48,15 +49,19 @@ class TokenValidator {
};
// Factory for |TokenValidator|.
class TokenValidatorFactory {
class TokenValidatorFactory
: public base::RefCountedThreadSafe<TokenValidatorFactory> {
public:
virtual ~TokenValidatorFactory() {}
// Creates a TokenValidator. |local_jid| and |remote_jid| are used to create
// a token scope that is restricted to the current connection's JIDs.
virtual scoped_ptr<TokenValidator> CreateTokenValidator(
const std::string& local_jid,
const std::string& remote_jid) = 0;
protected:
friend class base::RefCountedThreadSafe<TokenValidatorFactory>;
virtual ~TokenValidatorFactory() {}
};
} // namespace protocol
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment