Edgewall Software

Ticket #8172: plugin_env_setup.py

File plugin_env_setup.py, 7.0 KB (added by Felix Schwarz <felix.schwarz@…>, 3 years ago)

If there is a chance to get this into trac main, we can polish this class (including docs) and make trac use it too

Line 
1# -*- coding: utf8 -*-
2#   Copyright 2009 agile42 GmbH All right reserved
3#
4#   Licensed under the Apache License, Version 2.0 (the "License");
5#   you may not use this file except in compliance with the License.
6#   You may obtain a copy of the License at
7#
8#       http://www.apache.org/licenses/LICENSE-2.0
9#
10#   Unless required by applicable law or agreed to in writing, software
11#   distributed under the License is distributed on an "AS IS" BASIS,
12#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13#   See the License for the specific language governing permissions and
14#   limitations under the License.
15#
16#   Authors:
17#        - Felix Schwarz <felix.schwarz__at__agile42.com>
18
19from trac.core import Component, implements, TracError
20from trac.db import DatabaseManager
21from trac.env import IEnvironmentSetupParticipant
22from trac.util.translation import _
23
24__all__ = ['PluginEnvironmentSetup']
25
26
27class PluginEnvironmentSetup(Component):
28   
29    abstract = True
30   
31    implements(IEnvironmentSetupParticipant)
32   
33    # --------------------------------------------------------------------------
34    # Template methods
35   
36    def get_expected_db_version(self):
37        "Return the DB version (as integer) which is need by the plugin "
38        raise NotImplementedError()
39   
40    def get_package_name(self):
41        "Return the package name where the upgrade scripts can be found."
42        pkg_name = self.__module__.rsplit('.', 1)[0]
43        return pkg_name + '.db.upgrades'
44   
45    def name(self):
46        "Return the name of this plugin - like 'Trac' for the core Trac."
47        raise NotImplementedError()
48    name = property(name)
49   
50    # --------------------------------------------------------------------------
51    # IEnvironmentSetupParticipant methods
52    def environment_created(self):
53        """Fill in default data when a new environment was created (this
54        implementation just sets the db version identifier)."""
55        db, handle_ta = self.get_db_for_write()
56        self.set_db_version(db)
57        if handle_ta:
58            db.commit()
59   
60    def environment_needs_upgrade(self, db=None):
61        db = self.get_db_for_read(db)
62        db_ver = self.get_db_version(db)
63        expected_db_version = self.get_expected_db_version()
64        if db_ver == expected_db_version:
65            return False
66        elif db_ver > expected_db_version:
67            msg = _('Database newer than %s version (has version %s, is version %s)')
68            raise TracError(msg % (self.name, db_ver, expected_db_version))
69        return True
70   
71    def upgrade_environment(self, db=None):
72        db, handle_ta = self.get_db_for_write(db)
73        db_ver = self.get_db_version(db)
74        if db_ver == 0:
75            return self.environment_created()
76        cursor = db.cursor()
77       
78        catched_exception = None
79        try:
80            successful_upgrade = self.run_upgrade_scripts(cursor, db_ver)
81        except Exception, e:
82            error_msg = _('Exception while upgrading %s database: %s')
83            self.env.log.error(error_msg % (self.name, str(e)))
84            successful_upgrade = False
85            catched_exception = e
86       
87        if not successful_upgrade:
88            # Trac never issues a rollback implicitely during the upgrade which
89            # seems to be wrong.
90            db.rollback()
91            error_msg = _('Upgrading %s tables failed!') % self.name
92            self.env.log.error(error_msg)
93            if catched_exception:
94                raise
95        else:
96            self.set_db_version(db)
97            if handle_ta:
98                db.commit()
99            msg = _('Upgraded %s database version from %d to %d')
100            expected_db_version = self.get_expected_db_version()
101            self.env.log.info(msg % (self.name, db_ver, expected_db_version))
102   
103    # --------------------------------------------------------------------------
104    # Custom utility methods
105   
106    def get_db_for_read(self, db=None):
107        # The idea is that maybe in the future there is a different connection
108        # pool for read connections.
109        if db == None:
110            db = self.env.get_db_cnx()
111        return db
112   
113    def get_db_for_write(self, db=None):
114        handle_ta = False
115        if db == None:
116            db = self.env.get_db_cnx()
117            handle_ta = True
118        return (db, handle_ta)
119   
120    def _fetch_db_version(self, db, name):
121        db_version = 0
122        cursor = db.cursor()
123        sql = "SELECT value FROM system WHERE name='%s' LIMIT 1" % name
124        cursor.execute(sql)
125        row = cursor.fetchone()
126        if (row is not None) and (len(row) > 0):
127            db_version = row[0]
128        return db_version
129   
130    def get_db_version(self, db, name=None):
131        """Return the DB version (0 if no version information was read).
132        If name was given, use this identifier instead of the value in
133        self.name.
134       
135        Specifying another name is useful if your plugin changed its name but
136        you need to check if an old version of your plugin is present."""
137        name = name or self.name
138        db_version = int(self._fetch_db_version(db, name))
139        return db_version
140   
141    def run_upgrade_scripts(self, cursor, current_db_version):
142        dbm = DatabaseManager(self.env)
143        connector, args = dbm._get_connector()
144        upgrade_was_successful = True
145        expected_db_version = self.get_expected_db_version()
146        for i in xrange(current_db_version + 1, expected_db_version + 1):
147            name  = 'db%i' % i
148            filename = '%s.py' % name
149            try:
150                pkg_name = self.get_package_name()
151                upgrades = __import__(pkg_name, globals(), locals(), [name])
152                script = getattr(upgrades, name)
153            except AttributeError:
154                msg = _('No upgrade module for version %(num)i (%(filename)s)')
155                raise TracError(msg, num=i, filename=filename)
156            upgrade_was_successful = script.do_upgrade(self.env, i, cursor, connector)
157            if not upgrade_was_successful:
158                msg = _('Upgrade script %s did not complete successfully')
159                self.env.log.error(msg % filename)
160                break
161        return upgrade_was_successful
162   
163    def _insert_version_number(self, cursor, version):
164        cursor.execute('INSERT INTO system (name, value) VALUES (%s, %s)', 
165                       (self.name, version))
166   
167    def _update_version_number(self, cursor, version):
168        cursor.execute('UPDATE system SET value=%s WHERE name=%s', 
169                       (version, self.name))
170   
171    def set_db_version(self, db):
172        """Write the DB version of this plugin to the db (or update an existing
173        row if one already exist)."""
174        cursor = db.cursor()
175        latest_version = self.get_expected_db_version()
176        was_upgrade = (self.get_db_version(db) > 0)
177        if was_upgrade:
178            self._update_version_number(cursor, latest_version)
179        else:
180            self._insert_version_number(cursor, latest_version)
181