diff --git a/Cargo.lock b/Cargo.lock index 2087d1e..2adb230 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -363,9 +363,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68803225a7b13e47191bab76f2687382b60d259e8cf37f6e1893658b84bb9479" +checksum = "ee67c11feeac938fae061b232e38e0b6d94f97a9df10e6271319325ac4c56a86" [[package]] name = "arrayvec" @@ -382,6 +382,16 @@ dependencies = [ "event-listener", ] +[[package]] +name = "async-rwlock" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261803dcc39ba9e72760ba6e16d0199b1eef9fc44e81bffabbebb9f5aea3906c" +dependencies = [ + "async-mutex", + "event-listener", +] + [[package]] name = "async-trait" version = "0.1.42" @@ -1247,9 +1257,9 @@ dependencies = [ [[package]] name = "itoa" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc6f3ad7b9d11a0c00842ff8de1b60ee58661048eb8049ed33c73594f359d7e6" +checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" [[package]] name = "js-sys" @@ -1989,10 +1999,10 @@ dependencies = [ "ammonia", "anyhow", "async-mutex", + "async-rwlock", "async-trait", "background-jobs", "base64 0.13.0", - "bytes", "chrono", "config", "deadpool", @@ -2016,7 +2026,6 @@ dependencies = [ "sha2", "structopt", "thiserror", - "tokio", "tokio-postgres", "ttl_cache", "uuid", @@ -2241,9 +2250,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.60" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1500e84d27fe482ed1dc791a56eddc2f230046a040fa908c08bda1d9fb615779" +checksum = "4fceb2595057b6891a4ee808f70054bd2d12f0e97f1cbb78689b59f676df325a" dependencies = [ "itoa", "ryu", @@ -2495,9 +2504,9 @@ checksum = "1e81da0851ada1f3e9d4312c704aa4f8806f0f9d69faaf8df2f3464b4a9437c2" [[package]] name = "syn" -version = "1.0.55" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a571a711dddd09019ccc628e1b17fe87c59b09d513c06c026877aa708334f37a" +checksum = "a9802ddde94170d186eeee5005b798d9c159fa970403f1be19976d0cfb939b72" dependencies = [ "proc-macro2", "quote", @@ -2547,18 +2556,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.22" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e9ae34b84616eedaaf1e9dd6026dbe00dcafa92aa0c8077cb69df1fcfe5e53e" +checksum = "76cc616c6abf8c8928e2fdcc0dbfab37175edd8fb49a4641066ad1364fdab146" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.22" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ba20f23e85b10754cd195504aebf6a27e2e6cbe28c17778a0c930724628dd56" +checksum = "9be73a2caec27583d0046ef3796c3794f868a5bc813db689eed00c7631275cd1" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index a7063f9..203c4ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,9 +21,9 @@ activitystreams = "0.7.0-alpha.4" activitystreams-ext = "0.1.0-alpha.2" ammonia = "3.1.0" async-mutex = "1.0.1" +async-rwlock = "1.3.0" async-trait = "0.1.24" background-jobs = "0.8.0" -bytes = "0.5.4" base64 = "0.13" chrono = "0.4.19" config = "0.10.1" @@ -47,7 +47,6 @@ serde_json = "1.0" sha2 = "0.9" structopt = "0.3.12" thiserror = "1.0" -tokio = { version = "0.2.13", features = ["sync"] } tokio-postgres = { version = "0.5.1", features = ["with-serde_json-1", "with-uuid-0_8", "with-chrono-0_4"] } ttl_cache = "0.5.1" uuid = { version = "0.8", features = ["v4", "serde"] } diff --git a/src/data/actor.rs b/src/data/actor.rs index 48cb67a..b9cdc64 100644 --- a/src/data/actor.rs +++ b/src/data/actor.rs @@ -1,8 +1,8 @@ use crate::{apub::AcceptedActors, db::Db, error::MyError, requests::Requests}; use activitystreams::{prelude::*, uri, url::Url}; +use async_rwlock::RwLock; use log::error; use std::{collections::HashSet, sync::Arc, time::Duration}; -use tokio::sync::RwLock; use ttl_cache::TtlCache; use uuid::Uuid; diff --git a/src/data/media.rs b/src/data/media.rs index 1fe838b..a8fcfae 100644 --- a/src/data/media.rs +++ b/src/data/media.rs @@ -1,11 +1,11 @@ use crate::{db::Db, error::MyError}; use activitystreams::url::Url; +use actix_web::web::Bytes; use async_mutex::Mutex; -use bytes::Bytes; +use async_rwlock::RwLock; use futures::join; use lru::LruCache; use std::{collections::HashMap, sync::Arc, time::Duration}; -use tokio::sync::RwLock; use ttl_cache::TtlCache; use uuid::Uuid; diff --git a/src/data/node.rs b/src/data/node.rs index 9742d58..eb449c7 100644 --- a/src/data/node.rs +++ b/src/data/node.rs @@ -1,12 +1,12 @@ use crate::{db::Db, error::MyError}; use activitystreams::{uri, url::Url}; +use async_rwlock::RwLock; use log::{debug, error}; use std::{ collections::{HashMap, HashSet}, sync::Arc, time::{Duration, SystemTime}, }; -use tokio::sync::RwLock; use tokio_postgres::types::Json; use uuid::Uuid; diff --git a/src/data/state.rs b/src/data/state.rs index 9192500..bc09eac 100644 --- a/src/data/state.rs +++ b/src/data/state.rs @@ -11,13 +11,13 @@ use actix_rt::{ time::{interval_at, Instant}, }; use actix_web::web; +use async_rwlock::RwLock; use futures::{join, try_join}; use log::{error, info}; use lru::LruCache; use rand::thread_rng; use rsa::{RSAPrivateKey, RSAPublicKey}; use std::{collections::HashSet, sync::Arc, time::Duration}; -use tokio::sync::RwLock; #[derive(Clone)] pub struct State { diff --git a/src/middleware/payload.rs b/src/middleware/payload.rs index 002fa69..fdff1ea 100644 --- a/src/middleware/payload.rs +++ b/src/middleware/payload.rs @@ -1,16 +1,17 @@ use actix_web::{ dev::{Payload, Service, ServiceRequest, Transform}, http::StatusCode, + web::BytesMut, HttpMessage, HttpResponse, ResponseError, }; -use bytes::BytesMut; use futures::{ + channel::mpsc::channel, future::{ok, try_join, LocalBoxFuture, Ready}, + sink::SinkExt, stream::StreamExt, }; use log::{error, info}; use std::task::{Context, Poll}; -use tokio::sync::mpsc::channel; #[derive(Clone, Debug)] pub struct DebugPayload(pub bool); @@ -68,7 +69,7 @@ where fn call(&mut self, mut req: S::Request) -> Self::Future { if self.0 { - let (mut tx, rx) = channel(1); + let (mut tx, rx) = channel(0); let mut pl = req.take_payload(); req.set_payload(Payload::Stream(Box::pin(rx))); diff --git a/src/requests.rs b/src/requests.rs index 4c9f9a5..ef09275 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -1,7 +1,8 @@ use crate::error::MyError; use activitystreams::url::Url; -use actix_web::{client::Client, http::header::Date}; -use bytes::Bytes; +use actix_web::{client::Client, http::header::Date, web::Bytes}; +use async_mutex::Mutex; +use async_rwlock::RwLock; use chrono::{DateTime, Utc}; use http_signature_normalization_actix::prelude::*; use log::{debug, info, warn}; @@ -13,14 +14,14 @@ use std::{ rc::Rc, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, Mutex, + Arc, }, time::SystemTime, }; #[derive(Clone)] pub struct Breakers { - inner: Arc>>, + inner: Arc>>>>, } impl Breakers { @@ -28,32 +29,47 @@ impl Breakers { Self::default() } - fn should_try(&self, url: &Url) -> bool { + async fn should_try(&self, url: &Url) -> bool { if let Some(domain) = url.domain() { - self.inner - .lock() - .expect("Breakers poisoned") - .get(domain) - .map(|breaker| breaker.should_try()) - .unwrap_or(true) + if let Some(breaker) = self.inner.read().await.get(domain) { + breaker.lock().await.should_try() + } else { + true + } } else { false } } - fn fail(&self, url: &Url) { + async fn fail(&self, url: &Url) { if let Some(domain) = url.domain() { - let mut hm = self.inner.lock().expect("Breakers poisoned"); - let entry = hm.entry(domain.to_owned()).or_insert(Breaker::default()); - entry.fail(); + if let Some(breaker) = self.inner.read().await.get(domain) { + let owned_breaker = Arc::clone(&breaker); + drop(breaker); + owned_breaker.lock().await.fail(); + } else { + let mut hm = self.inner.write().await; + let breaker = hm + .entry(domain.to_owned()) + .or_insert(Arc::new(Mutex::new(Breaker::default()))); + breaker.lock().await.fail(); + } } } - fn succeed(&self, url: &Url) { + async fn succeed(&self, url: &Url) { if let Some(domain) = url.domain() { - let mut hm = self.inner.lock().expect("Breakers poisoned"); - let entry = hm.entry(domain.to_owned()).or_insert(Breaker::default()); - entry.succeed(); + if let Some(breaker) = self.inner.read().await.get(domain) { + let owned_breaker = Arc::clone(&breaker); + drop(breaker); + owned_breaker.lock().await.succeed(); + } else { + let mut hm = self.inner.write().await; + let breaker = hm + .entry(domain.to_owned()) + .or_insert(Arc::new(Mutex::new(Breaker::default()))); + breaker.lock().await.succeed(); + } } } } @@ -61,7 +77,7 @@ impl Breakers { impl Default for Breakers { fn default() -> Self { Breakers { - inner: Arc::new(Mutex::new(HashMap::new())), + inner: Arc::new(RwLock::new(HashMap::new())), } } } @@ -180,7 +196,7 @@ impl Requests { { let parsed_url = url.parse::()?; - if !self.breakers.should_try(&parsed_url) { + if !self.breakers.should_try(&parsed_url).await { return Err(MyError::Breaker); } @@ -202,7 +218,7 @@ impl Requests { if res.is_err() { self.count_err(); - self.breakers.fail(&parsed_url); + self.breakers.fail(&parsed_url).await; } let mut res = res.map_err(|e| MyError::SendRequest(url.to_string(), e.to_string()))?; @@ -218,12 +234,12 @@ impl Requests { } } - self.breakers.fail(&parsed_url); + self.breakers.fail(&parsed_url).await; return Err(MyError::Status(url.to_string(), res.status())); } - self.breakers.succeed(&parsed_url); + self.breakers.succeed(&parsed_url).await; let body = res .body() @@ -236,7 +252,7 @@ impl Requests { pub async fn fetch_bytes(&self, url: &str) -> Result<(String, Bytes), MyError> { let parsed_url = url.parse::()?; - if !self.breakers.should_try(&parsed_url) { + if !self.breakers.should_try(&parsed_url).await { return Err(MyError::Breaker); } @@ -258,7 +274,7 @@ impl Requests { .await; if res.is_err() { - self.breakers.fail(&parsed_url); + self.breakers.fail(&parsed_url).await; self.count_err(); } @@ -285,12 +301,12 @@ impl Requests { } } - self.breakers.fail(&parsed_url); + self.breakers.fail(&parsed_url).await; return Err(MyError::Status(url.to_string(), res.status())); } - self.breakers.succeed(&parsed_url); + self.breakers.succeed(&parsed_url).await; let bytes = match res.body().limit(1024 * 1024 * 4).await { Err(e) => { @@ -306,7 +322,7 @@ impl Requests { where T: serde::ser::Serialize, { - if !self.breakers.should_try(&inbox) { + if !self.breakers.should_try(&inbox).await { return Err(MyError::Breaker); } @@ -332,7 +348,7 @@ impl Requests { if res.is_err() { self.count_err(); - self.breakers.fail(&inbox); + self.breakers.fail(&inbox).await; } let mut res = res.map_err(|e| MyError::SendRequest(inbox.to_string(), e.to_string()))?; @@ -348,11 +364,11 @@ impl Requests { } } - self.breakers.fail(&inbox); + self.breakers.fail(&inbox).await; return Err(MyError::Status(inbox.to_string(), res.status())); } - self.breakers.succeed(&inbox); + self.breakers.succeed(&inbox).await; Ok(()) } diff --git a/src/routes/media.rs b/src/routes/media.rs index e40a823..4cdd625 100644 --- a/src/routes/media.rs +++ b/src/routes/media.rs @@ -3,7 +3,6 @@ use actix_web::{ http::header::{CacheControl, CacheDirective}, web, HttpResponse, }; -use bytes::Bytes; use uuid::Uuid; pub async fn route( @@ -30,7 +29,7 @@ pub async fn route( Ok(HttpResponse::NotFound().finish()) } -fn cached(content_type: String, bytes: Bytes) -> HttpResponse { +fn cached(content_type: String, bytes: web::Bytes) -> HttpResponse { HttpResponse::Ok() .set(CacheControl(vec![ CacheDirective::Public,