diff --git a/storage-bigtable/src/access_token.rs b/storage-bigtable/src/access_token.rs index 30eeb7932d..e74dea211c 100644 --- a/storage-bigtable/src/access_token.rs +++ b/storage-bigtable/src/access_token.rs @@ -2,11 +2,16 @@ use goauth::{ auth::{JwtClaims, Token}, credentials::Credentials, - get_token, }; use log::*; use smpl_jwt::Jwt; -use std::time::Instant; +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + {Arc, RwLock}, + }, + time::Instant, +}; pub use goauth::scopes::Scope; @@ -23,36 +28,29 @@ fn load_credentials() -> Result { }) } +#[derive(Clone)] pub struct AccessToken { credentials: Credentials, - jwt: Jwt, - token: Option<(Token, Instant)>, + scope: Scope, + refresh_active: Arc, + token: Arc>, } impl AccessToken { - pub fn new(scope: &Scope) -> Result { + pub async fn new(scope: Scope) -> Result { let credentials = load_credentials()?; - - let claims = JwtClaims::new( - credentials.iss(), - &scope, - credentials.token_uri(), - None, - None, - ); - let jwt = Jwt::new( - claims, - credentials - .rsa_key() - .map_err(|err| format!("Invalid rsa key: {}", err))?, - None, - ); - - Ok(Self { - credentials, - jwt, - token: None, - }) + if let Err(err) = credentials.rsa_key() { + Err(format!("Invalid rsa key: {}", err)) + } else { + let token = Arc::new(RwLock::new(Self::get_token(&credentials, &scope).await?)); + let access_token = Self { + credentials, + scope, + token, + refresh_active: Arc::new(AtomicBool::new(false)), + }; + Ok(access_token) + } } /// The project that this token grants access to @@ -60,32 +58,61 @@ impl AccessToken { self.credentials.project() } - /// Call this function regularly, and before calling `access_token()` - pub async fn refresh(&mut self) { - if let Some((token, last_refresh)) = self.token.as_ref() { - if last_refresh.elapsed().as_secs() < token.expires_in() as u64 / 2 { + async fn get_token( + credentials: &Credentials, + scope: &Scope, + ) -> Result<(Token, Instant), String> { + info!("Requesting token for {:?} scope", scope); + let claims = JwtClaims::new( + credentials.iss(), + scope, + credentials.token_uri(), + None, + None, + ); + let jwt = Jwt::new(claims, credentials.rsa_key().unwrap(), None); + + let token = goauth::get_token(&jwt, credentials) + .await + .map_err(|err| format!("Failed to refresh access token: {}", err))?; + + info!("Token expires in {} seconds", token.expires_in()); + Ok((token, Instant::now())) + } + + /// Call this function regularly to ensure the access token does not expire + pub async fn refresh(&self) { + // Check if it's time to try a token refresh + { + let token_r = self.token.read().unwrap(); + if token_r.1.elapsed().as_secs() < token_r.0.expires_in() as u64 / 2 { + return; + } + + if self + .refresh_active + .compare_and_swap(false, true, Ordering::Relaxed) + { + // Refresh already pending return; } } info!("Refreshing token"); - match get_token(&self.jwt, &self.credentials).await { - Ok(new_token) => { - info!("Token expires in {} seconds", new_token.expires_in()); - self.token = Some((new_token, Instant::now())); - } - Err(err) => { - warn!("Failed to get new token: {}", err); + let new_token = Self::get_token(&self.credentials, &self.scope).await; + { + let mut token_w = self.token.write().unwrap(); + match new_token { + Ok(new_token) => *token_w = new_token, + Err(err) => warn!("{}", err), } + self.refresh_active.store(false, Ordering::Relaxed); } } /// Return an access token suitable for use in an HTTP authorization header - pub fn get(&self) -> Result { - if let Some((token, _)) = self.token.as_ref() { - Ok(format!("{} {}", token.token_type(), token.access_token())) - } else { - Err("Access token not available".into()) - } + pub fn get(&self) -> String { + let token_r = self.token.read().unwrap(); + format!("{} {}", token_r.0.token_type(), token_r.0.access_token()) } } diff --git a/storage-bigtable/src/bigtable.rs b/storage-bigtable/src/bigtable.rs index 675fca3003..99b25a45b9 100644 --- a/storage-bigtable/src/bigtable.rs +++ b/storage-bigtable/src/bigtable.rs @@ -4,7 +4,6 @@ use crate::access_token::{AccessToken, Scope}; use crate::compression::{compress_best, decompress}; use crate::root_ca_certificate; use log::*; -use std::sync::{Arc, RwLock}; use thiserror::Error; use tonic::{metadata::MetadataValue, transport::ClientTlsConfig, Request}; @@ -86,7 +85,7 @@ pub type Result = std::result::Result; #[derive(Clone)] pub struct BigTableConnection { - access_token: Option>>, + access_token: Option, channel: tonic::transport::Channel, table_prefix: String, } @@ -115,14 +114,14 @@ impl BigTableConnection { } Err(_) => { - let mut access_token = AccessToken::new(if read_only { - &Scope::BigTableDataReadOnly + let access_token = AccessToken::new(if read_only { + Scope::BigTableDataReadOnly } else { - &Scope::BigTableData + Scope::BigTableData }) + .await .map_err(Error::AccessTokenError)?; - access_token.refresh().await; let table_prefix = format!( "projects/{}/instances/{}/tables/", access_token.project(), @@ -130,7 +129,7 @@ impl BigTableConnection { ); Ok(Self { - access_token: Some(Arc::new(RwLock::new(access_token))), + access_token: Some(access_token), channel: tonic::transport::Channel::from_static( "https://bigtable.googleapis.com", ) @@ -153,28 +152,25 @@ impl BigTableConnection { /// Clients require `&mut self`, due to `Tonic::transport::Channel` limitations, however /// creating new clients is cheap and thus can be used as a work around for ease of use. pub fn client(&self) -> BigTable { - let client = { - if let Some(ref access_token) = self.access_token { - let access_token = access_token.clone(); - bigtable_client::BigtableClient::with_interceptor( - self.channel.clone(), - move |mut req: Request<()>| { - match access_token.read().unwrap().get() { - Ok(access_token) => match MetadataValue::from_str(&access_token) { - Ok(authorization_header) => { - req.metadata_mut() - .insert("authorization", authorization_header); - } - Err(err) => warn!("Failed to set authorization header: {}", err), - }, - Err(err) => warn!("{}", err), + let client = if let Some(access_token) = &self.access_token { + let access_token = access_token.clone(); + bigtable_client::BigtableClient::with_interceptor( + self.channel.clone(), + move |mut req: Request<()>| { + match MetadataValue::from_str(&access_token.get()) { + Ok(authorization_header) => { + req.metadata_mut() + .insert("authorization", authorization_header); } - Ok(req) - }, - ) - } else { - bigtable_client::BigtableClient::new(self.channel.clone()) - } + Err(err) => { + warn!("Failed to set authorization header: {}", err); + } + } + Ok(req) + }, + ) + } else { + bigtable_client::BigtableClient::new(self.channel.clone()) }; BigTable { access_token: self.access_token.clone(), @@ -202,7 +198,7 @@ impl BigTableConnection { } pub struct BigTable { - access_token: Option>>, + access_token: Option, client: bigtable_client::BigtableClient, table_prefix: String, } @@ -283,7 +279,7 @@ impl BigTable { async fn refresh_access_token(&self) { if let Some(ref access_token) = self.access_token { - access_token.write().unwrap().refresh().await; + access_token.refresh().await; } } @@ -298,7 +294,6 @@ impl BigTable { rows_limit: i64, ) -> Result> { self.refresh_access_token().await; - let response = self .client .read_rows(ReadRowsRequest {