| 1 | # -*- coding: utf-8 -*-
|
|---|
| 2 | #
|
|---|
| 3 | # Copyright (C) 2005-2023 Edgewall Software
|
|---|
| 4 | # Copyright (C) 2005 Christopher Lenz <cmlenz@gmx.de>
|
|---|
| 5 | # All rights reserved.
|
|---|
| 6 | #
|
|---|
| 7 | # This software is licensed as described in the file COPYING, which
|
|---|
| 8 | # you should have received as part of this distribution. The terms
|
|---|
| 9 | # are also available at https://trac.edgewall.org/wiki/TracLicense.
|
|---|
| 10 | #
|
|---|
| 11 | # This software consists of voluntary contributions made by many
|
|---|
| 12 | # individuals. For the exact contribution history, see the revision
|
|---|
| 13 | # history and logs, available at https://trac.edgewall.org/log/.
|
|---|
| 14 | #
|
|---|
| 15 | # Author: Christopher Lenz <cmlenz@gmx.de>
|
|---|
| 16 |
|
|---|
| 17 | import importlib
|
|---|
| 18 | import os
|
|---|
| 19 | import time
|
|---|
| 20 | import urllib.parse
|
|---|
| 21 | from abc import ABCMeta, abstractmethod
|
|---|
| 22 |
|
|---|
| 23 | from trac import db_default
|
|---|
| 24 | from trac.api import IEnvironmentSetupParticipant, ISystemInfoProvider
|
|---|
| 25 | from trac.config import BoolOption, ConfigurationError, IntOption, Option
|
|---|
| 26 | from trac.core import *
|
|---|
| 27 | from trac.db.pool import ConnectionPool
|
|---|
| 28 | from trac.db.schema import Table
|
|---|
| 29 | from trac.db.util import ConnectionWrapper
|
|---|
| 30 | from trac.util.concurrency import ThreadLocal
|
|---|
| 31 | from trac.util.html import tag
|
|---|
| 32 | from trac.util.text import unicode_passwd
|
|---|
| 33 | from trac.util.translation import _, tag_
|
|---|
| 34 |
|
|---|
| 35 |
|
|---|
| 36 | class DbContextManager(object):
|
|---|
| 37 | """Database Context Manager
|
|---|
| 38 |
|
|---|
| 39 | The outermost `DbContextManager` will close the connection.
|
|---|
| 40 | """
|
|---|
| 41 |
|
|---|
| 42 | db = None
|
|---|
| 43 |
|
|---|
| 44 | def __init__(self, env):
|
|---|
| 45 | self.dbmgr = DatabaseManager(env)
|
|---|
| 46 |
|
|---|
| 47 | def execute(self, query, params=None):
|
|---|
| 48 | """Shortcut for directly executing a query."""
|
|---|
| 49 | with self as db:
|
|---|
| 50 | return db.execute(query, params)
|
|---|
| 51 |
|
|---|
| 52 | __call__ = execute
|
|---|
| 53 |
|
|---|
| 54 | def executemany(self, query, params=None):
|
|---|
| 55 | """Shortcut for directly calling "executemany" on a query."""
|
|---|
| 56 | with self as db:
|
|---|
| 57 | return db.executemany(query, params)
|
|---|
| 58 |
|
|---|
| 59 |
|
|---|
| 60 | class TransactionContextManager(DbContextManager):
|
|---|
| 61 | """Transactioned Database Context Manager for retrieving a
|
|---|
| 62 | `~trac.db.util.ConnectionWrapper`.
|
|---|
| 63 |
|
|---|
| 64 | The outermost such context manager will perform a commit upon
|
|---|
| 65 | normal exit or a rollback after an exception.
|
|---|
| 66 | """
|
|---|
| 67 |
|
|---|
| 68 | def __enter__(self):
|
|---|
| 69 | db = self.dbmgr._transaction_local.wdb # outermost writable db
|
|---|
| 70 | if not db:
|
|---|
| 71 | db = self.dbmgr._transaction_local.rdb # reuse wrapped connection
|
|---|
| 72 | if db:
|
|---|
| 73 | db = ConnectionWrapper(db.cnx, db.log)
|
|---|
| 74 | else:
|
|---|
| 75 | db = self.dbmgr.get_connection()
|
|---|
| 76 | self.dbmgr._transaction_local.wdb = self.db = db
|
|---|
| 77 | return db
|
|---|
| 78 |
|
|---|
| 79 | def __exit__(self, et, ev, tb):
|
|---|
| 80 | if self.db:
|
|---|
| 81 | self.dbmgr._transaction_local.wdb = None
|
|---|
| 82 | if et is None:
|
|---|
| 83 | self.db.commit()
|
|---|
| 84 | else:
|
|---|
| 85 | self.db.rollback()
|
|---|
| 86 | if not self.dbmgr._transaction_local.rdb:
|
|---|
| 87 | self.db.close()
|
|---|
| 88 |
|
|---|
| 89 |
|
|---|
| 90 | class QueryContextManager(DbContextManager):
|
|---|
| 91 | """Database Context Manager for retrieving a read-only
|
|---|
| 92 | `~trac.db.util.ConnectionWrapper`.
|
|---|
| 93 | """
|
|---|
| 94 |
|
|---|
| 95 | def __enter__(self):
|
|---|
| 96 | db = self.dbmgr._transaction_local.rdb # outermost readonly db
|
|---|
| 97 | if not db:
|
|---|
| 98 | db = self.dbmgr._transaction_local.wdb # reuse wrapped connection
|
|---|
| 99 | if db:
|
|---|
| 100 | db = ConnectionWrapper(db.cnx, db.log, readonly=True)
|
|---|
| 101 | else:
|
|---|
| 102 | db = self.dbmgr.get_connection(readonly=True)
|
|---|
| 103 | self.dbmgr._transaction_local.rdb = self.db = db
|
|---|
| 104 | return db
|
|---|
| 105 |
|
|---|
| 106 | def __exit__(self, et, ev, tb):
|
|---|
| 107 | if self.db:
|
|---|
| 108 | self.dbmgr._transaction_local.rdb = None
|
|---|
| 109 | if not self.dbmgr._transaction_local.wdb:
|
|---|
| 110 | self.db.close()
|
|---|
| 111 |
|
|---|
| 112 |
|
|---|
| 113 | class ConnectionBase(object, metaclass=ABCMeta):
|
|---|
| 114 | """Abstract base class for database connection classes."""
|
|---|
| 115 |
|
|---|
| 116 | @abstractmethod
|
|---|
| 117 | def cast(self, column, type):
|
|---|
| 118 | """Returns a clause casting `column` as `type`."""
|
|---|
| 119 | pass
|
|---|
| 120 |
|
|---|
| 121 | @abstractmethod
|
|---|
| 122 | def concat(self, *args):
|
|---|
| 123 | """Returns a clause concatenating the sequence `args`."""
|
|---|
| 124 | pass
|
|---|
| 125 |
|
|---|
| 126 | @abstractmethod
|
|---|
| 127 | def drop_column(self, table, column):
|
|---|
| 128 | """Drops the `column` from `table`."""
|
|---|
| 129 | pass
|
|---|
| 130 |
|
|---|
| 131 | @abstractmethod
|
|---|
| 132 | def drop_table(self, table):
|
|---|
| 133 | """Drops the `table`."""
|
|---|
| 134 | pass
|
|---|
| 135 |
|
|---|
| 136 | @abstractmethod
|
|---|
| 137 | def get_column_names(self, table):
|
|---|
| 138 | """Returns the list of the column names in `table`."""
|
|---|
| 139 | pass
|
|---|
| 140 |
|
|---|
| 141 | @abstractmethod
|
|---|
| 142 | def get_last_id(self, cursor, table, column='id'):
|
|---|
| 143 | """Returns the current value of the primary key sequence for `table`.
|
|---|
| 144 | The `column` of the primary key may be specified, which defaults
|
|---|
| 145 | to `id`."""
|
|---|
| 146 | pass
|
|---|
| 147 |
|
|---|
| 148 | @abstractmethod
|
|---|
| 149 | def get_sequence_names(self):
|
|---|
| 150 | """Returns a list of the sequence names."""
|
|---|
| 151 | pass
|
|---|
| 152 |
|
|---|
| 153 | @abstractmethod
|
|---|
| 154 | def get_table_names(self):
|
|---|
| 155 | """Returns a list of the table names."""
|
|---|
| 156 | pass
|
|---|
| 157 |
|
|---|
| 158 | @abstractmethod
|
|---|
| 159 | def has_table(self, table):
|
|---|
| 160 | """Returns whether the table exists."""
|
|---|
| 161 | pass
|
|---|
| 162 |
|
|---|
| 163 | @abstractmethod
|
|---|
| 164 | def like(self):
|
|---|
| 165 | """Returns a case-insensitive `LIKE` clause."""
|
|---|
| 166 | pass
|
|---|
| 167 |
|
|---|
| 168 | @abstractmethod
|
|---|
| 169 | def like_escape(self, text):
|
|---|
| 170 | """Returns `text` escaped for use in a `LIKE` clause."""
|
|---|
| 171 | pass
|
|---|
| 172 |
|
|---|
| 173 | @abstractmethod
|
|---|
| 174 | def prefix_match(self):
|
|---|
| 175 | """Return a case sensitive prefix-matching operator."""
|
|---|
| 176 | pass
|
|---|
| 177 |
|
|---|
| 178 | @abstractmethod
|
|---|
| 179 | def prefix_match_value(self, prefix):
|
|---|
| 180 | """Return a value for case sensitive prefix-matching operator."""
|
|---|
| 181 | pass
|
|---|
| 182 |
|
|---|
| 183 | @abstractmethod
|
|---|
| 184 | def quote(self, identifier):
|
|---|
| 185 | """Returns the quoted `identifier`."""
|
|---|
| 186 | pass
|
|---|
| 187 |
|
|---|
| 188 | @abstractmethod
|
|---|
| 189 | def reset_tables(self):
|
|---|
| 190 | """Deletes all data from the tables and resets autoincrement indexes.
|
|---|
| 191 |
|
|---|
| 192 | :return: list of names of the tables that were reset.
|
|---|
| 193 | """
|
|---|
| 194 | pass
|
|---|
| 195 |
|
|---|
| 196 | @abstractmethod
|
|---|
| 197 | def update_sequence(self, cursor, table, column='id'):
|
|---|
| 198 | """Updates the current value of the primary key sequence for `table`.
|
|---|
| 199 | The `column` of the primary key may be specified, which defaults
|
|---|
| 200 | to `id`."""
|
|---|
| 201 | pass
|
|---|
| 202 |
|
|---|
| 203 |
|
|---|
| 204 | class IDatabaseConnector(Interface):
|
|---|
| 205 | """Extension point interface for components that support the
|
|---|
| 206 | connection to relational databases.
|
|---|
| 207 | """
|
|---|
| 208 |
|
|---|
| 209 | def get_supported_schemes():
|
|---|
| 210 | """Return the connection URL schemes supported by the
|
|---|
| 211 | connector, and their relative priorities as an iterable of
|
|---|
| 212 | `(scheme, priority)` tuples.
|
|---|
| 213 |
|
|---|
| 214 | If `priority` is a negative number, this is indicative of an
|
|---|
| 215 | error condition with the connector. An error message should be
|
|---|
| 216 | attached to the `error` attribute of the connector.
|
|---|
| 217 | """
|
|---|
| 218 |
|
|---|
| 219 | def get_connection(path, log=None, **kwargs):
|
|---|
| 220 | """Create a new connection to the database."""
|
|---|
| 221 |
|
|---|
| 222 | def get_exceptions():
|
|---|
| 223 | """Return an object (typically a module) containing all the
|
|---|
| 224 | backend-specific exception types as attributes, named
|
|---|
| 225 | according to the Python Database API
|
|---|
| 226 | (http://www.python.org/dev/peps/pep-0249/).
|
|---|
| 227 | """
|
|---|
| 228 |
|
|---|
| 229 | def init_db(path, schema=None, log=None, **kwargs):
|
|---|
| 230 | """Initialize the database."""
|
|---|
| 231 |
|
|---|
| 232 | def destroy_db(self, path, log=None, **kwargs):
|
|---|
| 233 | """Destroy the database."""
|
|---|
| 234 |
|
|---|
| 235 | def db_exists(self, path, log=None, **kwargs):
|
|---|
| 236 | """Return `True` if the database exists."""
|
|---|
| 237 |
|
|---|
| 238 | def to_sql(table):
|
|---|
| 239 | """Return the DDL statements necessary to create the specified
|
|---|
| 240 | table, including indices."""
|
|---|
| 241 |
|
|---|
| 242 | def backup(dest):
|
|---|
| 243 | """Backup the database to a location defined by
|
|---|
| 244 | trac.backup_dir"""
|
|---|
| 245 |
|
|---|
| 246 | def get_system_info():
|
|---|
| 247 | """Yield a sequence of `(name, version)` tuples describing the
|
|---|
| 248 | name and version information of external packages used by the
|
|---|
| 249 | connector.
|
|---|
| 250 | """
|
|---|
| 251 |
|
|---|
| 252 |
|
|---|
| 253 | class DatabaseManager(Component):
|
|---|
| 254 | """Component used to manage the `IDatabaseConnector` implementations."""
|
|---|
| 255 |
|
|---|
| 256 | implements(IEnvironmentSetupParticipant, ISystemInfoProvider)
|
|---|
| 257 |
|
|---|
| 258 | connectors = ExtensionPoint(IDatabaseConnector)
|
|---|
| 259 |
|
|---|
| 260 | connection_uri = Option('trac', 'database', 'sqlite:db/trac.db',
|
|---|
| 261 | """Database connection
|
|---|
| 262 | [wiki:TracEnvironment#DatabaseConnectionStrings string] for this
|
|---|
| 263 | project""")
|
|---|
| 264 |
|
|---|
| 265 | backup_dir = Option('trac', 'backup_dir', 'db',
|
|---|
| 266 | """Database backup location""")
|
|---|
| 267 |
|
|---|
| 268 | timeout = IntOption('trac', 'timeout', '20',
|
|---|
| 269 | """Timeout value for database connection, in seconds.
|
|---|
| 270 | Use '0' to specify ''no timeout''.""")
|
|---|
| 271 |
|
|---|
| 272 | debug_sql = BoolOption('trac', 'debug_sql', False,
|
|---|
| 273 | """Show the SQL queries in the Trac log, at DEBUG level.
|
|---|
| 274 | """)
|
|---|
| 275 |
|
|---|
| 276 | def __init__(self):
|
|---|
| 277 | self._cnx_pool = None
|
|---|
| 278 | self._transaction_local = ThreadLocal(wdb=None, rdb=None)
|
|---|
| 279 |
|
|---|
| 280 | def init_db(self):
|
|---|
| 281 | connector, args = self.get_connector()
|
|---|
| 282 | args['schema'] = db_default.schema
|
|---|
| 283 | connector.init_db(**args)
|
|---|
| 284 | version = db_default.db_version
|
|---|
| 285 | self.set_database_version(version, 'initial_database_version')
|
|---|
| 286 | self.set_database_version(version)
|
|---|
| 287 |
|
|---|
| 288 | def insert_default_data(self):
|
|---|
| 289 | self.insert_into_tables(db_default.get_data)
|
|---|
| 290 |
|
|---|
| 291 | def destroy_db(self):
|
|---|
| 292 | connector, args = self.get_connector()
|
|---|
| 293 | # Connections to on-disk db must be closed before deleting it.
|
|---|
| 294 | self.shutdown()
|
|---|
| 295 | connector.destroy_db(**args)
|
|---|
| 296 |
|
|---|
| 297 | def db_exists(self):
|
|---|
| 298 | connector, args = self.get_connector()
|
|---|
| 299 | return connector.db_exists(**args)
|
|---|
| 300 |
|
|---|
| 301 | def create_tables(self, schema):
|
|---|
| 302 | """Create the specified tables.
|
|---|
| 303 |
|
|---|
| 304 | :param schema: an iterable of table objects.
|
|---|
| 305 |
|
|---|
| 306 | :since: version 1.0.2
|
|---|
| 307 | """
|
|---|
| 308 | connector = self.get_connector()[0]
|
|---|
| 309 | with self.env.db_transaction as db:
|
|---|
| 310 | for table in schema:
|
|---|
| 311 | for sql in connector.to_sql(table):
|
|---|
| 312 | db(sql)
|
|---|
| 313 |
|
|---|
| 314 | def drop_columns(self, table, columns):
|
|---|
| 315 | """Drops the specified columns from table.
|
|---|
| 316 |
|
|---|
| 317 | :since: version 1.2
|
|---|
| 318 | """
|
|---|
| 319 | table_name = table.name if isinstance(table, Table) else table
|
|---|
| 320 | with self.env.db_transaction as db:
|
|---|
| 321 | if not db.has_table(table_name):
|
|---|
| 322 | raise self.env.db_exc.OperationalError('Table %s not found' %
|
|---|
| 323 | db.quote(table_name))
|
|---|
| 324 | for col in columns:
|
|---|
| 325 | db.drop_column(table_name, col)
|
|---|
| 326 |
|
|---|
| 327 | def drop_tables(self, schema):
|
|---|
| 328 | """Drop the specified tables.
|
|---|
| 329 |
|
|---|
| 330 | :param schema: an iterable of `Table` objects or table names.
|
|---|
| 331 |
|
|---|
| 332 | :since: version 1.0.2
|
|---|
| 333 | """
|
|---|
| 334 | with self.env.db_transaction as db:
|
|---|
| 335 | for table in schema:
|
|---|
| 336 | table_name = table.name if isinstance(table, Table) else table
|
|---|
| 337 | db.drop_table(table_name)
|
|---|
| 338 |
|
|---|
| 339 | def insert_into_tables(self, data_or_callable):
|
|---|
| 340 | """Insert data into existing tables.
|
|---|
| 341 |
|
|---|
| 342 | :param data_or_callable: Nested tuples of table names, column names
|
|---|
| 343 | and row data::
|
|---|
| 344 |
|
|---|
| 345 | (table1,
|
|---|
| 346 | (column1, column2),
|
|---|
| 347 | ((row1col1, row1col2),
|
|---|
| 348 | (row2col1, row2col2)),
|
|---|
| 349 | table2, ...)
|
|---|
| 350 |
|
|---|
| 351 | or a callable that takes a single parameter
|
|---|
| 352 | `db` and returns the aforementioned nested
|
|---|
| 353 | tuple.
|
|---|
| 354 | :since: version 1.1.3
|
|---|
| 355 | """
|
|---|
| 356 | with self.env.db_transaction as db:
|
|---|
| 357 | data = data_or_callable(db) if callable(data_or_callable) \
|
|---|
| 358 | else data_or_callable
|
|---|
| 359 | for table, cols, vals in data:
|
|---|
| 360 | db.executemany("INSERT INTO %s (%s) VALUES (%s)"
|
|---|
| 361 | % (db.quote(table), ','.join(cols),
|
|---|
| 362 | ','.join(['%s'] * len(cols))), vals)
|
|---|
| 363 |
|
|---|
| 364 | def reset_tables(self):
|
|---|
| 365 | """Deletes all data from the tables and resets autoincrement indexes.
|
|---|
| 366 |
|
|---|
| 367 | :return: list of names of the tables that were reset.
|
|---|
| 368 |
|
|---|
| 369 | :since: version 1.1.3
|
|---|
| 370 | """
|
|---|
| 371 | with self.env.db_transaction as db:
|
|---|
| 372 | return db.reset_tables()
|
|---|
| 373 |
|
|---|
| 374 | def upgrade_tables(self, new_schema):
|
|---|
| 375 | """Upgrade table schema to `new_schema`, preserving data in
|
|---|
| 376 | columns that exist in the current schema and `new_schema`.
|
|---|
| 377 |
|
|---|
| 378 | :param new_schema: tuple or list of `Table` objects
|
|---|
| 379 |
|
|---|
| 380 | :since: version 1.2
|
|---|
| 381 | """
|
|---|
| 382 | with self.env.db_transaction as db:
|
|---|
| 383 | cursor = db.cursor()
|
|---|
| 384 | for new_table in new_schema:
|
|---|
| 385 | temp_table_name = new_table.name + '_old'
|
|---|
| 386 | has_table = self.has_table(new_table)
|
|---|
| 387 | if has_table:
|
|---|
| 388 | old_column_names = set(self.get_column_names(new_table))
|
|---|
| 389 | new_column_names = {col.name for col in new_table.columns}
|
|---|
| 390 | column_names = old_column_names & new_column_names
|
|---|
| 391 | if column_names:
|
|---|
| 392 | cols_to_copy = ','.join(db.quote(name)
|
|---|
| 393 | for name in column_names)
|
|---|
| 394 | cursor.execute("""
|
|---|
| 395 | CREATE TEMPORARY TABLE %s AS SELECT * FROM %s
|
|---|
| 396 | """ % (db.quote(temp_table_name),
|
|---|
| 397 | db.quote(new_table.name)))
|
|---|
| 398 | self.drop_tables((new_table,))
|
|---|
| 399 | self.create_tables((new_table,))
|
|---|
| 400 | if has_table and column_names:
|
|---|
| 401 | cursor.execute("""
|
|---|
| 402 | INSERT INTO %s (%s) SELECT %s FROM %s
|
|---|
| 403 | """ % (db.quote(new_table.name), cols_to_copy,
|
|---|
| 404 | cols_to_copy, db.quote(temp_table_name)))
|
|---|
| 405 | for col in new_table.columns:
|
|---|
| 406 | if col.auto_increment:
|
|---|
| 407 | db.update_sequence(cursor, new_table.name,
|
|---|
| 408 | col.name)
|
|---|
| 409 | self.drop_tables((temp_table_name,))
|
|---|
| 410 |
|
|---|
| 411 | def get_connection(self, readonly=False):
|
|---|
| 412 | """Get a database connection from the pool.
|
|---|
| 413 |
|
|---|
| 414 | If `readonly` is `True`, the returned connection will purposely
|
|---|
| 415 | lack the `rollback` and `commit` methods.
|
|---|
| 416 | """
|
|---|
| 417 | if not self._cnx_pool:
|
|---|
| 418 | connector, args = self.get_connector()
|
|---|
| 419 | self._cnx_pool = ConnectionPool(5, connector, **args)
|
|---|
| 420 | db = self._cnx_pool.get_cnx(self.timeout or None)
|
|---|
| 421 | if readonly:
|
|---|
| 422 | db = ConnectionWrapper(db, readonly=True)
|
|---|
| 423 | return db
|
|---|
| 424 |
|
|---|
| 425 | def get_database_version(self, name='database_version'):
|
|---|
| 426 | """Returns the database version from the SYSTEM table as an int,
|
|---|
| 427 | or `False` if the entry is not found.
|
|---|
| 428 |
|
|---|
| 429 | :param name: The name of the entry that contains the database version
|
|---|
| 430 | in the SYSTEM table. Defaults to `database_version`,
|
|---|
| 431 | which contains the database version for Trac.
|
|---|
| 432 | """
|
|---|
| 433 | with self.env.db_query as db:
|
|---|
| 434 | for value, in db("""
|
|---|
| 435 | SELECT value FROM {0} WHERE name=%s
|
|---|
| 436 | """.format(db.quote('system')), (name,)):
|
|---|
| 437 | return int(value)
|
|---|
| 438 | else:
|
|---|
| 439 | return False
|
|---|
| 440 |
|
|---|
| 441 | def get_exceptions(self):
|
|---|
| 442 | return self.get_connector()[0].get_exceptions()
|
|---|
| 443 |
|
|---|
| 444 | def get_sequence_names(self):
|
|---|
| 445 | """Returns a list of the sequence names.
|
|---|
| 446 |
|
|---|
| 447 | :since: 1.3.2
|
|---|
| 448 | """
|
|---|
| 449 | with self.env.db_query as db:
|
|---|
| 450 | return db.get_sequence_names()
|
|---|
| 451 |
|
|---|
| 452 | def get_table_names(self):
|
|---|
| 453 | """Returns a list of the table names.
|
|---|
| 454 |
|
|---|
| 455 | :since: 1.1.6
|
|---|
| 456 | """
|
|---|
| 457 | with self.env.db_query as db:
|
|---|
| 458 | return db.get_table_names()
|
|---|
| 459 |
|
|---|
| 460 | def get_column_names(self, table):
|
|---|
| 461 | """Returns a list of the column names for `table`.
|
|---|
| 462 |
|
|---|
| 463 | :param table: a `Table` object or table name.
|
|---|
| 464 |
|
|---|
| 465 | :since: 1.2
|
|---|
| 466 | """
|
|---|
| 467 | table_name = table.name if isinstance(table, Table) else table
|
|---|
| 468 | with self.env.db_query as db:
|
|---|
| 469 | if not db.has_table(table_name):
|
|---|
| 470 | raise self.env.db_exc.OperationalError('Table %s not found' %
|
|---|
| 471 | db.quote(table_name))
|
|---|
| 472 | return db.get_column_names(table_name)
|
|---|
| 473 |
|
|---|
| 474 | def has_table(self, table):
|
|---|
| 475 | """Returns whether the table exists."""
|
|---|
| 476 | table_name = table.name if isinstance(table, Table) else table
|
|---|
| 477 | with self.env.db_query as db:
|
|---|
| 478 | return db.has_table(table_name)
|
|---|
| 479 |
|
|---|
| 480 | def set_database_version(self, version, name='database_version'):
|
|---|
| 481 | """Sets the database version in the SYSTEM table.
|
|---|
| 482 |
|
|---|
| 483 | :param version: an integer database version.
|
|---|
| 484 | :param name: The name of the entry that contains the database version
|
|---|
| 485 | in the SYSTEM table. Defaults to `database_version`,
|
|---|
| 486 | which contains the database version for Trac.
|
|---|
| 487 | """
|
|---|
| 488 | current_database_version = self.get_database_version(name)
|
|---|
| 489 | if current_database_version is False:
|
|---|
| 490 | with self.env.db_transaction as db:
|
|---|
| 491 | db("""
|
|---|
| 492 | INSERT INTO {0} (name, value) VALUES (%s, %s)
|
|---|
| 493 | """.format(db.quote('system')), (name, version))
|
|---|
| 494 | elif version != self.get_database_version(name):
|
|---|
| 495 | with self.env.db_transaction as db:
|
|---|
| 496 | db("""
|
|---|
| 497 | UPDATE {0} SET value=%s WHERE name=%s
|
|---|
| 498 | """.format(db.quote('system')), (version, name))
|
|---|
| 499 | self.log.info("Upgraded %s from %d to %d",
|
|---|
| 500 | name, current_database_version, version)
|
|---|
| 501 |
|
|---|
| 502 | def needs_upgrade(self, version, name='database_version'):
|
|---|
| 503 | """Checks the database version to determine if an upgrade is needed.
|
|---|
| 504 |
|
|---|
| 505 | :param version: the expected integer database version.
|
|---|
| 506 | :param name: the name of the entry in the SYSTEM table that contains
|
|---|
| 507 | the database version. Defaults to `database_version`,
|
|---|
| 508 | which contains the database version for Trac.
|
|---|
| 509 |
|
|---|
| 510 | :return: `True` if the stored version is less than the expected
|
|---|
| 511 | version, `False` if it is equal to the expected version.
|
|---|
| 512 | :raises TracError: if the stored version is greater than the expected
|
|---|
| 513 | version.
|
|---|
| 514 | """
|
|---|
| 515 | dbver = self.get_database_version(name)
|
|---|
| 516 | if dbver == version:
|
|---|
| 517 | return False
|
|---|
| 518 | elif dbver > version:
|
|---|
| 519 | raise TracError(_("Need to downgrade %(name)s.", name=name))
|
|---|
| 520 | self.log.info("Need to upgrade %s from %d to %d",
|
|---|
| 521 | name, dbver, version)
|
|---|
| 522 | return True
|
|---|
| 523 |
|
|---|
| 524 | def upgrade(self, version, name='database_version', pkg='trac.upgrades'):
|
|---|
| 525 | """Invokes `do_upgrade(env, version, cursor)` in module
|
|---|
| 526 | `"%s/db%i.py" % (pkg, version)`, for each required version upgrade.
|
|---|
| 527 |
|
|---|
| 528 | :param version: the expected integer database version.
|
|---|
| 529 | :param name: the name of the entry in the SYSTEM table that contains
|
|---|
| 530 | the database version. Defaults to `database_version`,
|
|---|
| 531 | which contains the database version for Trac.
|
|---|
| 532 | :param pkg: the package containing the upgrade modules.
|
|---|
| 533 |
|
|---|
| 534 | :raises TracError: if the package or module doesn't exist.
|
|---|
| 535 | """
|
|---|
| 536 | dbver = self.get_database_version(name)
|
|---|
| 537 | for i in range(dbver + 1, version + 1):
|
|---|
| 538 | module = '%s.db%i' % (pkg, i)
|
|---|
| 539 | try:
|
|---|
| 540 | upgrader = importlib.import_module(module)
|
|---|
| 541 | except ImportError:
|
|---|
| 542 | raise TracError(_("No upgrade module %(module)s.py",
|
|---|
| 543 | module=module))
|
|---|
| 544 | with self.env.db_transaction as db:
|
|---|
| 545 | cursor = db.cursor()
|
|---|
| 546 | upgrader.do_upgrade(self.env, i, cursor)
|
|---|
| 547 | self.set_database_version(i, name)
|
|---|
| 548 |
|
|---|
| 549 | def shutdown(self, tid=None):
|
|---|
| 550 | if self._cnx_pool:
|
|---|
| 551 | self._cnx_pool.shutdown(tid)
|
|---|
| 552 | if not tid:
|
|---|
| 553 | self._cnx_pool = None
|
|---|
| 554 |
|
|---|
| 555 | def backup(self, dest=None):
|
|---|
| 556 | """Save a backup of the database.
|
|---|
| 557 |
|
|---|
| 558 | :param dest: base filename to write to.
|
|---|
| 559 |
|
|---|
| 560 | Returns the file actually written.
|
|---|
| 561 | """
|
|---|
| 562 | connector, args = self.get_connector()
|
|---|
| 563 | if not dest:
|
|---|
| 564 | backup_dir = self.backup_dir
|
|---|
| 565 | if not os.path.isabs(backup_dir):
|
|---|
| 566 | backup_dir = os.path.join(self.env.path, backup_dir)
|
|---|
| 567 | db_str = self.config.get('trac', 'database')
|
|---|
| 568 | db_name, db_path = db_str.split(":", 1)
|
|---|
| 569 | dest_name = '%s.%i.%d.bak' % (db_name, self.env.database_version,
|
|---|
| 570 | int(time.time()))
|
|---|
| 571 | dest = os.path.join(backup_dir, dest_name)
|
|---|
| 572 | else:
|
|---|
| 573 | backup_dir = os.path.dirname(dest)
|
|---|
| 574 | if not os.path.exists(backup_dir):
|
|---|
| 575 | os.makedirs(backup_dir)
|
|---|
| 576 | return connector.backup(dest)
|
|---|
| 577 |
|
|---|
| 578 | def get_connector(self):
|
|---|
| 579 | scheme, args = parse_connection_uri(self.connection_uri)
|
|---|
| 580 | candidates = [
|
|---|
| 581 | (priority, connector)
|
|---|
| 582 | for connector in self.connectors
|
|---|
| 583 | for scheme_, priority in connector.get_supported_schemes()
|
|---|
| 584 | if scheme_ == scheme
|
|---|
| 585 | ]
|
|---|
| 586 | if not candidates:
|
|---|
| 587 | raise TracError(_('Unsupported database type "%(scheme)s"',
|
|---|
| 588 | scheme=scheme))
|
|---|
| 589 | priority, connector = max(candidates)
|
|---|
| 590 | if priority < 0:
|
|---|
| 591 | raise TracError(connector.error)
|
|---|
| 592 |
|
|---|
| 593 | if scheme == 'sqlite':
|
|---|
| 594 | if args['path'] == ':memory:':
|
|---|
| 595 | # Special case for SQLite in-memory database, always get
|
|---|
| 596 | # the /same/ connection over
|
|---|
| 597 | pass
|
|---|
| 598 | elif not os.path.isabs(args['path']):
|
|---|
| 599 | # Special case for SQLite to support a path relative to the
|
|---|
| 600 | # environment directory
|
|---|
| 601 | args['path'] = os.path.join(self.env.path,
|
|---|
| 602 | args['path'].lstrip('/'))
|
|---|
| 603 |
|
|---|
| 604 | if self.debug_sql:
|
|---|
| 605 | args['log'] = self.log
|
|---|
| 606 | return connector, args
|
|---|
| 607 |
|
|---|
| 608 | # IEnvironmentSetupParticipant methods
|
|---|
| 609 |
|
|---|
| 610 | def environment_created(self):
|
|---|
| 611 | pass
|
|---|
| 612 |
|
|---|
| 613 | def environment_needs_upgrade(self):
|
|---|
| 614 | return self.needs_upgrade(db_default.db_version)
|
|---|
| 615 |
|
|---|
| 616 | def upgrade_environment(self):
|
|---|
| 617 | self.upgrade(db_default.db_version)
|
|---|
| 618 |
|
|---|
| 619 | # ISystemInfoProvider methods
|
|---|
| 620 |
|
|---|
| 621 | def get_system_info(self):
|
|---|
| 622 | connector = self.get_connector()[0]
|
|---|
| 623 | for info in connector.get_system_info():
|
|---|
| 624 | yield info
|
|---|
| 625 |
|
|---|
| 626 |
|
|---|
| 627 | def get_column_names(cursor):
|
|---|
| 628 | """Retrieve column names from a cursor, if possible."""
|
|---|
| 629 | return [str(d[0], 'utf-8') if isinstance(d[0], bytes) else d[0]
|
|---|
| 630 | for d in cursor.description] if cursor.description else []
|
|---|
| 631 |
|
|---|
| 632 |
|
|---|
| 633 | def parse_connection_uri(db_str):
|
|---|
| 634 | """Parse the database connection string.
|
|---|
| 635 |
|
|---|
| 636 | The database connection string for an environment is specified through
|
|---|
| 637 | the `database` option in the `[trac]` section of trac.ini.
|
|---|
| 638 |
|
|---|
| 639 | :return: a tuple containing the scheme and a dictionary of attributes:
|
|---|
| 640 | `user`, `password`, `host`, `port`, `path`, `params`.
|
|---|
| 641 | :since: 1.1.3
|
|---|
| 642 | """
|
|---|
| 643 | if not db_str:
|
|---|
| 644 | section = tag.a("[trac]",
|
|---|
| 645 | title=_("TracIni documentation"),
|
|---|
| 646 | class_='trac-target-new',
|
|---|
| 647 | href='https://trac.edgewall.org/wiki/TracIni'
|
|---|
| 648 | '#trac-section')
|
|---|
| 649 | raise ConfigurationError(
|
|---|
| 650 | tag_("Database connection string is empty. Set the %(option)s "
|
|---|
| 651 | "configuration option in the %(section)s section of "
|
|---|
| 652 | "trac.ini. Please refer to the %(doc)s for help.",
|
|---|
| 653 | option=tag.code("database"), section=section,
|
|---|
| 654 | doc=_doc_db_str()))
|
|---|
| 655 |
|
|---|
| 656 | try:
|
|---|
| 657 | scheme, rest = db_str.split(':', 1)
|
|---|
| 658 | except ValueError:
|
|---|
| 659 | raise _invalid_db_str(db_str)
|
|---|
| 660 |
|
|---|
| 661 | if not rest.startswith('/'):
|
|---|
| 662 | if scheme == 'sqlite' and rest:
|
|---|
| 663 | # Support for relative and in-memory SQLite connection strings
|
|---|
| 664 | host = None
|
|---|
| 665 | path = rest
|
|---|
| 666 | else:
|
|---|
| 667 | raise _invalid_db_str(db_str)
|
|---|
| 668 | else:
|
|---|
| 669 | if not rest.startswith('//'):
|
|---|
| 670 | host = None
|
|---|
| 671 | rest = rest[1:]
|
|---|
| 672 | elif rest.startswith('///'):
|
|---|
| 673 | host = None
|
|---|
| 674 | rest = rest[3:]
|
|---|
| 675 | else:
|
|---|
| 676 | rest = rest[2:]
|
|---|
| 677 | if '/' in rest:
|
|---|
| 678 | host, rest = rest.split('/', 1)
|
|---|
| 679 | else:
|
|---|
| 680 | host = rest
|
|---|
| 681 | rest = ''
|
|---|
| 682 | path = None
|
|---|
| 683 |
|
|---|
| 684 | if host and '@' in host:
|
|---|
| 685 | user, host = host.split('@', 1)
|
|---|
| 686 | if ':' in user:
|
|---|
| 687 | user, password = user.split(':', 1)
|
|---|
| 688 | else:
|
|---|
| 689 | password = None
|
|---|
| 690 | if user:
|
|---|
| 691 | user = urllib.parse.unquote(user)
|
|---|
| 692 | if password:
|
|---|
| 693 | password = unicode_passwd(urllib.parse.unquote(password))
|
|---|
| 694 | else:
|
|---|
| 695 | user = password = None
|
|---|
| 696 |
|
|---|
| 697 | if host and ':' in host:
|
|---|
| 698 | host, port = host.split(':', 1)
|
|---|
| 699 | try:
|
|---|
| 700 | port = int(port)
|
|---|
| 701 | except ValueError:
|
|---|
| 702 | raise _invalid_db_str(db_str)
|
|---|
| 703 | else:
|
|---|
| 704 | port = None
|
|---|
| 705 |
|
|---|
| 706 | if not path:
|
|---|
| 707 | path = '/' + rest
|
|---|
| 708 | if os.name == 'nt':
|
|---|
| 709 | # Support local paths containing drive letters on Win32
|
|---|
| 710 | if len(rest) > 1 and rest[1] == '|':
|
|---|
| 711 | path = "%s:%s" % (rest[0], rest[2:])
|
|---|
| 712 |
|
|---|
| 713 | params = {}
|
|---|
| 714 | if '?' in path:
|
|---|
| 715 | path, qs = path.split('?', 1)
|
|---|
| 716 | qs = qs.split('&')
|
|---|
| 717 | for param in qs:
|
|---|
| 718 | try:
|
|---|
| 719 | name, value = param.split('=', 1)
|
|---|
| 720 | except ValueError:
|
|---|
| 721 | raise _invalid_db_str(db_str)
|
|---|
| 722 | value = urllib.parse.unquote(value)
|
|---|
| 723 | params[name] = value
|
|---|
| 724 |
|
|---|
| 725 | args = zip(('user', 'password', 'host', 'port', 'path', 'params'),
|
|---|
| 726 | (user, password, host, port, path, params))
|
|---|
| 727 | return scheme, {key: value for key, value in args if value}
|
|---|
| 728 |
|
|---|
| 729 |
|
|---|
| 730 | def _invalid_db_str(db_str):
|
|---|
| 731 | return ConfigurationError(
|
|---|
| 732 | tag_("Invalid format %(db_str)s for the database connection string. "
|
|---|
| 733 | "Please refer to the %(doc)s for help.",
|
|---|
| 734 | db_str=tag.code(db_str), doc=_doc_db_str()))
|
|---|
| 735 |
|
|---|
| 736 |
|
|---|
| 737 | def _doc_db_str():
|
|---|
| 738 | return tag.a(_("documentation"),
|
|---|
| 739 | title=_("Database Connection Strings documentation"),
|
|---|
| 740 | class_='trac-target-new',
|
|---|
| 741 | href='https://trac.edgewall.org/wiki/'
|
|---|
| 742 | 'TracIni#DatabaseConnectionStrings')
|
|---|