diff --git a/migrations/2020-03-16-160336_add_triggers/down.sql b/migrations/2020-03-16-160336_add_triggers/down.sql new file mode 100644 index 0000000..8b8f803 --- /dev/null +++ b/migrations/2020-03-16-160336_add_triggers/down.sql @@ -0,0 +1,8 @@ +-- This file should undo anything in `up.sql` +DROP TRIGGER IF EXISTS whitelists_notify ON whitelists; +DROP TRIGGER IF EXISTS blocks_notify ON blocks; +DROP TRIGGER IF EXISTS listeners_notify ON listeners; + +DROP FUNCTION IF EXISTS invoke_whitelists_trigger(); +DROP FUNCTION IF EXISTS invoke_blocks_trigger(); +DROP FUNCTION IF EXISTS invoke_listeners_trigger(); diff --git a/migrations/2020-03-16-160336_add_triggers/up.sql b/migrations/2020-03-16-160336_add_triggers/up.sql new file mode 100644 index 0000000..5520a34 --- /dev/null +++ b/migrations/2020-03-16-160336_add_triggers/up.sql @@ -0,0 +1,99 @@ +-- Your SQL goes here +CREATE OR REPLACE FUNCTION invoke_listeners_trigger () + RETURNS TRIGGER + LANGUAGE plpgsql +AS $$ +DECLARE + rec RECORD; + channel TEXT; + payload TEXT; +BEGIN + case TG_OP + WHEN 'INSERT' THEN + rec := NEW; + channel := 'new_listeners'; + payload := NEW.actor_id; + WHEN 'DELETE' THEN + rec := OLD; + channel := 'rm_listeners'; + payload := OLD.actor_id; + ELSE + RAISE EXCEPTION 'Unknown TG_OP: "%". Should not occur!', TG_OP; + END CASE; + + PERFORM pg_notify(channel, payload::TEXT); + RETURN rec; +END; +$$; + +CREATE OR REPLACE FUNCTION invoke_blocks_trigger () + RETURNS TRIGGER + LANGUAGE plpgsql +AS $$ +DECLARE + rec RECORD; + channel TEXT; + payload TEXT; +BEGIN + case TG_OP + WHEN 'INSERT' THEN + rec := NEW; + channel := 'new_blocks'; + payload := NEW.domain_name; + WHEN 'DELETE' THEN + rec := OLD; + channel := 'rm_blocks'; + payload := OLD.domain_name; + ELSE + RAISE EXCEPTION 'Unknown TG_OP: "%". Should not occur!', TG_OP; + END CASE; + + PERFORM pg_notify(channel, payload::TEXT); + RETURN NULL; +END; +$$; + +CREATE OR REPLACE FUNCTION invoke_whitelists_trigger () + RETURNS TRIGGER + LANGUAGE plpgsql +AS $$ +DECLARE + rec RECORD; + channel TEXT; + payload TEXT; +BEGIN + case TG_OP + WHEN 'INSERT' THEN + rec := NEW; + channel := 'new_whitelists'; + payload := NEW.domain_name; + WHEN 'DELETE' THEN + rec := OLD; + channel := 'rm_whitelists'; + payload := OLD.domain_name; + ELSE + RAISE EXCEPTION 'Unknown TG_OP: "%". Should not occur!', TG_OP; + END CASE; + + PERFORM pg_notify(channel, payload::TEXT); + RETURN rec; +END; +$$; + +CREATE TRIGGER listeners_notify + AFTER INSERT OR UPDATE OR DELETE + ON listeners +FOR EACH ROW + EXECUTE PROCEDURE invoke_listeners_trigger(); + +CREATE TRIGGER blocks_notify + AFTER INSERT OR UPDATE OR DELETE + ON blocks +FOR EACH ROW + EXECUTE PROCEDURE invoke_blocks_trigger(); + +CREATE TRIGGER whitelists_notify + AFTER INSERT OR UPDATE OR DELETE + ON whitelists +FOR EACH ROW + EXECUTE PROCEDURE invoke_whitelists_trigger(); diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..b37906d --- /dev/null +++ b/src/db.rs @@ -0,0 +1,158 @@ +use activitystreams::primitives::XsdAnyUri; +use anyhow::Error; +use bb8_postgres::tokio_postgres::{row::Row, Client}; +use log::info; +use rsa::RSAPrivateKey; +use rsa_pem::KeyExt; +use std::collections::HashSet; + +#[derive(Clone, Debug, thiserror::Error)] +#[error("No host present in URI")] +pub struct HostError; + +pub async fn listen(client: &Client) -> Result<(), Error> { + client + .batch_execute( + "LISTEN new_blocks; + LISTEN new_whitelists; + LISTEN new_listeners; + LISTEN rm_blocks; + LISTEN rm_whitelists; + LISTEN rm_listeners;", + ) + .await?; + + Ok(()) +} + +pub async fn hydrate_private_key(client: &Client) -> Result, Error> { + info!("SELECT value FROM settings WHERE key = 'private_key'"); + let rows = client + .query("SELECT value FROM settings WHERE key = 'private_key'", &[]) + .await?; + + if let Some(row) = rows.into_iter().next() { + let key_str: String = row.get(0); + return Ok(Some(KeyExt::from_pem_pkcs8(&key_str)?)); + } + + Ok(None) +} + +pub async fn update_private_key(client: &Client, key: &RSAPrivateKey) -> Result<(), Error> { + let pem_pkcs8 = key.to_pem_pkcs8()?; + + info!("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');"); + client.execute("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');", &[&pem_pkcs8]).await?; + Ok(()) +} + +pub async fn add_block(client: &Client, block: &XsdAnyUri) -> Result<(), Error> { + let host = if let Some(host) = block.as_url().host() { + host + } else { + return Err(HostError.into()); + }; + + info!( + "INSERT INTO blocks (domain_name, created_at) VALUES ($1::TEXT, 'now'); [{}]", + host.to_string() + ); + client + .execute( + "INSERT INTO blocks (domain_name, created_at) VALUES ($1::TEXT, 'now');", + &[&host.to_string()], + ) + .await?; + + Ok(()) +} + +pub async fn add_whitelist(client: &Client, whitelist: &XsdAnyUri) -> Result<(), Error> { + let host = if let Some(host) = whitelist.as_url().host() { + host + } else { + return Err(HostError.into()); + }; + + info!( + "INSERT INTO whitelists (domain_name, created_at) VALUES ($1::TEXT, 'now'); [{}]", + host.to_string() + ); + client + .execute( + "INSERT INTO whitelists (domain_name, created_at) VALUES ($1::TEXT, 'now');", + &[&host.to_string()], + ) + .await?; + + Ok(()) +} + +pub async fn remove_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), Error> { + info!( + "DELETE FROM listeners WHERE actor_id = {};", + listener.as_str() + ); + client + .execute( + "DELETE FROM listeners WHERE actor_id = $1::TEXT;", + &[&listener.as_str()], + ) + .await?; + + Ok(()) +} + +pub async fn add_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), Error> { + info!( + "INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now'); [{}]", + listener.as_str(), + ); + client + .execute( + "INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now');", + &[&listener.as_str()], + ) + .await?; + + Ok(()) +} + +pub async fn hydrate_blocks(client: &Client) -> Result, Error> { + info!("SELECT domain_name FROM blocks"); + let rows = client.query("SELECT domain_name FROM blocks", &[]).await?; + + parse_rows(rows) +} + +pub async fn hydrate_whitelists(client: &Client) -> Result, Error> { + info!("SELECT domain_name FROM whitelists"); + let rows = client + .query("SELECT domain_name FROM whitelists", &[]) + .await?; + + parse_rows(rows) +} + +pub async fn hydrate_listeners(client: &Client) -> Result, Error> { + info!("SELECT actor_id FROM listeners"); + let rows = client.query("SELECT actor_id FROM listeners", &[]).await?; + + parse_rows(rows) +} + +fn parse_rows(rows: Vec) -> Result, Error> +where + T: std::str::FromStr + Eq + std::hash::Hash, +{ + let hs = rows + .into_iter() + .filter_map(move |row| { + let s: String = row.try_get(0).ok()?; + s.parse().ok() + }) + .collect(); + + Ok(hs) +} diff --git a/src/inbox.rs b/src/inbox.rs index f796993..7f53989 100644 --- a/src/inbox.rs +++ b/src/inbox.rs @@ -63,17 +63,18 @@ async fn handle_undo( let inbox = actor.inbox().to_owned(); - let state2 = state.clone().into_inner(); db_actor.do_send(DbQuery(move |pool: Pool| { let inbox = inbox.clone(); async move { let conn = pool.get().await?; - state2.remove_listener(&conn, &inbox).await.map_err(|e| { - error!("Error removing listener, {}", e); - e - }) + crate::db::remove_listener(&conn, &inbox) + .await + .map_err(|e| { + error!("Error removing listener, {}", e); + e + }) } })); @@ -181,17 +182,14 @@ async fn handle_follow( } if !is_listener { - let state = state.clone().into_inner(); - let inbox = actor.inbox().to_owned(); db_actor.do_send(DbQuery(move |pool: Pool| { let inbox = inbox.clone(); - let state = state.clone(); async move { let conn = pool.get().await?; - state.add_listener(&conn, inbox).await.map_err(|e| { + crate::db::add_listener(&conn, &inbox).await.map_err(|e| { error!("Error adding listener, {}", e); e }) diff --git a/src/main.rs b/src/main.rs index 55e7132..b40f81c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,10 +6,12 @@ use rsa_pem::KeyExt; use sha2::{Digest, Sha256}; mod apub; +mod db; mod db_actor; mod error; mod inbox; mod label; +mod notify; mod state; mod verifier; mod webfinger; @@ -62,7 +64,7 @@ async fn actor_route(state: web::Data) -> Result #[actix_rt::main] async fn main() -> Result<(), anyhow::Error> { dotenv::dotenv().ok(); - std::env::set_var("RUST_LOG", "info"); + std::env::set_var("RUST_LOG", "debug"); pretty_env_logger::init(); let pg_config: tokio_postgres::Config = std::env::var("DATABASE_URL")?.parse()?; @@ -82,6 +84,8 @@ async fn main() -> Result<(), anyhow::Error> { .await? .await??; + let _ = notify::NotifyHandler::start_handler(state.clone(), pg_config.clone()); + HttpServer::new(move || { let actor = DbActor::new(pg_config.clone()); arbiter_labeler.clone().set_label(); diff --git a/src/notify.rs b/src/notify.rs new file mode 100644 index 0000000..c6d1b4b --- /dev/null +++ b/src/notify.rs @@ -0,0 +1,161 @@ +use crate::state::State; +use activitystreams::primitives::XsdAnyUri; +use actix::prelude::*; +use bb8_postgres::tokio_postgres::{tls::NoTls, AsyncMessage, Client, Config, Notification}; +use futures::{ + future::ready, + stream::{poll_fn, StreamExt}, +}; +use log::{debug, error, info}; +use tokio::sync::mpsc; + +#[derive(Message)] +#[rtype(result = "()")] +pub struct Notify(Notification); + +pub struct NotifyHandler { + client: Option, + state: State, + config: Config, +} + +impl NotifyHandler { + fn new(state: State, config: Config) -> Self { + NotifyHandler { + state, + config, + client: None, + } + } + + pub fn start_handler(state: State, config: Config) -> Addr { + Supervisor::start(|_| Self::new(state, config)) + } +} + +impl Actor for NotifyHandler { + type Context = Context; + + fn started(&mut self, ctx: &mut Self::Context) { + let config = self.config.clone(); + + let fut = async move { + let (client, mut conn) = match config.connect(NoTls).await { + Ok((client, conn)) => (client, conn), + Err(e) => { + error!("Error establishing DB Connection, {}", e); + return Err(()); + } + }; + + let mut stream = poll_fn(move |cx| conn.poll_message(cx)).filter_map(|m| match m { + Ok(AsyncMessage::Notification(n)) => { + debug!("Handling Notification, {:?}", n); + ready(Some(Notify(n))) + } + Ok(AsyncMessage::Notice(e)) => { + debug!("Handling Notice, {:?}", e); + ready(None) + } + Err(e) => { + debug!("Handling Error, {:?}", e); + ready(None) + } + _ => { + debug!("Handling rest"); + ready(None) + } + }); + + let (mut tx, rx) = mpsc::channel(256); + + Arbiter::spawn(async move { + debug!("Spawned stream handler"); + while let Some(n) = stream.next().await { + match tx.send(n).await { + Err(e) => error!("Error forwarding notification, {}", e), + _ => (), + }; + } + debug!("Stream handler ended"); + }); + + Ok((client, rx)) + }; + + let fut = fut.into_actor(self).map(|res, actor, ctx| match res { + Ok((client, stream)) => { + Self::add_stream(stream, ctx); + let f = async move { + match crate::db::listen(&client).await { + Err(e) => { + error!("Error listening, {}", e); + Err(()) + } + Ok(_) => Ok(client), + } + }; + + ctx.wait(f.into_actor(actor).map(|res, actor, ctx| match res { + Ok(client) => { + actor.client = Some(client); + } + Err(_) => { + ctx.stop(); + } + })); + } + Err(_) => { + ctx.stop(); + } + }); + + ctx.wait(fut); + info!("Listener starting"); + } +} + +impl StreamHandler for NotifyHandler { + fn handle(&mut self, Notify(notif): Notify, ctx: &mut Self::Context) { + let state = self.state.clone(); + + info!("Handling notification in {}", notif.channel()); + let fut = async move { + match notif.channel() { + "new_blocks" => { + debug!("Caching block of {}", notif.payload()); + state.cache_block(notif.payload().to_owned()).await; + } + "new_whitelists" => { + debug!("Caching whitelist of {}", notif.payload()); + state.cache_whitelist(notif.payload().to_owned()).await; + } + "new_listeners" => { + if let Ok(uri) = notif.payload().parse::() { + debug!("Caching listener {}", uri); + state.cache_listener(uri).await; + } + } + "rm_blocks" => { + debug!("Busting block cache for {}", notif.payload()); + state.bust_block(notif.payload()).await; + } + "rm_whitelists" => { + debug!("Busting whitelist cache for {}", notif.payload()); + state.bust_whitelist(notif.payload()).await; + } + "rm_listeners" => { + if let Ok(uri) = notif.payload().parse::() { + debug!("Busting listener cache for {}", uri); + state.bust_listener(&uri).await; + } + } + _ => (), + } + }; + + ctx.spawn(fut.into_actor(self)); + } +} + +impl Supervised for NotifyHandler {} diff --git a/src/state.rs b/src/state.rs index 1a0f146..cd97bec 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,12 +1,11 @@ use activitystreams::primitives::XsdAnyUri; use anyhow::Error; -use bb8_postgres::tokio_postgres::{row::Row, Client}; +use bb8_postgres::tokio_postgres::Client; use futures::try_join; use log::{error, info}; use lru::LruCache; use rand::thread_rng; use rsa::{RSAPrivateKey, RSAPublicKey}; -use rsa_pem::KeyExt; use std::{collections::HashSet, sync::Arc}; use tokio::sync::RwLock; use ttl_cache::TtlCache; @@ -42,10 +41,6 @@ pub enum UrlKind { MainKey, } -#[derive(Clone, Debug, thiserror::Error)] -#[error("No host present in URI")] -pub struct HostError; - #[derive(Clone, Debug, thiserror::Error)] #[error("Error generating RSA key")] pub struct RsaError; @@ -57,14 +52,8 @@ impl Settings { whitelist_enabled: bool, hostname: String, ) -> Result { - info!("SELECT value FROM settings WHERE key = 'private_key'"); - let rows = client - .query("SELECT value FROM settings WHERE key = 'private_key'", &[]) - .await?; - - let private_key = if let Some(row) = rows.into_iter().next() { - let key_str: String = row.get(0); - KeyExt::from_pem_pkcs8(&key_str)? + let private_key = if let Some(key) = crate::db::hydrate_private_key(client).await? { + key } else { info!("Generating new keys"); let mut rng = thread_rng(); @@ -72,10 +61,9 @@ impl Settings { error!("Error generating RSA key, {}", e); RsaError })?; - let pem_pkcs8 = key.to_pem_pkcs8()?; - info!("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');"); - client.execute("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');", &[&pem_pkcs8]).await?; + crate::db::update_private_key(client, &key).await?; + key }; @@ -131,21 +119,25 @@ impl State { self.settings.sign(bytes) } - pub async fn remove_listener(&self, client: &Client, inbox: &XsdAnyUri) -> Result<(), Error> { - let hs = self.listeners.clone(); + pub async fn bust_whitelist(&self, whitelist: &str) { + let hs = self.whitelists.clone(); - info!("DELETE FROM listeners WHERE actor_id = {};", inbox.as_str()); - client - .execute( - "DELETE FROM listeners WHERE actor_id = $1::TEXT;", - &[&inbox.as_str()], - ) - .await?; + let mut write_guard = hs.write().await; + write_guard.remove(whitelist); + } + + pub async fn bust_block(&self, block: &str) { + let hs = self.blocks.clone(); + + let mut write_guard = hs.write().await; + write_guard.remove(block); + } + + pub async fn bust_listener(&self, inbox: &XsdAnyUri) { + let hs = self.listeners.clone(); let mut write_guard = hs.write().await; write_guard.remove(inbox); - - Ok(()) } pub async fn listeners_without(&self, inbox: &XsdAnyUri, domain: &str) -> Vec { @@ -228,76 +220,25 @@ impl State { write_guard.put(object_id, actor_id); } - pub async fn add_block(&self, client: &Client, block: XsdAnyUri) -> Result<(), Error> { + pub async fn cache_block(&self, host: String) { let blocks = self.blocks.clone(); - let host = if let Some(host) = block.as_url().host() { - host - } else { - return Err(HostError.into()); - }; - - info!( - "INSERT INTO blocks (domain_name, created_at) VALUES ($1::TEXT, 'now'); [{}]", - host.to_string() - ); - client - .execute( - "INSERT INTO blocks (domain_name, created_at) VALUES ($1::TEXT, 'now');", - &[&host.to_string()], - ) - .await?; - let mut write_guard = blocks.write().await; - write_guard.insert(host.to_string()); - - Ok(()) + write_guard.insert(host); } - pub async fn add_whitelist(&self, client: &Client, whitelist: XsdAnyUri) -> Result<(), Error> { + pub async fn cache_whitelist(&self, host: String) { let whitelists = self.whitelists.clone(); - let host = if let Some(host) = whitelist.as_url().host() { - host - } else { - return Err(HostError.into()); - }; - - info!( - "INSERT INTO whitelists (domain_name, created_at) VALUES ($1::TEXT, 'now'); [{}]", - host.to_string() - ); - client - .execute( - "INSERT INTO whitelists (domain_name, created_at) VALUES ($1::TEXT, 'now');", - &[&host.to_string()], - ) - .await?; - let mut write_guard = whitelists.write().await; - write_guard.insert(host.to_string()); - - Ok(()) + write_guard.insert(host); } - pub async fn add_listener(&self, client: &Client, listener: XsdAnyUri) -> Result<(), Error> { + pub async fn cache_listener(&self, listener: XsdAnyUri) { let listeners = self.listeners.clone(); - info!( - "INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now'); [{}]", - listener.as_str(), - ); - client - .execute( - "INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now');", - &[&listener.as_str()], - ) - .await?; - let mut write_guard = listeners.write().await; write_guard.insert(listener); - - Ok(()) } pub async fn hydrate( @@ -313,19 +254,19 @@ impl State { let f1 = async move { let conn = pool.get().await?; - hydrate_blocks(&conn).await + crate::db::hydrate_blocks(&conn).await }; let f2 = async move { let conn = pool1.get().await?; - hydrate_whitelists(&conn).await + crate::db::hydrate_whitelists(&conn).await }; let f3 = async move { let conn = pool2.get().await?; - hydrate_listeners(&conn).await + crate::db::hydrate_listeners(&conn).await }; let f4 = async move { @@ -346,41 +287,3 @@ impl State { }) } } - -pub async fn hydrate_blocks(client: &Client) -> Result, Error> { - info!("SELECT domain_name FROM blocks"); - let rows = client.query("SELECT domain_name FROM blocks", &[]).await?; - - parse_rows(rows) -} - -pub async fn hydrate_whitelists(client: &Client) -> Result, Error> { - info!("SELECT domain_name FROM whitelists"); - let rows = client - .query("SELECT domain_name FROM whitelists", &[]) - .await?; - - parse_rows(rows) -} - -pub async fn hydrate_listeners(client: &Client) -> Result, Error> { - info!("SELECT actor_id FROM listeners"); - let rows = client.query("SELECT actor_id FROM listeners", &[]).await?; - - parse_rows(rows) -} - -pub fn parse_rows(rows: Vec) -> Result, Error> -where - T: std::str::FromStr + Eq + std::hash::Hash, -{ - let hs = rows - .into_iter() - .filter_map(move |row| { - let s: String = row.try_get(0).ok()?; - s.parse().ok() - }) - .collect(); - - Ok(hs) -}