2021-11-07 18:53:39 +01:00
use std ::{ sync ::Arc , time ::Duration } ;
2020-08-18 17:15:44 +02:00
use diesel ::r2d2 ::{ ConnectionManager , Pool , PooledConnection } ;
2020-07-14 18:00:09 +02:00
use rocket ::{
http ::Status ,
2021-11-07 18:53:39 +01:00
outcome ::IntoOutcome ,
2020-07-19 21:01:31 +02:00
request ::{ FromRequest , Outcome } ,
2021-11-07 18:53:39 +01:00
Request ,
} ;
use tokio ::{
sync ::{ Mutex , OwnedSemaphorePermit , Semaphore } ,
time ::timeout ,
2020-07-14 18:00:09 +02:00
} ;
2019-05-03 15:46:29 +02:00
2020-08-18 17:15:44 +02:00
use crate ::{
error ::{ Error , MapResult } ,
CONFIG ,
} ;
2018-02-10 01:00:55 +01:00
2020-08-18 17:15:44 +02:00
#[ cfg(sqlite) ]
2019-05-26 23:02:41 +02:00
#[ path = " schemas/sqlite/schema.rs " ]
2020-08-18 17:15:44 +02:00
pub mod __sqlite_schema ;
#[ cfg(mysql) ]
2019-05-26 23:02:41 +02:00
#[ path = " schemas/mysql/schema.rs " ]
2020-08-18 17:15:44 +02:00
pub mod __mysql_schema ;
#[ cfg(postgresql) ]
2019-09-12 22:12:22 +02:00
#[ path = " schemas/postgresql/schema.rs " ]
2020-08-18 17:15:44 +02:00
pub mod __postgresql_schema ;
2021-11-16 17:07:55 +01:00
// These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
2021-11-07 18:53:39 +01:00
// A wrapper around spawn_blocking that propagates panics to the calling code.
pub async fn run_blocking < F , R > ( job : F ) -> R
where
F : FnOnce ( ) -> R + Send + 'static ,
R : Send + 'static ,
{
match tokio ::task ::spawn_blocking ( job ) . await {
Ok ( ret ) = > ret ,
Err ( e ) = > match e . try_into_panic ( ) {
Ok ( panic ) = > std ::panic ::resume_unwind ( panic ) ,
Err ( _ ) = > unreachable! ( " spawn_blocking tasks are never cancelled " ) ,
} ,
}
}
2020-08-18 17:15:44 +02:00
// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
macro_rules ! generate_connections {
( $( $name :ident : $ty :ty ) , + ) = > {
#[ allow(non_camel_case_types, dead_code) ]
#[ derive(Eq, PartialEq) ]
pub enum DbConnType { $( $name , ) + }
2021-11-07 18:53:39 +01:00
pub struct DbConn {
conn : Arc < Mutex < Option < DbConnInner > > > ,
permit : Option < OwnedSemaphorePermit > ,
}
2020-08-18 17:15:44 +02:00
#[ allow(non_camel_case_types) ]
2021-11-07 18:53:39 +01:00
pub enum DbConnInner { $( #[ cfg($name) ] $name ( PooledConnection < ConnectionManager < $ty > > ) , ) + }
#[ derive(Clone) ]
pub struct DbPool {
// This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
pool : Option < DbPoolInner > ,
semaphore : Arc < Semaphore >
}
2020-08-18 17:15:44 +02:00
#[ allow(non_camel_case_types) ]
2021-03-22 19:57:35 +01:00
#[ derive(Clone) ]
2021-11-07 18:53:39 +01:00
pub enum DbPoolInner { $( #[ cfg($name) ] $name ( Pool < ConnectionManager < $ty > > ) , ) + }
impl Drop for DbConn {
fn drop ( & mut self ) {
let conn = self . conn . clone ( ) ;
let permit = self . permit . take ( ) ;
// Since connection can't be on the stack in an async fn during an
// await, we have to spawn a new blocking-safe thread...
tokio ::task ::spawn_blocking ( move | | {
// And then re-enter the runtime to wait on the async mutex, but in a blocking fashion.
let mut conn = tokio ::runtime ::Handle ::current ( ) . block_on ( conn . lock_owned ( ) ) ;
if let Some ( conn ) = conn . take ( ) {
drop ( conn ) ;
}
// Drop permit after the connection is dropped
drop ( permit ) ;
} ) ;
}
}
impl Drop for DbPool {
fn drop ( & mut self ) {
let pool = self . pool . take ( ) ;
tokio ::task ::spawn_blocking ( move | | drop ( pool ) ) ;
}
}
2020-08-18 17:15:44 +02:00
impl DbPool {
// For the given database URL, guess it's type, run migrations create pool and return it
pub fn from_config ( ) -> Result < Self , Error > {
let url = CONFIG . database_url ( ) ;
let conn_type = DbConnType ::from_url ( & url ) ? ;
match conn_type { $(
DbConnType ::$name = > {
#[ cfg($name) ]
{
2020-10-03 22:31:52 +02:00
paste ::paste! { [ < $name _migrations > ] ::run_migrations ( ) ? ; }
2020-08-18 17:15:44 +02:00
let manager = ConnectionManager ::new ( & url ) ;
2020-10-06 15:23:55 +02:00
let pool = Pool ::builder ( )
. max_size ( CONFIG . database_max_conns ( ) )
2021-11-07 18:53:39 +01:00
. connection_timeout ( Duration ::from_secs ( CONFIG . database_timeout ( ) ) )
2020-10-06 15:23:55 +02:00
. build ( manager )
. map_res ( " Failed to create pool " ) ? ;
2021-11-07 18:53:39 +01:00
return Ok ( DbPool {
pool : Some ( DbPoolInner ::$name ( pool ) ) ,
semaphore : Arc ::new ( Semaphore ::new ( CONFIG . database_max_conns ( ) as usize ) ) ,
} ) ;
2020-08-18 17:15:44 +02:00
}
#[ cfg(not($name)) ]
#[ allow(unreachable_code) ]
return unreachable! ( " Trying to use a DB backend when it's feature is disabled " ) ;
} ,
) + }
}
// Get a connection from the pool
2021-11-07 18:53:39 +01:00
pub async fn get ( & self ) -> Result < DbConn , Error > {
let duration = Duration ::from_secs ( CONFIG . database_timeout ( ) ) ;
let permit = match timeout ( duration , self . semaphore . clone ( ) . acquire_owned ( ) ) . await {
Ok ( p ) = > p . expect ( " Semaphore should be open " ) ,
Err ( _ ) = > {
err! ( " Timeout waiting for database connection " ) ;
}
} ;
match self . pool . as_ref ( ) . expect ( " DbPool.pool should always be Some() " ) { $(
2020-08-18 17:15:44 +02:00
#[ cfg($name) ]
2021-11-07 18:53:39 +01:00
DbPoolInner ::$name ( p ) = > {
let pool = p . clone ( ) ;
let c = run_blocking ( move | | pool . get_timeout ( duration ) ) . await . map_res ( " Error retrieving connection from pool " ) ? ;
return Ok ( DbConn {
conn : Arc ::new ( Mutex ::new ( Some ( DbConnInner ::$name ( c ) ) ) ) ,
permit : Some ( permit )
} ) ;
} ,
2020-08-18 17:15:44 +02:00
) + }
}
}
} ;
}
2019-05-26 23:02:41 +02:00
2020-08-18 17:15:44 +02:00
generate_connections! {
sqlite : diesel ::sqlite ::SqliteConnection ,
mysql : diesel ::mysql ::MysqlConnection ,
postgresql : diesel ::pg ::PgConnection
}
impl DbConnType {
pub fn from_url ( url : & str ) -> Result < DbConnType , Error > {
// Mysql
if url . starts_with ( " mysql: " ) {
#[ cfg(mysql) ]
return Ok ( DbConnType ::mysql ) ;
#[ cfg(not(mysql)) ]
err! ( " `DATABASE_URL` is a MySQL URL, but the 'mysql' feature is not enabled " )
// Postgres
} else if url . starts_with ( " postgresql: " ) | | url . starts_with ( " postgres: " ) {
#[ cfg(postgresql) ]
return Ok ( DbConnType ::postgresql ) ;
2018-02-10 01:00:55 +01:00
2020-08-18 17:15:44 +02:00
#[ cfg(not(postgresql)) ]
err! ( " `DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled " )
//Sqlite
} else {
#[ cfg(sqlite) ]
return Ok ( DbConnType ::sqlite ) ;
#[ cfg(not(sqlite)) ]
err! ( " `DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled " )
}
}
2018-02-10 01:00:55 +01:00
}
2020-08-18 17:15:44 +02:00
#[ macro_export ]
macro_rules ! db_run {
// Same for all dbs
( $conn :ident : $body :block ) = > {
db_run! { $conn : sqlite , mysql , postgresql $body }
} ;
2021-03-28 00:10:01 +01:00
( @ raw $conn :ident : $body :block ) = > {
db_run! { @ raw $conn : sqlite , mysql , postgresql $body }
} ;
// Different code for each db
2021-11-07 18:53:39 +01:00
( $conn :ident : $( $( $db :ident ) , + $body :block ) + ) = > { {
2021-03-28 00:10:01 +01:00
#[ allow(unused) ] use diesel ::prelude ::* ;
2021-11-22 00:01:23 +01:00
#[ allow(unused) ] use crate ::db ::FromDb ;
2021-11-07 18:53:39 +01:00
let conn = $conn . conn . clone ( ) ;
2021-11-22 00:01:23 +01:00
let mut conn = conn . lock_owned ( ) . await ;
match conn . as_mut ( ) . expect ( " internal invariant broken: self.connection is Some " ) {
$( $(
#[ cfg($db) ]
crate ::db ::DbConnInner ::$db ( $conn ) = > {
paste ::paste! {
#[ allow(unused) ] use crate ::db ::[ < __ $db _schema > ] ::{ self as schema , * } ;
#[ allow(unused) ] use [ < __ $db _model > ] ::* ;
}
2021-11-07 18:53:39 +01:00
2021-11-22 00:01:23 +01:00
tokio ::task ::block_in_place ( move | | { $body } ) // Run blocking can't be used due to the 'static limitation, use block_in_place instead
} ,
) + ) +
}
2021-11-07 18:53:39 +01:00
} } ;
( @ raw $conn :ident : $( $( $db :ident ) , + $body :block ) + ) = > { {
#[ allow(unused) ] use diesel ::prelude ::* ;
2021-11-22 00:01:23 +01:00
#[ allow(unused) ] use crate ::db ::FromDb ;
2021-11-07 18:53:39 +01:00
let conn = $conn . conn . clone ( ) ;
2021-11-22 00:01:23 +01:00
let mut conn = conn . lock_owned ( ) . await ;
match conn . as_mut ( ) . expect ( " internal invariant broken: self.connection is Some " ) {
$( $(
#[ cfg($db) ]
crate ::db ::DbConnInner ::$db ( $conn ) = > {
paste ::paste! {
#[ allow(unused) ] use crate ::db ::[ < __ $db _schema > ] ::{ self as schema , * } ;
// @ RAW: #[allow(unused)] use [<__ $db _model>]::*;
}
2021-11-07 18:53:39 +01:00
2021-11-22 00:01:23 +01:00
tokio ::task ::block_in_place ( move | | { $body } ) // Run blocking can't be used due to the 'static limitation, use block_in_place instead
} ,
) + ) +
}
2021-11-07 18:53:39 +01:00
} } ;
2018-02-10 01:00:55 +01:00
}
2020-08-18 17:15:44 +02:00
pub trait FromDb {
type Output ;
2020-08-28 22:10:28 +02:00
#[ allow(clippy::wrong_self_convention) ]
2020-08-18 17:15:44 +02:00
fn from_db ( self ) -> Self ::Output ;
}
2021-05-02 17:49:25 +02:00
impl < T : FromDb > FromDb for Vec < T > {
type Output = Vec < T ::Output > ;
#[ allow(clippy::wrong_self_convention) ]
#[ inline(always) ]
fn from_db ( self ) -> Self ::Output {
self . into_iter ( ) . map ( crate ::db ::FromDb ::from_db ) . collect ( )
}
}
impl < T : FromDb > FromDb for Option < T > {
type Output = Option < T ::Output > ;
#[ allow(clippy::wrong_self_convention) ]
#[ inline(always) ]
fn from_db ( self ) -> Self ::Output {
self . map ( crate ::db ::FromDb ::from_db )
}
}
2021-03-28 00:10:01 +01:00
// For each struct eg. Cipher, we create a CipherDb inside a module named __$db_model (where $db is sqlite, mysql or postgresql),
2020-08-18 17:15:44 +02:00
// to implement the Diesel traits. We also provide methods to convert between them and the basic structs. Later, that module will be auto imported when using db_run!
#[ macro_export ]
macro_rules ! db_object {
( $(
$( #[ $attr:meta ] ) *
pub struct $name :ident {
$( $( #[ $field_attr:meta ] ) * $vis :vis $field :ident : $typ :ty ) , +
$(, ) ?
}
2021-03-28 00:10:01 +01:00
) + ) = > {
2020-08-18 17:15:44 +02:00
// Create the normal struct, without attributes
$( pub struct $name { $( /* $( #[$field_attr] ) * */ $vis $field : $typ , ) + } ) +
2021-03-28 00:10:01 +01:00
2020-08-18 17:15:44 +02:00
#[ cfg(sqlite) ]
pub mod __sqlite_model { $( db_object! { @ db sqlite | $( #[ $attr ] ) * | $name | $( $( #[ $field_attr ] ) * $field : $typ ) , + } ) + }
#[ cfg(mysql) ]
pub mod __mysql_model { $( db_object! { @ db mysql | $( #[ $attr ] ) * | $name | $( $( #[ $field_attr ] ) * $field : $typ ) , + } ) + }
#[ cfg(postgresql) ]
pub mod __postgresql_model { $( db_object! { @ db postgresql | $( #[ $attr ] ) * | $name | $( $( #[ $field_attr ] ) * $field : $typ ) , + } ) + }
} ;
( @ db $db :ident | $( #[ $attr:meta ] ) * | $name :ident | $( $( #[ $field_attr:meta ] ) * $vis :vis $field :ident : $typ :ty ) , + ) = > {
paste ::paste! {
#[ allow(unused) ] use super ::* ;
#[ allow(unused) ] use diesel ::prelude ::* ;
#[ allow(unused) ] use crate ::db ::[ < __ $db _schema > ] ::* ;
$( #[ $attr ] ) *
pub struct [ < $name Db > ] { $(
$( #[ $field_attr ] ) * $vis $field : $typ ,
) + }
impl [ < $name Db > ] {
2021-03-28 00:10:01 +01:00
#[ allow(clippy::wrong_self_convention) ]
2020-08-18 17:15:44 +02:00
#[ inline(always) ] pub fn to_db ( x : & super ::$name ) -> Self { Self { $( $field : x . $field . clone ( ) , ) + } }
}
impl crate ::db ::FromDb for [ < $name Db > ] {
type Output = super ::$name ;
2021-05-02 17:49:25 +02:00
#[ allow(clippy::wrong_self_convention) ]
2020-08-18 17:15:44 +02:00
#[ inline(always) ] fn from_db ( self ) -> Self ::Output { super ::$name { $( $field : self . $field , ) + } }
}
}
} ;
}
// Reexport the models, needs to be after the macros are defined so it can access them
pub mod models ;
2021-04-05 15:09:16 +02:00
/// Creates a back-up of the sqlite database
/// MySQL/MariaDB and PostgreSQL are not supported.
2021-11-07 18:53:39 +01:00
pub async fn backup_database ( conn : & DbConn ) -> Result < ( ) , Error > {
2021-04-05 15:09:16 +02:00
db_run! { @ raw conn :
postgresql , mysql {
2021-11-07 18:53:39 +01:00
let _ = conn ;
2021-04-05 15:09:16 +02:00
err! ( " PostgreSQL and MySQL/MariaDB do not support this backup feature " ) ;
}
sqlite {
use std ::path ::Path ;
let db_url = CONFIG . database_url ( ) ;
let db_path = Path ::new ( & db_url ) . parent ( ) . unwrap ( ) . to_string_lossy ( ) ;
2021-04-29 15:58:29 +02:00
let file_date = chrono ::Utc ::now ( ) . format ( " %Y%m%d_%H%M%S " ) . to_string ( ) ;
2021-04-05 15:09:16 +02:00
diesel ::sql_query ( format! ( " VACUUM INTO ' {} /db_ {} .sqlite3' " , db_path , file_date ) ) . execute ( conn ) ? ;
2021-04-29 15:58:29 +02:00
Ok ( ( ) )
2021-04-05 15:09:16 +02:00
}
}
2019-05-03 15:46:29 +02:00
}
2021-03-28 00:10:01 +01:00
/// Get the SQL Server version
2021-11-07 18:53:39 +01:00
pub async fn get_sql_server_version ( conn : & DbConn ) -> String {
2021-03-28 00:10:01 +01:00
db_run! { @ raw conn :
postgresql , mysql {
2021-04-05 15:09:16 +02:00
no_arg_sql_function! ( version , diesel ::sql_types ::Text ) ;
diesel ::select ( version ) . get_result ::< String > ( conn ) . unwrap_or_else ( | _ | " Unknown " . to_string ( ) )
2021-03-28 00:10:01 +01:00
}
sqlite {
2021-04-05 15:09:16 +02:00
no_arg_sql_function! ( sqlite_version , diesel ::sql_types ::Text ) ;
diesel ::select ( sqlite_version ) . get_result ::< String > ( conn ) . unwrap_or_else ( | _ | " Unknown " . to_string ( ) )
2021-03-28 00:10:01 +01:00
}
}
}
2018-02-10 01:00:55 +01:00
/// Attempts to retrieve a single connection from the managed database pool. If
/// no pool is currently managed, fails with an `InternalServerError` status. If
/// no connections are available, fails with a `ServiceUnavailable` status.
2021-11-07 18:53:39 +01:00
#[ rocket::async_trait ]
impl < ' r > FromRequest < ' r > for DbConn {
2018-02-10 01:00:55 +01:00
type Error = ( ) ;
2021-11-07 18:53:39 +01:00
async fn from_request ( request : & ' r Request < '_ > ) -> Outcome < Self , Self ::Error > {
match request . rocket ( ) . state ::< DbPool > ( ) {
Some ( p ) = > p . get ( ) . await . map_err ( | _ | ( ) ) . into_outcome ( Status ::ServiceUnavailable ) ,
None = > Outcome ::Failure ( ( Status ::InternalServerError , ( ) ) ) ,
2018-02-10 01:00:55 +01:00
}
}
}
2020-08-18 17:15:44 +02:00
// Embed the migrations from the migrations folder into the application
// This way, the program automatically migrates the database to the latest version
// https://docs.rs/diesel_migrations/*/diesel_migrations/macro.embed_migrations.html
#[ cfg(sqlite) ]
mod sqlite_migrations {
embed_migrations! ( " migrations/sqlite " ) ;
2020-10-03 22:31:52 +02:00
pub fn run_migrations ( ) -> Result < ( ) , super ::Error > {
2020-08-18 17:15:44 +02:00
// Make sure the directory exists
let url = crate ::CONFIG . database_url ( ) ;
let path = std ::path ::Path ::new ( & url ) ;
if let Some ( parent ) = path . parent ( ) {
if std ::fs ::create_dir_all ( parent ) . is_err ( ) {
error! ( " Error creating database directory " ) ;
std ::process ::exit ( 1 ) ;
}
}
use diesel ::{ Connection , RunQueryDsl } ;
// Make sure the database is up to date (create if it doesn't exist, or run the migrations)
2021-03-31 22:18:35 +02:00
let connection = diesel ::sqlite ::SqliteConnection ::establish ( & crate ::CONFIG . database_url ( ) ) ? ;
2020-08-18 17:15:44 +02:00
// Disable Foreign Key Checks during migration
2021-03-28 00:10:01 +01:00
2020-08-18 17:15:44 +02:00
// Scoped to a connection.
diesel ::sql_query ( " PRAGMA foreign_keys = OFF " )
. execute ( & connection )
. expect ( " Failed to disable Foreign Key Checks during migrations " ) ;
// Turn on WAL in SQLite
if crate ::CONFIG . enable_db_wal ( ) {
2021-04-06 22:54:42 +02:00
diesel ::sql_query ( " PRAGMA journal_mode=wal " ) . execute ( & connection ) . expect ( " Failed to turn on WAL " ) ;
2020-08-18 17:15:44 +02:00
}
2020-10-03 22:31:52 +02:00
embedded_migrations ::run_with_output ( & connection , & mut std ::io ::stdout ( ) ) ? ;
Ok ( ( ) )
2020-08-18 17:15:44 +02:00
}
}
#[ cfg(mysql) ]
mod mysql_migrations {
embed_migrations! ( " migrations/mysql " ) ;
2020-10-03 22:31:52 +02:00
pub fn run_migrations ( ) -> Result < ( ) , super ::Error > {
2020-08-18 17:15:44 +02:00
use diesel ::{ Connection , RunQueryDsl } ;
// Make sure the database is up to date (create if it doesn't exist, or run the migrations)
2021-03-31 22:18:35 +02:00
let connection = diesel ::mysql ::MysqlConnection ::establish ( & crate ::CONFIG . database_url ( ) ) ? ;
2020-08-18 17:15:44 +02:00
// Disable Foreign Key Checks during migration
// Scoped to a connection/session.
diesel ::sql_query ( " SET FOREIGN_KEY_CHECKS = 0 " )
. execute ( & connection )
. expect ( " Failed to disable Foreign Key Checks during migrations " ) ;
2020-10-03 22:31:52 +02:00
embedded_migrations ::run_with_output ( & connection , & mut std ::io ::stdout ( ) ) ? ;
Ok ( ( ) )
2020-08-18 17:15:44 +02:00
}
}
#[ cfg(postgresql) ]
mod postgresql_migrations {
embed_migrations! ( " migrations/postgresql " ) ;
2020-10-03 22:31:52 +02:00
pub fn run_migrations ( ) -> Result < ( ) , super ::Error > {
2020-08-18 17:15:44 +02:00
use diesel ::{ Connection , RunQueryDsl } ;
// Make sure the database is up to date (create if it doesn't exist, or run the migrations)
2021-03-31 22:18:35 +02:00
let connection = diesel ::pg ::PgConnection ::establish ( & crate ::CONFIG . database_url ( ) ) ? ;
2020-08-18 17:15:44 +02:00
// Disable Foreign Key Checks during migration
2021-03-28 00:10:01 +01:00
2020-08-18 17:15:44 +02:00
// FIXME: Per https://www.postgresql.org/docs/12/sql-set-constraints.html,
// "SET CONSTRAINTS sets the behavior of constraint checking within the
// current transaction", so this setting probably won't take effect for
// any of the migrations since it's being run outside of a transaction.
// Migrations that need to disable foreign key checks should run this
// from within the migration script itself.
diesel ::sql_query ( " SET CONSTRAINTS ALL DEFERRED " )
. execute ( & connection )
. expect ( " Failed to disable Foreign Key Checks during migrations " ) ;
2020-10-03 22:31:52 +02:00
embedded_migrations ::run_with_output ( & connection , & mut std ::io ::stdout ( ) ) ? ;
Ok ( ( ) )
2018-02-10 01:00:55 +01:00
}
}