Files
rustdesk-server/src/database.rs
2022-05-13 09:56:44 +08:00

232 lines
6.4 KiB
Rust

use async_trait::async_trait;
use hbb_common::{log, ResultType};
use serde_json::value::Value;
use sqlx::{
sqlite::SqliteConnectOptions, ConnectOptions, Connection, Error as SqlxError, SqliteConnection,
};
use std::{ops::DerefMut, str::FromStr};
//use sqlx::postgres::PgPoolOptions;
//use sqlx::mysql::MySqlPoolOptions;
pub(crate) type DB = sqlx::Sqlite;
pub(crate) type MapValue = serde_json::map::Map<String, Value>;
pub(crate) type MapStr = std::collections::HashMap<String, String>;
type Pool = deadpool::managed::Pool<DbPool>;
pub struct DbPool {
url: String,
}
#[async_trait]
impl deadpool::managed::Manager for DbPool {
type Type = SqliteConnection;
type Error = SqlxError;
async fn create(&self) -> Result<SqliteConnection, SqlxError> {
let mut opt = SqliteConnectOptions::from_str(&self.url).unwrap();
opt.log_statements(log::LevelFilter::Debug);
SqliteConnection::connect_with(&opt).await
}
async fn recycle(
&self,
obj: &mut SqliteConnection,
) -> deadpool::managed::RecycleResult<SqlxError> {
Ok(obj.ping().await?)
}
}
#[derive(Clone)]
pub struct Database {
pool: Pool,
}
#[derive(Default)]
pub struct Peer {
pub guid: Vec<u8>,
pub id: String,
pub uuid: Vec<u8>,
pub pk: Vec<u8>,
pub user: Option<Vec<u8>>,
pub info: String,
pub status: Option<i64>,
}
impl Database {
pub async fn new(url: &str) -> ResultType<Database> {
if !std::path::Path::new(url).exists() {
std::fs::File::create(url).ok();
}
let n: usize = std::env::var("MAX_DATABASE_CONNECTIONS")
.unwrap_or("1".to_owned())
.parse()
.unwrap_or(1);
log::debug!("MAX_DATABASE_CONNECTIONS={}", n);
let pool = Pool::new(
DbPool {
url: url.to_owned(),
},
n,
);
let _ = pool.get().await?; // test
let db = Database { pool };
db.create_tables().await?;
Ok(db)
}
async fn create_tables(&self) -> ResultType<()> {
sqlx::query!(
"
create table if not exists peer (
guid blob primary key not null,
id varchar(100) not null,
uuid blob not null,
pk blob not null,
created_at datetime not null default(current_timestamp),
user blob,
status tinyint,
note varchar(300),
info text not null
) without rowid;
create unique index if not exists index_peer_id on peer (id);
create index if not exists index_peer_user on peer (user);
create index if not exists index_peer_created_at on peer (created_at);
create index if not exists index_peer_status on peer (status);
"
)
.execute(self.pool.get().await?.deref_mut())
.await?;
Ok(())
}
pub async fn get_peer(&self, id: &str) -> ResultType<Option<Peer>> {
Ok(sqlx::query_as!(
Peer,
"select guid, id, uuid, pk, user, status, info from peer where id = ?",
id
)
.fetch_optional(self.pool.get().await?.deref_mut())
.await?)
}
pub async fn get_peer_id(&self, guid: &[u8]) -> ResultType<Option<String>> {
Ok(sqlx::query!("select id from peer where guid = ?", guid)
.fetch_optional(self.pool.get().await?.deref_mut())
.await?
.map(|x| x.id))
}
#[inline]
pub async fn get_conn(&self) -> ResultType<deadpool::managed::Object<DbPool>> {
Ok(self.pool.get().await?)
}
pub async fn update_peer(&self, payload: MapValue, guid: &[u8]) -> ResultType<()> {
let mut conn = self.get_conn().await?;
let mut tx = conn.begin().await?;
if let Some(v) = payload.get("note") {
let v = get_str(v);
sqlx::query!("update peer set note = ? where guid = ?", v, guid)
.execute(&mut tx)
.await?;
}
tx.commit().await?;
Ok(())
}
pub async fn insert_peer(
&self,
id: &str,
uuid: &Vec<u8>,
pk: &Vec<u8>,
info: &str,
) -> ResultType<Vec<u8>> {
let guid = uuid::Uuid::new_v4().as_bytes().to_vec();
sqlx::query!(
"insert into peer(guid, id, uuid, pk, info) values(?, ?, ?, ?, ?)",
guid,
id,
uuid,
pk,
info
)
.execute(self.pool.get().await?.deref_mut())
.await?;
Ok(guid)
}
pub async fn update_pk(
&self,
guid: &Vec<u8>,
id: &str,
pk: &Vec<u8>,
info: &str,
) -> ResultType<()> {
sqlx::query!(
"update peer set id=?, pk=?, info=? where guid=?",
id,
pk,
info,
guid
)
.execute(self.pool.get().await?.deref_mut())
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use hbb_common::tokio;
#[test]
fn test_insert() {
insert();
}
#[tokio::main(flavor = "multi_thread")]
async fn insert() {
let db = super::Database::new("test.sqlite3").await.unwrap();
let mut jobs = vec![];
for i in 0..10000 {
let cloned = db.clone();
let id = i.to_string();
let a = tokio::spawn(async move {
let empty_vec = Vec::new();
cloned
.insert_peer(&id, &empty_vec, &empty_vec, "")
.await
.unwrap();
});
jobs.push(a);
}
for i in 0..10000 {
let cloned = db.clone();
let id = i.to_string();
let a = tokio::spawn(async move {
cloned.get_peer(&id).await.unwrap();
});
jobs.push(a);
}
hbb_common::futures::future::join_all(jobs).await;
}
}
#[inline]
pub fn guid2str(guid: &Vec<u8>) -> String {
let mut bytes = [0u8; 16];
bytes[..].copy_from_slice(&guid);
uuid::Uuid::from_bytes(bytes).to_string()
}
pub(crate) fn get_str(v: &Value) -> Option<&str> {
match v {
Value::String(v) => {
let v = v.trim();
if v.is_empty() {
None
} else {
Some(v)
}
}
_ => None,
}
}