diff --git a/src/rpc_request.rs b/src/rpc_request.rs index 94987e7eae..b38177ab22 100644 --- a/src/rpc_request.rs +++ b/src/rpc_request.rs @@ -2,6 +2,7 @@ use reqwest; use reqwest::header::CONTENT_TYPE; use serde_json::{self, Value}; use std::net::SocketAddr; +use std::thread::sleep; use std::time::Duration; use std::{error, fmt}; @@ -64,23 +65,51 @@ impl RpcRequest { client: &RpcClient, id: u64, params: Option, + ) -> Result> { + self.retry_make_rpc_request(client, id, params, 0) + } + + pub fn retry_make_rpc_request( + &self, + client: &RpcClient, + id: u64, + params: Option, + mut retries: usize, ) -> Result> { let request = self.build_request_json(id, params); - let mut response = client - .client - .post(&client.addr) - .header(CONTENT_TYPE, "application/json") - .body(request.to_string()) - .send()?; - let json: Value = serde_json::from_str(&response.text()?)?; - if json["error"].is_object() { - Err(RpcError::RpcRequestError(format!( - "RPC Error response: {}", - serde_json::to_string(&json["error"]).unwrap() - )))? + loop { + match client + .client + .post(&client.addr) + .header(CONTENT_TYPE, "application/json") + .body(request.to_string()) + .send() + { + Ok(mut response) => { + let json: Value = serde_json::from_str(&response.text()?)?; + if json["error"].is_object() { + Err(RpcError::RpcRequestError(format!( + "RPC Error response: {}", + serde_json::to_string(&json["error"]).unwrap() + )))? + } + return Ok(json["result"].clone()); + } + Err(e) => { + info!( + "make_rpc_request() failed, {} retries left: {:?}", + retries, e + ); + if retries == 0 { + Err(e)?; + } + retries -= 1; + // TODO: Make the caller supply their desired retry frequency? + sleep(Duration::from_millis(500)); + } + } } - Ok(json["result"].clone()) } fn build_request_json(&self, id: u64, params: Option) -> Value { @@ -242,4 +271,44 @@ mod tests { RpcRequest::GetLastId.make_rpc_request(&rpc_client, 3, Some(json!("paramter"))); assert_eq!(last_id.is_err(), true); } + + #[test] + fn test_retry_make_rpc_request() { + solana_logger::setup(); + let (sender, receiver) = channel(); + thread::spawn(move || { + // 1. Pick a random port + // 2. Tell the client to start using it + // 3. Delay for 1.5 seconds before starting the server to ensure the client will fail + // and need to retry + let rpc_addr = socketaddr!(0, 4242); + sender.send(rpc_addr.clone()).unwrap(); + sleep(Duration::from_millis(1500)); + + let mut io = IoHandler::default(); + io.add_method("getBalance", move |_params: Params| { + Ok(Value::Number(Number::from(5))) + }); + let server = ServerBuilder::new(io) + .threads(1) + .cors(DomainsValidation::AllowOnly(vec![ + AccessControlAllowOrigin::Any, + ])) + .start_http(&rpc_addr) + .expect("Unable to start RPC server"); + server.wait(); + }); + + let rpc_addr = receiver.recv().unwrap(); + let rpc_client = RpcClient::new_from_socket(rpc_addr); + + let balance = RpcRequest::GetBalance.retry_make_rpc_request( + &rpc_client, + 1, + Some(json!(["deadbeefXjn8o3yroDHxUtKsZZgoy4GPkPPXfouKNHhw"])), + 10, + ); + assert!(balance.is_ok()); + assert_eq!(balance.unwrap().as_u64().unwrap(), 5); + } }