From c8b90319966a54ea5629344d328c54701f671a08 Mon Sep 17 00:00:00 2001 From: 21pages Date: Thu, 24 Oct 2024 23:14:43 +0800 Subject: [PATCH] refactor session insert, update if already exists (#9739) * All share the same last_receive_time * Not second port forward Signed-off-by: 21pages --- src/server/connection.rs | 126 ++++++++++++++++++--------------------- 1 file changed, 59 insertions(+), 67 deletions(-) diff --git a/src/server/connection.rs b/src/server/connection.rs index dc184ceac..12157ddbb 100644 --- a/src/server/connection.rs +++ b/src/server/connection.rs @@ -152,8 +152,6 @@ struct Session { last_recv_time: Arc>, random_password: String, tfa: bool, - conn_type: AuthConnType, - conn_id: i32, } #[cfg(not(any(target_os = "android", target_os = "ios")))] @@ -217,7 +215,7 @@ pub struct Connection { server_audit_conn: String, server_audit_file: String, lr: LoginRequest, - last_recv_time: Arc>, + session_last_recv_time: Option>>, chat_unanswered: bool, file_transferred: bool, #[cfg(windows)] @@ -364,7 +362,7 @@ impl Connection { server_audit_conn: "".to_owned(), server_audit_file: "".to_owned(), lr: Default::default(), - last_recv_time: Arc::new(Mutex::new(Instant::now())), + session_last_recv_time: None, chat_unanswered: false, file_transferred: false, #[cfg(windows)] @@ -595,7 +593,7 @@ impl Connection { }, Ok(bytes) => { last_recv_time = Instant::now(); - *conn.last_recv_time.lock().unwrap() = Instant::now(); + conn.session_last_recv_time.as_mut().map(|t| *t.lock().unwrap() = Instant::now()); if let Ok(msg_in) = Message::parse_from_bytes(&bytes) { if !conn.on_message(msg_in).await { break; @@ -762,6 +760,10 @@ impl Connection { } if let Err(err) = conn.try_port_forward_loop(&mut rx_from_cm).await { conn.on_close(&err.to_string(), false).await; + raii::AuthedConnID::remove_session_if_last_duplication( + conn.inner.id(), + conn.session_key(), + ); } conn.post_conn_audit(json!({ @@ -1140,6 +1142,11 @@ impl Connection { auth_conn_type, self.session_key(), )); + self.session_last_recv_time = SESSIONS + .lock() + .unwrap() + .get(&self.session_key()) + .map(|s| s.last_recv_time.clone()); self.post_conn_audit( json!({"peer": ((&self.lr.my_id, &self.lr.my_name)), "type": conn_type}), ); @@ -1549,15 +1556,10 @@ impl Connection { if password::temporary_enabled() { let password = password::temporary_password(); if self.validate_one_password(password.clone()) { - raii::AuthedConnID::insert_session( + raii::AuthedConnID::update_or_insert_session( self.session_key(), - Session { - last_recv_time: self.last_recv_time.clone(), - random_password: password, - tfa: false, - conn_type: self.conn_type(), - conn_id: self.inner.id(), - }, + Some(password), + Some(false), ); return true; } @@ -1581,15 +1583,11 @@ impl Connection { .get(&self.session_key()) .map(|s| s.to_owned()); // last_recv_time is a mutex variable shared with connection, can be updated lively. - if let Some(mut session) = session { + if let Some(session) = session { if !self.lr.password.is_empty() && (tfa && session.tfa || !tfa && self.validate_one_password(session.random_password.clone())) { - session.last_recv_time = self.last_recv_time.clone(); - session.conn_id = self.inner.id(); - session.conn_type = self.conn_type(); - raii::AuthedConnID::insert_session(self.session_key(), session); log::info!("is recent session"); return true; } @@ -1841,34 +1839,13 @@ impl Connection { if res { self.update_failure(failure, true, 1); self.require_2fa.take(); + raii::AuthedConnID::set_session_2fa(self.session_key()); self.send_logon_response().await; self.try_start_cm( self.lr.my_id.to_owned(), self.lr.my_name.to_owned(), self.authorized, ); - let session = SESSIONS - .lock() - .unwrap() - .get(&self.session_key()) - .map(|s| s.to_owned()); - if let Some(mut session) = session { - session.tfa = true; - session.conn_id = self.inner.id(); - session.conn_type = self.conn_type(); - raii::AuthedConnID::insert_session(self.session_key(), session); - } else { - raii::AuthedConnID::insert_session( - self.session_key(), - Session { - last_recv_time: self.last_recv_time.clone(), - random_password: "".to_owned(), - tfa: true, - conn_type: self.conn_type(), - conn_id: self.inner.id(), - }, - ); - } if !tfa.hwid.is_empty() && Self::enable_trusted_devices() { Config::add_trusted_device(TrustedDevice { hwid: tfa.hwid, @@ -3872,16 +3849,17 @@ mod raii { } pub fn remove_session_if_last_duplication(conn_id: i32, key: SessionKey) { - let contains = SESSIONS.lock().unwrap().contains_key(&key); + let mut lock = SESSIONS.lock().unwrap(); + let contains = lock.contains_key(&key); if contains { let another = AUTHED_CONNS .lock() .unwrap() .iter() - .any(|c| c.0 != conn_id && c.2 == key && c.1 != AuthConnType::PortForward); + .any(|c| c.0 != conn_id && c.2 == key); if !another { // Keep the session if there is another connection with same peer_id and session_id. - SESSIONS.lock().unwrap().remove(&key); + lock.remove(&key); log::info!("remove session"); } else { log::info!("skip remove session"); @@ -3889,32 +3867,46 @@ mod raii { } } - pub fn insert_session(key: SessionKey, session: Session) { - let mut insert = true; - if session.conn_type == AuthConnType::PortForward { - // port forward doesn't update last received time - let other_alive_conns = AUTHED_CONNS - .lock() - .unwrap() - .iter() - .filter(|c| { - c.2 == key && c.1 != AuthConnType::PortForward // port forward doesn't remove itself - }) - .map(|c| c.0) - .collect::>(); - let another = SESSIONS.lock().unwrap().get(&key).map(|s| { - other_alive_conns.contains(&s.conn_id) - && s.tfa == session.tfa - && s.conn_type != AuthConnType::PortForward - }) == Some(true); - if another { - insert = false; - log::info!("skip insert session for port forward"); + pub fn update_or_insert_session( + key: SessionKey, + password: Option, + tfa: Option, + ) { + let mut lock = SESSIONS.lock().unwrap(); + let session = lock.get_mut(&key); + if let Some(session) = session { + if let Some(password) = password { + session.random_password = password; } + if let Some(tfa) = tfa { + session.tfa = tfa; + } + } else { + lock.insert( + key, + Session { + random_password: password.unwrap_or_default(), + tfa: tfa.unwrap_or_default(), + last_recv_time: Arc::new(Mutex::new(Instant::now())), + }, + ); } - if insert { - log::info!("insert session for {:?}", session.conn_type); - SESSIONS.lock().unwrap().insert(key, session); + } + + pub fn set_session_2fa(key: SessionKey) { + let mut lock = SESSIONS.lock().unwrap(); + let session = lock.get_mut(&key); + if let Some(session) = session { + session.tfa = true; + } else { + lock.insert( + key, + Session { + last_recv_time: Arc::new(Mutex::new(Instant::now())), + random_password: "".to_owned(), + tfa: true, + }, + ); } } }