diff --git a/mRemoteNG/Config/Serializers/ConnectionSerializers/Sql/SqlDatabaseMetaDataRetriever.cs b/mRemoteNG/Config/Serializers/ConnectionSerializers/Sql/SqlDatabaseMetaDataRetriever.cs index d900a959..27fe95bc 100644 --- a/mRemoteNG/Config/Serializers/ConnectionSerializers/Sql/SqlDatabaseMetaDataRetriever.cs +++ b/mRemoteNG/Config/Serializers/ConnectionSerializers/Sql/SqlDatabaseMetaDataRetriever.cs @@ -99,9 +99,22 @@ namespace mRemoteNG.Config.Serializers.ConnectionSerializers.Sql if (rootTreeNode != null) { cmd = databaseConnector.DbCommand( - "INSERT INTO tblRoot (Name, Export, Protected, ConfVersion) VALUES('" + - MiscTools.PrepareValueForDB(rootTreeNode.Name) + "', 0, '" + strProtected + "','" + - ConnectionsFileInfo.ConnectionFileVersion + "')"); + "INSERT INTO tblRoot (Name, Export, Protected, ConfVersion) VALUES(@Name, 0, @Protected, @ConfVersion)"); + + DbParameter nameParam = cmd.CreateParameter(); + nameParam.ParameterName = "@Name"; + nameParam.Value = rootTreeNode.Name; + cmd.Parameters.Add(nameParam); + + DbParameter protectedParam = cmd.CreateParameter(); + protectedParam.ParameterName = "@Protected"; + protectedParam.Value = strProtected; + cmd.Parameters.Add(protectedParam); + + DbParameter confVersionParam = cmd.CreateParameter(); + confVersionParam.ParameterName = "@ConfVersion"; + confVersionParam.Value = ConnectionsFileInfo.ConnectionFileVersion.ToString(); + cmd.Parameters.Add(confVersionParam); cmd.ExecuteNonQuery(); } @@ -111,6 +124,22 @@ namespace mRemoteNG.Config.Serializers.ConnectionSerializers.Sql } } + private bool IsValidTableName(string tableName) + { + // Table names should only contain alphanumeric characters and underscores + // This prevents SQL injection when table names must be used directly in queries + if (string.IsNullOrWhiteSpace(tableName)) + return false; + + foreach (char c in tableName) + { + if (!char.IsLetterOrDigit(c) && c != '_') + return false; + } + + return true; + } + private bool DoesDbTableExist(IDatabaseConnector databaseConnector, string tableName) { bool exists; @@ -119,7 +148,18 @@ namespace mRemoteNG.Config.Serializers.ConnectionSerializers.Sql { // ANSI SQL way. Works in PostgreSQL, MSSQL, MySQL. string database_name = Properties.OptionsDBsPage.Default.SQLDatabaseName; - DbCommand cmd = databaseConnector.DbCommand("select case when exists((select * from information_schema.tables where table_name = '" + tableName + "' and table_schema='"+ database_name + "')) then 1 else 0 end"); + DbCommand cmd = databaseConnector.DbCommand("select case when exists((select * from information_schema.tables where table_name = @TableName and table_schema = @DatabaseName)) then 1 else 0 end"); + + DbParameter tableNameParam = cmd.CreateParameter(); + tableNameParam.ParameterName = "@TableName"; + tableNameParam.Value = tableName; + cmd.Parameters.Add(tableNameParam); + + DbParameter databaseNameParam = cmd.CreateParameter(); + databaseNameParam.ParameterName = "@DatabaseName"; + databaseNameParam.Value = database_name; + cmd.Parameters.Add(databaseNameParam); + short cmdResult = Convert.ToInt16(cmd.ExecuteScalar()); exists = (cmdResult == 1); } @@ -128,9 +168,18 @@ namespace mRemoteNG.Config.Serializers.ConnectionSerializers.Sql try { // Other RDBMS. Graceful degradation - exists = true; - DbCommand cmdOthers = databaseConnector.DbCommand("select 1 from " + tableName + " where 1 = 0"); - cmdOthers.ExecuteNonQuery(); + // Note: Table names cannot be parameterized in standard SQL. + // Validate tableName to prevent SQL injection + if (!IsValidTableName(tableName)) + { + exists = false; + } + else + { + exists = true; + DbCommand cmdOthers = databaseConnector.DbCommand($"select 1 from {tableName} where 1 = 0"); + cmdOthers.ExecuteNonQuery(); + } } catch {