relay/src/state.rs

217 lines
6 KiB
Rust
Raw Normal View History

2020-03-15 02:05:40 +00:00
use activitystreams::primitives::XsdAnyUri;
use anyhow::Error;
use bb8_postgres::tokio_postgres::{row::Row, Client};
use futures::try_join;
2020-03-15 16:29:01 +00:00
use lru::LruCache;
2020-03-15 02:05:40 +00:00
use std::{collections::HashSet, sync::Arc};
2020-03-15 16:29:01 +00:00
use tokio::sync::RwLock;
use ttl_cache::TtlCache;
2020-03-15 02:05:40 +00:00
2020-03-15 16:29:01 +00:00
use crate::{apub::AcceptedActors, db_actor::Pool};
2020-03-15 02:05:40 +00:00
#[derive(Clone)]
pub struct State {
2020-03-15 17:49:27 +00:00
whitelist_enabled: bool,
2020-03-15 16:29:01 +00:00
actor_cache: Arc<RwLock<TtlCache<XsdAnyUri, AcceptedActors>>>,
actor_id_cache: Arc<RwLock<LruCache<XsdAnyUri, XsdAnyUri>>>,
2020-03-15 17:49:27 +00:00
blocks: Arc<RwLock<HashSet<String>>>,
whitelists: Arc<RwLock<HashSet<String>>>,
2020-03-15 02:05:40 +00:00
listeners: Arc<RwLock<HashSet<XsdAnyUri>>>,
}
2020-03-15 17:49:27 +00:00
#[derive(Clone, Debug, thiserror::Error)]
#[error("No host present in URI")]
pub struct HostError;
2020-03-15 02:05:40 +00:00
impl State {
2020-03-15 17:49:27 +00:00
pub async fn is_whitelisted(&self, actor_id: &XsdAnyUri) -> bool {
if !self.whitelist_enabled {
return true;
}
let hs = self.whitelists.clone();
if let Some(host) = actor_id.as_url().host() {
let read_guard = hs.read().await;
return read_guard.contains(&host.to_string());
}
false
}
pub async fn is_blocked(&self, actor_id: &XsdAnyUri) -> bool {
let hs = self.blocks.clone();
if let Some(host) = actor_id.as_url().host() {
let read_guard = hs.read().await;
return read_guard.contains(&host.to_string());
}
true
}
pub async fn is_listener(&self, actor_id: &XsdAnyUri) -> bool {
let hs = self.listeners.clone();
let read_guard = hs.read().await;
read_guard.contains(actor_id)
}
2020-03-15 16:29:01 +00:00
pub async fn get_actor(&self, actor_id: &XsdAnyUri) -> Option<AcceptedActors> {
let cache = self.actor_cache.clone();
2020-03-15 02:05:40 +00:00
2020-03-15 16:29:01 +00:00
let read_guard = cache.read().await;
read_guard.get(actor_id).cloned()
}
pub async fn cache_actor(&self, actor_id: XsdAnyUri, actor: AcceptedActors) {
let cache = self.actor_cache.clone();
let mut write_guard = cache.write().await;
write_guard.insert(actor_id, actor, std::time::Duration::from_secs(3600));
}
pub async fn is_cached(&self, object_id: &XsdAnyUri) -> bool {
let cache = self.actor_id_cache.clone();
let read_guard = cache.read().await;
read_guard.contains(object_id)
2020-03-15 02:05:40 +00:00
}
pub async fn cache(&self, object_id: XsdAnyUri, actor_id: XsdAnyUri) {
2020-03-15 16:29:01 +00:00
let cache = self.actor_id_cache.clone();
2020-03-15 02:05:40 +00:00
2020-03-15 16:29:01 +00:00
let mut write_guard = cache.write().await;
write_guard.put(object_id, actor_id);
2020-03-15 02:05:40 +00:00
}
pub async fn add_block(&self, client: &Client, block: XsdAnyUri) -> Result<(), Error> {
let blocks = self.blocks.clone();
2020-03-15 17:49:27 +00:00
let host = if let Some(host) = block.as_url().host() {
host
} else {
return Err(HostError.into());
};
2020-03-15 02:05:40 +00:00
client
.execute(
"INSERT INTO blocks (actor_id, created_at) VALUES ($1::TEXT, now);",
2020-03-15 17:49:27 +00:00
&[&host.to_string()],
2020-03-15 02:05:40 +00:00
)
.await?;
let mut write_guard = blocks.write().await;
2020-03-15 17:49:27 +00:00
write_guard.insert(host.to_string());
2020-03-15 02:05:40 +00:00
Ok(())
}
pub async fn add_whitelist(&self, client: &Client, whitelist: XsdAnyUri) -> Result<(), Error> {
let whitelists = self.whitelists.clone();
2020-03-15 17:49:27 +00:00
let host = if let Some(host) = whitelist.as_url().host() {
host
} else {
return Err(HostError.into());
};
2020-03-15 02:05:40 +00:00
client
.execute(
"INSERT INTO whitelists (actor_id, created_at) VALUES ($1::TEXT, now);",
2020-03-15 17:49:27 +00:00
&[&host.to_string()],
2020-03-15 02:05:40 +00:00
)
.await?;
let mut write_guard = whitelists.write().await;
2020-03-15 17:49:27 +00:00
write_guard.insert(host.to_string());
2020-03-15 02:05:40 +00:00
Ok(())
}
pub async fn add_listener(&self, client: &Client, listener: XsdAnyUri) -> Result<(), Error> {
let listeners = self.listeners.clone();
client
.execute(
"INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, now);",
2020-03-15 17:49:27 +00:00
&[&listener.as_str()],
2020-03-15 02:05:40 +00:00
)
.await?;
let mut write_guard = listeners.write().await;
write_guard.insert(listener);
Ok(())
}
2020-03-15 17:49:27 +00:00
pub async fn hydrate(whitelist_enabled: bool, pool: Pool) -> Result<Self, Error> {
2020-03-15 02:05:40 +00:00
let pool1 = pool.clone();
let pool2 = pool.clone();
let f1 = async move {
let conn = pool.get().await?;
hydrate_blocks(&conn).await
};
let f2 = async move {
let conn = pool1.get().await?;
hydrate_whitelists(&conn).await
};
let f3 = async move {
let conn = pool2.get().await?;
hydrate_listeners(&conn).await
};
let (blocks, whitelists, listeners) = try_join!(f1, f2, f3)?;
Ok(State {
2020-03-15 17:49:27 +00:00
whitelist_enabled,
2020-03-15 16:29:01 +00:00
actor_cache: Arc::new(RwLock::new(TtlCache::new(1024 * 8))),
actor_id_cache: Arc::new(RwLock::new(LruCache::new(1024 * 8))),
2020-03-15 02:05:40 +00:00
blocks: Arc::new(RwLock::new(blocks)),
whitelists: Arc::new(RwLock::new(whitelists)),
listeners: Arc::new(RwLock::new(listeners)),
})
}
}
2020-03-15 17:49:27 +00:00
pub async fn hydrate_blocks(client: &Client) -> Result<HashSet<String>, Error> {
let rows = client.query("SELECT domain_name FROM blocks", &[]).await?;
2020-03-15 02:05:40 +00:00
parse_rows(rows)
}
2020-03-15 17:49:27 +00:00
pub async fn hydrate_whitelists(client: &Client) -> Result<HashSet<String>, Error> {
let rows = client
.query("SELECT domain_name FROM whitelists", &[])
.await?;
2020-03-15 02:05:40 +00:00
parse_rows(rows)
}
pub async fn hydrate_listeners(client: &Client) -> Result<HashSet<XsdAnyUri>, Error> {
let rows = client.query("SELECT actor_id FROM listeners", &[]).await?;
parse_rows(rows)
}
2020-03-15 17:49:27 +00:00
pub fn parse_rows<T>(rows: Vec<Row>) -> Result<HashSet<T>, Error>
where
T: std::str::FromStr + Eq + std::hash::Hash,
{
2020-03-15 02:05:40 +00:00
let hs = rows
.into_iter()
.filter_map(move |row| {
2020-03-15 17:49:27 +00:00
let s: String = row.try_get(0).ok()?;
2020-03-15 02:05:40 +00:00
s.parse().ok()
})
.collect();
Ok(hs)
}