Edgewall Software

Ticket #7600: postgres_quoting.2.patch

File postgres_quoting.2.patch, 4.7 KB (added by felix.schwarz@…, 4 years ago)

re-did my patch, this time with unit tests

  • trac/db/postgres_backend.py

     
    7474        cnx.commit() 
    7575 
    7676    def to_sql(self, table): 
    77         sql = ["CREATE TABLE %s (" % table.name] 
     77        sql = ['CREATE TABLE "%s" (' % table.name] 
    7878        coldefs = [] 
    7979        for column in table.columns: 
    8080            ctype = column.type 
    8181            if column.auto_increment: 
    82                 ctype = "SERIAL" 
     82                ctype = 'SERIAL' 
    8383            if len(table.key) == 1 and column.name in table.key: 
    84                 ctype += " PRIMARY KEY" 
    85             coldefs.append("    %s %s" % (column.name, ctype)) 
     84                ctype += ' PRIMARY KEY' 
     85            coldefs.append('    "%s" %s' % (column.name, ctype)) 
    8686        if len(table.key) > 1: 
    87             coldefs.append("    CONSTRAINT %s_pk PRIMARY KEY (%s)" 
    88                            % (table.name, ','.join(table.key))) 
     87            coldefs.append('    CONSTRAINT "%s_pk" PRIMARY KEY ("%s")' 
     88                           % (table.name, '","'.join(table.key))) 
    8989        sql.append(',\n'.join(coldefs) + '\n)') 
    9090        yield '\n'.join(sql) 
    9191        for index in table.indices: 
    9292            unique = index.unique and 'UNIQUE' or '' 
    93             yield "CREATE %s INDEX %s_%s_idx ON %s (%s)" % (unique, table.name,  
     93            yield 'CREATE %s INDEX "%s_%s_idx" ON "%s" ("%s")' % (unique, table.name,  
    9494                   '_'.join(index.columns), table.name, ','.join(index.columns)) 
    9595 
    9696 
  • trac/db/tests/postgres_test.py

     
     1# -*- coding: utf-8 -*- 
     2 
     3import re 
     4import unittest 
     5 
     6from trac.db import Table, Column, Index 
     7from trac.db.postgres_backend import PostgreSQLConnector 
     8from trac.test import EnvironmentStub 
     9 
     10 
     11class PostgresTableCreationSQLTest(unittest.TestCase): 
     12    def setUp(self): 
     13        self.env = EnvironmentStub() 
     14        self.db = self.env.get_db_cnx() 
     15     
     16    def _unroll_generator(self, generator): 
     17        items = [] 
     18        for item in generator: 
     19            items.append(item) 
     20        return items 
     21     
     22    def _normalize_sql(self, sql_generator): 
     23        normalized_commands = [] 
     24        whitespace_regex = re.compile(r'\s+') 
     25        commands = self._unroll_generator(sql_generator) 
     26        for command in commands: 
     27            command = command.replace('\n', '') 
     28            command = whitespace_regex.sub(' ', command) 
     29            normalized_commands.append(command) 
     30        return normalized_commands 
     31     
     32    def test_quote_table_name(self): 
     33        table = Table('foo bar') 
     34        table[Column('name'),] 
     35        sql_generator = PostgreSQLConnector(self.env).to_sql(table) 
     36        sql_commands = self._normalize_sql(sql_generator) 
     37        self.assertEqual(1, len(sql_commands)) 
     38        self.assertEqual('CREATE TABLE "foo bar" ( "name" text)', sql_commands[0]) 
     39     
     40    def test_quote_column_names(self): 
     41        table = Table('foo') 
     42        table[Column('my name'),] 
     43        sql_generator = PostgreSQLConnector(self.env).to_sql(table) 
     44        sql_commands = self._normalize_sql(sql_generator) 
     45        self.assertEqual(1, len(sql_commands)) 
     46        self.assertEqual('CREATE TABLE "foo" ( "my name" text)', sql_commands[0]) 
     47     
     48    def test_quote_compound_primary_key_declaration(self): 
     49        table = Table('foo bar', key=['my name', 'your name']) 
     50        table[Column('my name'), Column('your name'),] 
     51        sql_generator = PostgreSQLConnector(self.env).to_sql(table) 
     52        sql_commands = self._normalize_sql(sql_generator) 
     53        self.assertEqual(1, len(sql_commands)) 
     54        expected_sql = 'CREATE TABLE "foo bar" ( "my name" text, ' + \ 
     55                       '"your name" text, CONSTRAINT "foo bar_pk" ' +\ 
     56                       'PRIMARY KEY ("my name","your name"))' 
     57        self.assertEqual(expected_sql, sql_commands[0]) 
     58     
     59    def test_quote_index_declaration(self): 
     60        table = Table('foo') 
     61        table[Column('my name'), Index(['my name'])] 
     62        sql_generator = PostgreSQLConnector(self.env).to_sql(table) 
     63        sql_commands = self._normalize_sql(sql_generator) 
     64        self.assertEqual(2, len(sql_commands)) 
     65        self.assertEqual('CREATE TABLE "foo" ( "my name" text)', sql_commands[0]) 
     66        index_sql = 'CREATE INDEX "foo_my name_idx" ON "foo" ("my name")' 
     67        self.assertEqual(index_sql, sql_commands[1]) 
     68 
     69 
     70def suite(): 
     71    return unittest.makeSuite(PostgresTableCreationSQLTest, 'test') 
     72 
     73if __name__ == '__main__': 
     74    unittest.main(defaultTest='suite')