Azure / iot-identity-service

Source of the Azure IoT Identity Service and related services.
MIT License
37 stars 46 forks source link

Add exponential backoff for IotHub throttle #502

Closed lfitchett closed 1 year ago

lfitchett commented 1 year ago

Adds exponential backoff to identityd. This has caused several icms when many devices are updated at once. Previously the devices would attempt to connect to hub every 10-20 seconds, causing the throttle to persist forever.

Now devices will go into an exponential backoff when throttled. This only affects 429 (too many requests) errors, all other errors use the old logic.

As a result of the backoff change, the edged timeout has been increased to 10 minutes.

This also moves the request payload download inside the timeout. This has never been an issue before, but it is more correct.

onalante-msft commented 1 year ago

Alternative implementation:

PATCH ```diff diff --git a/Cargo.lock b/Cargo.lock index 650835c..cce9c70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2160,12 +2160,6 @@ version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" -[[package]] -name = "ppv-lite86" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" - [[package]] name = "proc-macro-error" version = "1.0.4" @@ -2214,18 +2208,6 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "libc", - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", "rand_core", ] diff --git a/http-common/Cargo.toml b/http-common/Cargo.toml index 21b3fd6..5adf28f 100644 --- a/http-common/Cargo.toml +++ b/http-common/Cargo.toml @@ -19,7 +19,7 @@ nix = "0.24" openssl = { version = "0.10" } openssl-sys = { version = "0.9" } percent-encoding = "2" -rand = "0.8.5" +rand = { version = "0.8", features = ["getrandom"], default-features = false } serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1", features = ["net", "rt-multi-thread", "sync", "time"] } diff --git a/http-common/src/backoff.rs b/http-common/src/backoff.rs deleted file mode 100644 index 86b8ca0..0000000 --- a/http-common/src/backoff.rs +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -use std::time::Duration; - -use rand::Rng; - -pub const DEFAULT_BACKOFF: Backoff<4> = Backoff { - pattern: [ - BackoffInstance::from_secs(60, 10), - BackoffInstance::from_secs(120, 20), - BackoffInstance::from_secs(180, 30), - BackoffInstance::from_secs(300, 30), - ], -}; - -pub struct Backoff { - pattern: [BackoffInstance; N], -} - -impl Backoff { - #[allow(clippy::unused_self, clippy::cast_possible_truncation)] - pub fn max_retries(&self) -> u32 { - N as u32 - } - - /// Computes backoff for current try. Returns None if no retry attempts left - pub fn get_backoff_duration(&self, current_attempt: u32) -> Option { - self.pattern - .get(current_attempt as usize - 1) - .map(BackoffInstance::backoff_duration) - } -} - -pub struct BackoffInstance { - duration: Duration, - max_jitter: Duration, -} - -impl BackoffInstance { - const fn from_secs(duration: u64, max_jitter: u64) -> Self { - Self { - duration: Duration::from_secs(duration), - max_jitter: Duration::from_secs(max_jitter), - } - } - - fn backoff_duration(&self) -> Duration { - let mut rng = rand::thread_rng(); - let jitter_multiple = rng.gen_range(0.0..1.0); - - self.duration + self.max_jitter.mul_f32(jitter_multiple) - } -} diff --git a/http-common/src/lib.rs b/http-common/src/lib.rs index cf683f6..1e35f1e 100644 --- a/http-common/src/lib.rs +++ b/http-common/src/lib.rs @@ -30,8 +30,6 @@ pub use request::HttpRequest; pub mod server; -mod backoff; - mod uid; /// Ref diff --git a/http-common/src/request.rs b/http-common/src/request.rs index 7b57ae3..c8b7ae1 100644 --- a/http-common/src/request.rs +++ b/http-common/src/request.rs @@ -1,10 +1,14 @@ // Copyright (c) Microsoft. All rights reserved. use std::io::{Error, ErrorKind}; +use std::time::Duration; -use crate::backoff::DEFAULT_BACKOFF; +use rand::Rng; const CONTENT_TYPE_JSON: &str = "application/json"; +const INITIAL_BACKOFF: Duration = Duration::from_secs(60); +const BACKOFF_FACTOR: f32 = 1.5; +const JITTER_RATIO: f32 = 0.25; pub struct HttpRequest { connector: TConnector, @@ -87,11 +91,7 @@ where self } - pub fn add_header( - &mut self, - name: hyper::header::HeaderName, - value: &str, - ) -> Result<(), Error> { + pub fn add_header(&mut self, name: http::header::HeaderName, value: &str) -> Result<(), Error> { let value = http::HeaderValue::from_str(value) .map_err(|err| Error::new(ErrorKind::InvalidInput, err))?; @@ -131,6 +131,8 @@ where .to_str() .map_err(|err| Error::new(ErrorKind::InvalidData, err))?; + // NOTE: `str::contains` since the content type can be + // followed by additional data like `charset`. content_type.contains(CONTENT_TYPE_JSON) } else { false @@ -162,15 +164,23 @@ where > { let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(self.connector); - let mut current_attempt = 1; - - loop { + let attempt_limit = self.retries + 1; + let mut attempt = 0; + let mut backoff = 0; + let result = loop { + attempt += 1; let mut request = hyper::Request::builder() .method(&self.method) .uri(&self.uri); + let headers = request + .headers_mut() + .expect("cannot fail to get request headers"); let request_body = if let Some(body) = &self.body { - request = request.header(hyper::header::CONTENT_TYPE, CONTENT_TYPE_JSON); + headers.insert( + http::header::CONTENT_TYPE, + http::header::HeaderValue::from_static(CONTENT_TYPE_JSON), + ); serde_json::to_vec(body) .expect("cannot fail to serialize request") @@ -179,107 +189,68 @@ where hyper::Body::default() }; - for (header_name, header_value) in &self.headers { - request = request.header(header_name, header_value); - } + headers.extend(self.headers.clone()); let request = request .body(request_body) .expect("cannot fail to create request"); - let (err, backoff_exponential) = - match tokio::time::timeout(self.timeout, client.request(request)).await { - Ok(response) => { - match response { - Ok(response) => { - let ( - http::response::Parts { - status: response_status, - headers: response_headers, - .. - }, - response_body, - ) = response.into_parts(); - - // Make sure to download body inside the timeout - let response_body = if has_response_body { - let response_body = hyper::body::to_bytes(response_body) - .await - .map_err(|err| Error::new(ErrorKind::Other, err))?; - - Some(response_body) - } else { - None - }; - - // if response throttled, go into exponential backoff - if response_status == http::StatusCode::TOO_MANY_REQUESTS { - ( - std::io::Error::new( - std::io::ErrorKind::Other, - "429: Too many requests", - ), - true, - ) - } else { - // Return results - return Ok((response_status, response_headers, response_body)); - } - } - Err(err) => { - if err.is_connect() { - // Network error. - ( - std::io::Error::new(std::io::ErrorKind::NotConnected, err), - false, - ) - } else { - (std::io::Error::new(std::io::ErrorKind::Other, err), false) - } - } - } + let mut wait = Duration::from_secs(3); + let err = match tokio::time::timeout(self.timeout, client.request(request)).await { + Ok(Ok(response)) => { + if response.status() == http::StatusCode::TOO_MANY_REQUESTS { + backoff += 1; + let base_backoff = INITIAL_BACKOFF.mul_f32(BACKOFF_FACTOR.powi(backoff)); + let wait_ratio = + 1.0 + rand::rngs::OsRng.gen_range(-JITTER_RATIO..JITTER_RATIO); + wait = base_backoff.mul_f32(wait_ratio); + Error::new(ErrorKind::Other, "too many requests") + } else { + break Ok(response); } - - Err(timeout) => (timeout.into(), false), - }; - - if backoff_exponential { - if let Some(backoff_duration) = - DEFAULT_BACKOFF.get_backoff_duration(current_attempt) - { - log::warn!( - "HTTP request throttled (attempt {} of {}). Sleeping for {} seconds.", - current_attempt, - DEFAULT_BACKOFF.max_retries() + 1, - backoff_duration.as_secs() - ); - tokio::time::sleep(backoff_duration).await; - } else { - log::warn!( - "Final HTTP request throttled (attempt {} of {}).", - current_attempt, - DEFAULT_BACKOFF.max_retries() + 1, - ); - return Err(err); } - } else { + Ok(Err(err)) => Error::new( + if err.is_connect() { + ErrorKind::NotConnected + } else { + ErrorKind::Other + }, + err, + ), + Err(err) => Error::from(err), + }; + + if attempt < attempt_limit { log::warn!( - "Failed to send HTTP request (attempt {} of {}): {}", - current_attempt, - self.retries + 1, - err + "Failed to send request (attempt {} of {}): {}. Retrying in {} seconds.", + attempt, + attempt_limit, + err, + wait.as_secs(), ); - - if current_attempt > self.retries { - return Err(err); - } - // Wait a short time between failed requests. - tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + tokio::time::sleep(wait).await + } else { + break Err(err); } + }; - current_attempt += 1; - } + let ( + http::response::Parts { + status, headers, .. + }, + response_body, + ) = result?.into_parts(); + let body = if has_response_body { + Some( + hyper::body::to_bytes(response_body) + .await + .map_err(|err| Error::new(ErrorKind::Other, err))?, + ) + } else { + None + }; + Ok((status, headers, body)) } } ```