Edgewall Software

source: trunk/trac/db/api.py

Last change on this file was 17657, checked in by Jun Omae, 8 months ago

1.5.4dev: update copyright year to 2023 (refs #13402)

[skip ci]

  • Property svn:eol-style set to native
File size: 25.7 KB
Line 
1# -*- coding: utf-8 -*-
2#
3# Copyright (C) 2005-2023 Edgewall Software
4# Copyright (C) 2005 Christopher Lenz <cmlenz@gmx.de>
5# All rights reserved.
6#
7# This software is licensed as described in the file COPYING, which
8# you should have received as part of this distribution. The terms
9# are also available at https://trac.edgewall.org/wiki/TracLicense.
10#
11# This software consists of voluntary contributions made by many
12# individuals. For the exact contribution history, see the revision
13# history and logs, available at https://trac.edgewall.org/log/.
14#
15# Author: Christopher Lenz <cmlenz@gmx.de>
16
17import importlib
18import os
19import time
20import urllib.parse
21from abc import ABCMeta, abstractmethod
22
23from trac import db_default
24from trac.api import IEnvironmentSetupParticipant, ISystemInfoProvider
25from trac.config import BoolOption, ConfigurationError, IntOption, Option
26from trac.core import *
27from trac.db.pool import ConnectionPool
28from trac.db.schema import Table
29from trac.db.util import ConnectionWrapper
30from trac.util.concurrency import ThreadLocal
31from trac.util.html import tag
32from trac.util.text import unicode_passwd
33from trac.util.translation import _, tag_
34
35
36class DbContextManager(object):
37 """Database Context Manager
38
39 The outermost `DbContextManager` will close the connection.
40 """
41
42 db = None
43
44 def __init__(self, env):
45 self.dbmgr = DatabaseManager(env)
46
47 def execute(self, query, params=None):
48 """Shortcut for directly executing a query."""
49 with self as db:
50 return db.execute(query, params)
51
52 __call__ = execute
53
54 def executemany(self, query, params=None):
55 """Shortcut for directly calling "executemany" on a query."""
56 with self as db:
57 return db.executemany(query, params)
58
59
60class TransactionContextManager(DbContextManager):
61 """Transactioned Database Context Manager for retrieving a
62 `~trac.db.util.ConnectionWrapper`.
63
64 The outermost such context manager will perform a commit upon
65 normal exit or a rollback after an exception.
66 """
67
68 def __enter__(self):
69 db = self.dbmgr._transaction_local.wdb # outermost writable db
70 if not db:
71 db = self.dbmgr._transaction_local.rdb # reuse wrapped connection
72 if db:
73 db = ConnectionWrapper(db.cnx, db.log)
74 else:
75 db = self.dbmgr.get_connection()
76 self.dbmgr._transaction_local.wdb = self.db = db
77 return db
78
79 def __exit__(self, et, ev, tb):
80 if self.db:
81 self.dbmgr._transaction_local.wdb = None
82 if et is None:
83 self.db.commit()
84 else:
85 self.db.rollback()
86 if not self.dbmgr._transaction_local.rdb:
87 self.db.close()
88
89
90class QueryContextManager(DbContextManager):
91 """Database Context Manager for retrieving a read-only
92 `~trac.db.util.ConnectionWrapper`.
93 """
94
95 def __enter__(self):
96 db = self.dbmgr._transaction_local.rdb # outermost readonly db
97 if not db:
98 db = self.dbmgr._transaction_local.wdb # reuse wrapped connection
99 if db:
100 db = ConnectionWrapper(db.cnx, db.log, readonly=True)
101 else:
102 db = self.dbmgr.get_connection(readonly=True)
103 self.dbmgr._transaction_local.rdb = self.db = db
104 return db
105
106 def __exit__(self, et, ev, tb):
107 if self.db:
108 self.dbmgr._transaction_local.rdb = None
109 if not self.dbmgr._transaction_local.wdb:
110 self.db.close()
111
112
113class ConnectionBase(object, metaclass=ABCMeta):
114 """Abstract base class for database connection classes."""
115
116 @abstractmethod
117 def cast(self, column, type):
118 """Returns a clause casting `column` as `type`."""
119 pass
120
121 @abstractmethod
122 def concat(self, *args):
123 """Returns a clause concatenating the sequence `args`."""
124 pass
125
126 @abstractmethod
127 def drop_column(self, table, column):
128 """Drops the `column` from `table`."""
129 pass
130
131 @abstractmethod
132 def drop_table(self, table):
133 """Drops the `table`."""
134 pass
135
136 @abstractmethod
137 def get_column_names(self, table):
138 """Returns the list of the column names in `table`."""
139 pass
140
141 @abstractmethod
142 def get_last_id(self, cursor, table, column='id'):
143 """Returns the current value of the primary key sequence for `table`.
144 The `column` of the primary key may be specified, which defaults
145 to `id`."""
146 pass
147
148 @abstractmethod
149 def get_sequence_names(self):
150 """Returns a list of the sequence names."""
151 pass
152
153 @abstractmethod
154 def get_table_names(self):
155 """Returns a list of the table names."""
156 pass
157
158 @abstractmethod
159 def has_table(self, table):
160 """Returns whether the table exists."""
161 pass
162
163 @abstractmethod
164 def like(self):
165 """Returns a case-insensitive `LIKE` clause."""
166 pass
167
168 @abstractmethod
169 def like_escape(self, text):
170 """Returns `text` escaped for use in a `LIKE` clause."""
171 pass
172
173 @abstractmethod
174 def prefix_match(self):
175 """Return a case sensitive prefix-matching operator."""
176 pass
177
178 @abstractmethod
179 def prefix_match_value(self, prefix):
180 """Return a value for case sensitive prefix-matching operator."""
181 pass
182
183 @abstractmethod
184 def quote(self, identifier):
185 """Returns the quoted `identifier`."""
186 pass
187
188 @abstractmethod
189 def reset_tables(self):
190 """Deletes all data from the tables and resets autoincrement indexes.
191
192 :return: list of names of the tables that were reset.
193 """
194 pass
195
196 @abstractmethod
197 def update_sequence(self, cursor, table, column='id'):
198 """Updates the current value of the primary key sequence for `table`.
199 The `column` of the primary key may be specified, which defaults
200 to `id`."""
201 pass
202
203
204class IDatabaseConnector(Interface):
205 """Extension point interface for components that support the
206 connection to relational databases.
207 """
208
209 def get_supported_schemes():
210 """Return the connection URL schemes supported by the
211 connector, and their relative priorities as an iterable of
212 `(scheme, priority)` tuples.
213
214 If `priority` is a negative number, this is indicative of an
215 error condition with the connector. An error message should be
216 attached to the `error` attribute of the connector.
217 """
218
219 def get_connection(path, log=None, **kwargs):
220 """Create a new connection to the database."""
221
222 def get_exceptions():
223 """Return an object (typically a module) containing all the
224 backend-specific exception types as attributes, named
225 according to the Python Database API
226 (http://www.python.org/dev/peps/pep-0249/).
227 """
228
229 def init_db(path, schema=None, log=None, **kwargs):
230 """Initialize the database."""
231
232 def destroy_db(self, path, log=None, **kwargs):
233 """Destroy the database."""
234
235 def db_exists(self, path, log=None, **kwargs):
236 """Return `True` if the database exists."""
237
238 def to_sql(table):
239 """Return the DDL statements necessary to create the specified
240 table, including indices."""
241
242 def backup(dest):
243 """Backup the database to a location defined by
244 trac.backup_dir"""
245
246 def get_system_info():
247 """Yield a sequence of `(name, version)` tuples describing the
248 name and version information of external packages used by the
249 connector.
250 """
251
252
253class DatabaseManager(Component):
254 """Component used to manage the `IDatabaseConnector` implementations."""
255
256 implements(IEnvironmentSetupParticipant, ISystemInfoProvider)
257
258 connectors = ExtensionPoint(IDatabaseConnector)
259
260 connection_uri = Option('trac', 'database', 'sqlite:db/trac.db',
261 """Database connection
262 [wiki:TracEnvironment#DatabaseConnectionStrings string] for this
263 project""")
264
265 backup_dir = Option('trac', 'backup_dir', 'db',
266 """Database backup location""")
267
268 timeout = IntOption('trac', 'timeout', '20',
269 """Timeout value for database connection, in seconds.
270 Use '0' to specify ''no timeout''.""")
271
272 debug_sql = BoolOption('trac', 'debug_sql', False,
273 """Show the SQL queries in the Trac log, at DEBUG level.
274 """)
275
276 def __init__(self):
277 self._cnx_pool = None
278 self._transaction_local = ThreadLocal(wdb=None, rdb=None)
279
280 def init_db(self):
281 connector, args = self.get_connector()
282 args['schema'] = db_default.schema
283 connector.init_db(**args)
284 version = db_default.db_version
285 self.set_database_version(version, 'initial_database_version')
286 self.set_database_version(version)
287
288 def insert_default_data(self):
289 self.insert_into_tables(db_default.get_data)
290
291 def destroy_db(self):
292 connector, args = self.get_connector()
293 # Connections to on-disk db must be closed before deleting it.
294 self.shutdown()
295 connector.destroy_db(**args)
296
297 def db_exists(self):
298 connector, args = self.get_connector()
299 return connector.db_exists(**args)
300
301 def create_tables(self, schema):
302 """Create the specified tables.
303
304 :param schema: an iterable of table objects.
305
306 :since: version 1.0.2
307 """
308 connector = self.get_connector()[0]
309 with self.env.db_transaction as db:
310 for table in schema:
311 for sql in connector.to_sql(table):
312 db(sql)
313
314 def drop_columns(self, table, columns):
315 """Drops the specified columns from table.
316
317 :since: version 1.2
318 """
319 table_name = table.name if isinstance(table, Table) else table
320 with self.env.db_transaction as db:
321 if not db.has_table(table_name):
322 raise self.env.db_exc.OperationalError('Table %s not found' %
323 db.quote(table_name))
324 for col in columns:
325 db.drop_column(table_name, col)
326
327 def drop_tables(self, schema):
328 """Drop the specified tables.
329
330 :param schema: an iterable of `Table` objects or table names.
331
332 :since: version 1.0.2
333 """
334 with self.env.db_transaction as db:
335 for table in schema:
336 table_name = table.name if isinstance(table, Table) else table
337 db.drop_table(table_name)
338
339 def insert_into_tables(self, data_or_callable):
340 """Insert data into existing tables.
341
342 :param data_or_callable: Nested tuples of table names, column names
343 and row data::
344
345 (table1,
346 (column1, column2),
347 ((row1col1, row1col2),
348 (row2col1, row2col2)),
349 table2, ...)
350
351 or a callable that takes a single parameter
352 `db` and returns the aforementioned nested
353 tuple.
354 :since: version 1.1.3
355 """
356 with self.env.db_transaction as db:
357 data = data_or_callable(db) if callable(data_or_callable) \
358 else data_or_callable
359 for table, cols, vals in data:
360 db.executemany("INSERT INTO %s (%s) VALUES (%s)"
361 % (db.quote(table), ','.join(cols),
362 ','.join(['%s'] * len(cols))), vals)
363
364 def reset_tables(self):
365 """Deletes all data from the tables and resets autoincrement indexes.
366
367 :return: list of names of the tables that were reset.
368
369 :since: version 1.1.3
370 """
371 with self.env.db_transaction as db:
372 return db.reset_tables()
373
374 def upgrade_tables(self, new_schema):
375 """Upgrade table schema to `new_schema`, preserving data in
376 columns that exist in the current schema and `new_schema`.
377
378 :param new_schema: tuple or list of `Table` objects
379
380 :since: version 1.2
381 """
382 with self.env.db_transaction as db:
383 cursor = db.cursor()
384 for new_table in new_schema:
385 temp_table_name = new_table.name + '_old'
386 has_table = self.has_table(new_table)
387 if has_table:
388 old_column_names = set(self.get_column_names(new_table))
389 new_column_names = {col.name for col in new_table.columns}
390 column_names = old_column_names & new_column_names
391 if column_names:
392 cols_to_copy = ','.join(db.quote(name)
393 for name in column_names)
394 cursor.execute("""
395 CREATE TEMPORARY TABLE %s AS SELECT * FROM %s
396 """ % (db.quote(temp_table_name),
397 db.quote(new_table.name)))
398 self.drop_tables((new_table,))
399 self.create_tables((new_table,))
400 if has_table and column_names:
401 cursor.execute("""
402 INSERT INTO %s (%s) SELECT %s FROM %s
403 """ % (db.quote(new_table.name), cols_to_copy,
404 cols_to_copy, db.quote(temp_table_name)))
405 for col in new_table.columns:
406 if col.auto_increment:
407 db.update_sequence(cursor, new_table.name,
408 col.name)
409 self.drop_tables((temp_table_name,))
410
411 def get_connection(self, readonly=False):
412 """Get a database connection from the pool.
413
414 If `readonly` is `True`, the returned connection will purposely
415 lack the `rollback` and `commit` methods.
416 """
417 if not self._cnx_pool:
418 connector, args = self.get_connector()
419 self._cnx_pool = ConnectionPool(5, connector, **args)
420 db = self._cnx_pool.get_cnx(self.timeout or None)
421 if readonly:
422 db = ConnectionWrapper(db, readonly=True)
423 return db
424
425 def get_database_version(self, name='database_version'):
426 """Returns the database version from the SYSTEM table as an int,
427 or `False` if the entry is not found.
428
429 :param name: The name of the entry that contains the database version
430 in the SYSTEM table. Defaults to `database_version`,
431 which contains the database version for Trac.
432 """
433 with self.env.db_query as db:
434 for value, in db("""
435 SELECT value FROM {0} WHERE name=%s
436 """.format(db.quote('system')), (name,)):
437 return int(value)
438 else:
439 return False
440
441 def get_exceptions(self):
442 return self.get_connector()[0].get_exceptions()
443
444 def get_sequence_names(self):
445 """Returns a list of the sequence names.
446
447 :since: 1.3.2
448 """
449 with self.env.db_query as db:
450 return db.get_sequence_names()
451
452 def get_table_names(self):
453 """Returns a list of the table names.
454
455 :since: 1.1.6
456 """
457 with self.env.db_query as db:
458 return db.get_table_names()
459
460 def get_column_names(self, table):
461 """Returns a list of the column names for `table`.
462
463 :param table: a `Table` object or table name.
464
465 :since: 1.2
466 """
467 table_name = table.name if isinstance(table, Table) else table
468 with self.env.db_query as db:
469 if not db.has_table(table_name):
470 raise self.env.db_exc.OperationalError('Table %s not found' %
471 db.quote(table_name))
472 return db.get_column_names(table_name)
473
474 def has_table(self, table):
475 """Returns whether the table exists."""
476 table_name = table.name if isinstance(table, Table) else table
477 with self.env.db_query as db:
478 return db.has_table(table_name)
479
480 def set_database_version(self, version, name='database_version'):
481 """Sets the database version in the SYSTEM table.
482
483 :param version: an integer database version.
484 :param name: The name of the entry that contains the database version
485 in the SYSTEM table. Defaults to `database_version`,
486 which contains the database version for Trac.
487 """
488 current_database_version = self.get_database_version(name)
489 if current_database_version is False:
490 with self.env.db_transaction as db:
491 db("""
492 INSERT INTO {0} (name, value) VALUES (%s, %s)
493 """.format(db.quote('system')), (name, version))
494 elif version != self.get_database_version(name):
495 with self.env.db_transaction as db:
496 db("""
497 UPDATE {0} SET value=%s WHERE name=%s
498 """.format(db.quote('system')), (version, name))
499 self.log.info("Upgraded %s from %d to %d",
500 name, current_database_version, version)
501
502 def needs_upgrade(self, version, name='database_version'):
503 """Checks the database version to determine if an upgrade is needed.
504
505 :param version: the expected integer database version.
506 :param name: the name of the entry in the SYSTEM table that contains
507 the database version. Defaults to `database_version`,
508 which contains the database version for Trac.
509
510 :return: `True` if the stored version is less than the expected
511 version, `False` if it is equal to the expected version.
512 :raises TracError: if the stored version is greater than the expected
513 version.
514 """
515 dbver = self.get_database_version(name)
516 if dbver == version:
517 return False
518 elif dbver > version:
519 raise TracError(_("Need to downgrade %(name)s.", name=name))
520 self.log.info("Need to upgrade %s from %d to %d",
521 name, dbver, version)
522 return True
523
524 def upgrade(self, version, name='database_version', pkg='trac.upgrades'):
525 """Invokes `do_upgrade(env, version, cursor)` in module
526 `"%s/db%i.py" % (pkg, version)`, for each required version upgrade.
527
528 :param version: the expected integer database version.
529 :param name: the name of the entry in the SYSTEM table that contains
530 the database version. Defaults to `database_version`,
531 which contains the database version for Trac.
532 :param pkg: the package containing the upgrade modules.
533
534 :raises TracError: if the package or module doesn't exist.
535 """
536 dbver = self.get_database_version(name)
537 for i in range(dbver + 1, version + 1):
538 module = '%s.db%i' % (pkg, i)
539 try:
540 upgrader = importlib.import_module(module)
541 except ImportError:
542 raise TracError(_("No upgrade module %(module)s.py",
543 module=module))
544 with self.env.db_transaction as db:
545 cursor = db.cursor()
546 upgrader.do_upgrade(self.env, i, cursor)
547 self.set_database_version(i, name)
548
549 def shutdown(self, tid=None):
550 if self._cnx_pool:
551 self._cnx_pool.shutdown(tid)
552 if not tid:
553 self._cnx_pool = None
554
555 def backup(self, dest=None):
556 """Save a backup of the database.
557
558 :param dest: base filename to write to.
559
560 Returns the file actually written.
561 """
562 connector, args = self.get_connector()
563 if not dest:
564 backup_dir = self.backup_dir
565 if not os.path.isabs(backup_dir):
566 backup_dir = os.path.join(self.env.path, backup_dir)
567 db_str = self.config.get('trac', 'database')
568 db_name, db_path = db_str.split(":", 1)
569 dest_name = '%s.%i.%d.bak' % (db_name, self.env.database_version,
570 int(time.time()))
571 dest = os.path.join(backup_dir, dest_name)
572 else:
573 backup_dir = os.path.dirname(dest)
574 if not os.path.exists(backup_dir):
575 os.makedirs(backup_dir)
576 return connector.backup(dest)
577
578 def get_connector(self):
579 scheme, args = parse_connection_uri(self.connection_uri)
580 candidates = [
581 (priority, connector)
582 for connector in self.connectors
583 for scheme_, priority in connector.get_supported_schemes()
584 if scheme_ == scheme
585 ]
586 if not candidates:
587 raise TracError(_('Unsupported database type "%(scheme)s"',
588 scheme=scheme))
589 priority, connector = max(candidates)
590 if priority < 0:
591 raise TracError(connector.error)
592
593 if scheme == 'sqlite':
594 if args['path'] == ':memory:':
595 # Special case for SQLite in-memory database, always get
596 # the /same/ connection over
597 pass
598 elif not os.path.isabs(args['path']):
599 # Special case for SQLite to support a path relative to the
600 # environment directory
601 args['path'] = os.path.join(self.env.path,
602 args['path'].lstrip('/'))
603
604 if self.debug_sql:
605 args['log'] = self.log
606 return connector, args
607
608 # IEnvironmentSetupParticipant methods
609
610 def environment_created(self):
611 pass
612
613 def environment_needs_upgrade(self):
614 return self.needs_upgrade(db_default.db_version)
615
616 def upgrade_environment(self):
617 self.upgrade(db_default.db_version)
618
619 # ISystemInfoProvider methods
620
621 def get_system_info(self):
622 connector = self.get_connector()[0]
623 for info in connector.get_system_info():
624 yield info
625
626
627def get_column_names(cursor):
628 """Retrieve column names from a cursor, if possible."""
629 return [str(d[0], 'utf-8') if isinstance(d[0], bytes) else d[0]
630 for d in cursor.description] if cursor.description else []
631
632
633def parse_connection_uri(db_str):
634 """Parse the database connection string.
635
636 The database connection string for an environment is specified through
637 the `database` option in the `[trac]` section of trac.ini.
638
639 :return: a tuple containing the scheme and a dictionary of attributes:
640 `user`, `password`, `host`, `port`, `path`, `params`.
641 :since: 1.1.3
642 """
643 if not db_str:
644 section = tag.a("[trac]",
645 title=_("TracIni documentation"),
646 class_='trac-target-new',
647 href='https://trac.edgewall.org/wiki/TracIni'
648 '#trac-section')
649 raise ConfigurationError(
650 tag_("Database connection string is empty. Set the %(option)s "
651 "configuration option in the %(section)s section of "
652 "trac.ini. Please refer to the %(doc)s for help.",
653 option=tag.code("database"), section=section,
654 doc=_doc_db_str()))
655
656 try:
657 scheme, rest = db_str.split(':', 1)
658 except ValueError:
659 raise _invalid_db_str(db_str)
660
661 if not rest.startswith('/'):
662 if scheme == 'sqlite' and rest:
663 # Support for relative and in-memory SQLite connection strings
664 host = None
665 path = rest
666 else:
667 raise _invalid_db_str(db_str)
668 else:
669 if not rest.startswith('//'):
670 host = None
671 rest = rest[1:]
672 elif rest.startswith('///'):
673 host = None
674 rest = rest[3:]
675 else:
676 rest = rest[2:]
677 if '/' in rest:
678 host, rest = rest.split('/', 1)
679 else:
680 host = rest
681 rest = ''
682 path = None
683
684 if host and '@' in host:
685 user, host = host.split('@', 1)
686 if ':' in user:
687 user, password = user.split(':', 1)
688 else:
689 password = None
690 if user:
691 user = urllib.parse.unquote(user)
692 if password:
693 password = unicode_passwd(urllib.parse.unquote(password))
694 else:
695 user = password = None
696
697 if host and ':' in host:
698 host, port = host.split(':', 1)
699 try:
700 port = int(port)
701 except ValueError:
702 raise _invalid_db_str(db_str)
703 else:
704 port = None
705
706 if not path:
707 path = '/' + rest
708 if os.name == 'nt':
709 # Support local paths containing drive letters on Win32
710 if len(rest) > 1 and rest[1] == '|':
711 path = "%s:%s" % (rest[0], rest[2:])
712
713 params = {}
714 if '?' in path:
715 path, qs = path.split('?', 1)
716 qs = qs.split('&')
717 for param in qs:
718 try:
719 name, value = param.split('=', 1)
720 except ValueError:
721 raise _invalid_db_str(db_str)
722 value = urllib.parse.unquote(value)
723 params[name] = value
724
725 args = zip(('user', 'password', 'host', 'port', 'path', 'params'),
726 (user, password, host, port, path, params))
727 return scheme, {key: value for key, value in args if value}
728
729
730def _invalid_db_str(db_str):
731 return ConfigurationError(
732 tag_("Invalid format %(db_str)s for the database connection string. "
733 "Please refer to the %(doc)s for help.",
734 db_str=tag.code(db_str), doc=_doc_db_str()))
735
736
737def _doc_db_str():
738 return tag.a(_("documentation"),
739 title=_("Database Connection Strings documentation"),
740 class_='trac-target-new',
741 href='https://trac.edgewall.org/wiki/'
742 'TracIni#DatabaseConnectionStrings')
Note: See TracBrowser for help on using the repository browser.