Records 源码阅读与实践

Records 源码阅读与实践

简介

Records 是用于对大多数关系型数据库进行原始的 SQL 语句查询的第三方库,由 kennethreitz 创建,目前仅 500 多行代码,非常简单但又十分强大。它支持的数据库有 RedShift, Postgres, MySQL, SQLite, Oracle, and MS-SQL

Records 开源地址: github.com/kennethreitz

SQL for Humans™ pypi.python.org/pypi/re

records.py

以下是按照源代码顺序进行的阅读记录,如有错误,欢迎指正!

# -*- coding: utf-8 -*-

PEP 0263 :建议在Python文件头部声明使用何种编码

import

import os
# 用于处理文件和目录
from sys import stdout
# Python 默认输出
from collections import OrderedDict
# OrderedDict 有序字典,根据元素放入先后顺序排列
from contextlib import contextmanager
# 上下文管理器
from inspect import isclass
# isclass 判断是否为类对象
import tablib
# 将数据输出为常用格式的第三方库
from docopt import docopt
# 解析命令行参数
from sqlalchemy import create_engine, exc, inspect, text
# sqlalchemy 提供 SQL 工具和 ORM 工具

什么是ORM?

全称 Object-Relationl Mapping,在Python中表现为关系型数据库的对象和Python对象之间的映射,有了这个映射,我们就可以直接通过调用Python对象来操作数据库。

isexception 方法

# isexception 判断对象是否为Exception类实例或其子类
def isexception(obj):
    """Given an object, return a boolean indicating whether it is an instance
    or subclass of :py:class:`Exception`.
    if isinstance(obj, Exception):  
        return True
    if isclass(obj) and issubclass(obj, Exception): 
        return True
    return False

Record 类

# Record 储存单行数据的类
class Record(object):
    """A row, from a query, from a database."""
    __slots__ = ('_keys', '_values')    # 限制类的合法属性集(仅对新式类作用)
    def __init__(self, keys, values):
        self._keys = keys   # 单下划线保护变量
        self._values = values
        # Ensure that lengths match properly.
        # assert 断言,表达式返回为False则报错,此处为确保 _keys _values 长度一致
        assert len(self._keys) == len(self._values)
    # keys 和 values 都是 getter 方法,是用于获取保护变量值的函数
    def keys(self):
        """Returns the list of column names from the query."""
        return self._keys
    def values(self):
        """Returns the list of values from the query."""
        return self._values
    # __repr__ 定义对象在终端返回的字符串形式
    def __repr__(self):
        return '<Record {}>'.format(self.export('json')[1:-1])
    # __getitem__ 定义 obj[key] 索引返回值
    def __getitem__(self, key):
        # Support for index-based lookup.
        # 整数索引支持
        if isinstance(key, int):    # 如果 key 为整数类型
            return self.values()[key]
        # Support for string-based lookup.
        # 字符索引支持
        if key in self.keys():  # 如果 key 为 keys属性中的值
            i = self.keys().index(key) # 返回该 key 的位置
            if self.keys().count(key) > 1:  # 如果 key 在 keys 属性中有多值
                raise KeyError("Record contains multiple '{}' fields.".format(key))
            return self.values()[i] 
        raise KeyError("Record contains no '{}' field.".format(key))
    # __getattr__ 属性查找,定义 obj.key 返回值
    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError as e:
            raise AttributeError(e)
    # __dir__ 定义 obj 属性信息,通过 dir(obj) 调用返回
    def __dir__(self):
        standard = dir(super(Record, self))
        # Merge standard attrs with generated ones (from column names).
        return sorted(standard + [str(k) for k in self.keys()])
    # get 实现类似于字典的 get 方法,若找不到值则返回默认值
    def get(self, key, default=None):
        """Returns the value for a given key, or default."""
        try:
            return self[key]
        except KeyError:
            return default
    # as_dict 生成字典,ordered 控制是否返回有序字典
    def as_dict(self, ordered=False):
        """Returns the row as a dictionary, as ordered."""
        items = zip(self.keys(), self.values())
        return OrderedDict(items) if ordered else dict(items)
    # @property 装饰器,将函数调用方式改为属性调用方式(obj.dataset)
    # dataset 将 keys 和 values 放入 tablib 的 Dataset 对象
    @property
    def dataset(self):
        """A Tablib Dataset containing the row."""
        data = tablib.Dataset()
        data.headers = self.keys()
        # _reduce_datetimes 在后面有定义,若 values 为 datetime 类型则转换为字符串
        row = _reduce_datetimes(self.values())
        data.append(row)
        return data
    # export 调用Dataset对象的export方法输出指定格式数据
    def export(self, format, **kwargs):
        """Exports the row to the given format."""
        return self.dataset.export(format, **kwargs)

Record 对象储存的是单条记录,keys 负责储存列名,values 负责储存每列对应的值,支持通过整数、字符串、get 方法索引,可以通过 tablib 导出为指定格式数据。

RecordCollection 类

# RecordCollection 存储多行数据的类
class RecordCollection(object):
    """A set of excellent Records from a query."""
    def __init__(self, rows):
        self._rows = rows
        self._all_rows = [] # 用于缓存已迭代过的行
        self.pending = True
    def __repr__(self):
        return '<RecordCollection size={} pending={}>'.format(len(self), self.pending)
    # __iter__ 实现迭代能力(Iterable),yield 返回值
    def __iter__(self):
        """Iterate over all rows, consuming the underlying generator
        only when necessary."""
        i = 0
        while True:
            # Other code may have iterated between yields,
            # so always check the cache.
            # 首先从缓存中查找(len 返回值为 _all_rows 长度),然后进行迭代
            if i < len(self):
                yield self[i]
            else:
                # Throws StopIteration when done.
                # Prevent StopIteration bubbling from generator, following https://www.python.org/dev/peps/pep-0479/
                try:
                    yield next(self)
                except StopIteration:
                    return
            i += 1
    # next 提供显示调用 __next__ 能力
    def next(self):
        return self.__next__()
    # __next__ 此处同时实现了 __iter__ 和 __next__ ,说明该类的对象为迭代器(Iterator)
    def __next__(self):
        try:
            nextrow = next(self._rows)  # next() Python 内置迭代器方法
            self._all_rows.append(nextrow)  # 将每次迭代的行缓存到_all_rows属性
            return nextrow
        except StopIteration:
            self.pending = False
            raise StopIteration('RecordCollection contains no more rows.')
    # __getitem__ 定义类的索引返回值,[]的索引只有int、slice两种类型
    # 此处与 Record 类的实现有所不同,主要针对切片操作进行了优化
    def __getitem__(self, key):
        is_int = isinstance(key, int)
        # Convert RecordCollection[1] into slice.
        if is_int:
            key = slice(key, key + 1)   # 整数索引转换为切片对象
        while len(self) < key.stop or key.stop is None: # 若索引大于缓存的长度或索引无结束位置,则进行迭代
            try:
                next(self)  
            except StopIteration:
                break
        rows = self._all_rows[key]  # 从缓存中提取索引指定行
        # 若传入 key 为 int 则返回行,否则返回一个包含多行的 RecordCollection 对象
        if is_int:  
            return rows[0]
        else:
            return RecordCollection(iter(rows))
    # __len__ 定义len(obj)的返回值,此处为_all_rows长度
    def __len__(self):
        return len(self._all_rows)
    # export 导出指定格式数据集
    def export(self, format, **kwargs):
        """Export the RecordCollection to a given format (courtesy of Tablib)."""
        return self.dataset.export(format, **kwargs)
    # dataset 存放记录到 Dataset 对象并返回(支持属性调用访问)
    @property
    def dataset(self):
        """A Tablib Dataset representation of the RecordCollection."""
        # Create a new Tablib Dataset.
        data = tablib.Dataset()
        # If the RecordCollection is empty, just return the empty set
        # Check number of rows by typecasting to list
        # 如果传入对象长度为空,则返回空 Dataset
        if len(list(self)) == 0:
            return data
        # Set the column names as headers on Tablib Dataset.
        # 设定 Dataset 对象的表头名
        first = self[0]
        data.headers = first.keys()
        for row in self.all():
            row = _reduce_datetimes(row.values())
            data.append(row)
        return data
    # all 返回所有记录,可指定列表、字典、有序字典形式
    def all(self, as_dict=False, as_ordereddict=False):
        """Returns a list of all rows for the RecordCollection. If they haven't
        been fetched yet, consume the iterator and cache the results."""
        # By calling list it calls the __iter__ method
        # list(self) 调用了类的 __iter__ 方法
        rows = list(self)
        if as_dict:
            return [r.as_dict() for r in rows]
        elif as_ordereddict:
            return [r.as_dict(ordered=True) for r in rows]
        return rows
    # as_dict 返回所有记录的字典形式(调用 all 方法)
    def as_dict(self, ordered=False):
        return self.all(as_dict=not(ordered), as_ordereddict=ordered)
    # first 返回 RecordCollection 对象的第一条记录
    def first(self, default=None, as_dict=False, as_ordereddict=False):
        """Returns a single record for the RecordCollection, or `default`. If
        `default` is an instance or subclass of Exception, then raise it
        instead of returning it."""
        # Try to get a record, or return/raise default.
        # 尝试获得第一条记录,否则返回默认值或报默认 
        try:
            record = self[0]    # 调用 __getitem__ 获取_all_rows中的第一个元素
        except IndexError:
            if isexception(default): # isexception 全局方法
                raise default
            return default
        # Cast and return.
        if as_dict:
            return record.as_dict()
        elif as_ordereddict:
            return record.as_dict(ordered=True)
        else:
            return record
    # one 返回仅含一条记录的 RecordCollection 对象的记录,不满足则报错
    def one(self, default=None, as_dict=False, as_ordereddict=False):
        """Returns a single record for the RecordCollection, ensuring that it
        is the only record, or returns `default`. If `default` is an instance
        or subclass of Exception, then raise it instead of returning it."""
        # Ensure that we don't have more than one row.
        try:
            self[1]
        except IndexError:  # 如果没有第二条元素则返回第一条元素
            return self.first(default=default, as_dict=as_dict, as_ordereddict=as_ordereddict)
        else:   # 如果有第二条元素则报错
            raise ValueError('RecordCollection contained more than one row. ''
                             'Expects only one row when using '
                             'RecordCollection.one')
    # scalar 返回满足 one 条件记录的第一列元素      
    def scalar(self, default=None):
        """Returns the first column of the first row, or `default`."""
        row = self.one()
        return row[0] if row else default

RecordsCollection 对象是 Records 查询返回的对象,是一个包含多行结果的迭代器,迭代器最大的好处在于节约内存空间,而对象的索引也是通过迭代返回结果,索引结果会缓存起来,再次索引可以直接从缓存中查找,大大减少了平均索引速度。

Database 类

# Database 用于数据库连接和 SQL 查询的类
class Database(object):
    """A Database. Encapsulates a url and an SQLAlchemy engine with a pool of
    connections.
    def __init__(self, db_url=None, **kwargs):
        # If no db_url was provided, fallback to $DATABASE_URL.
        # 如果不提供 db_url 则从环境变量中寻找
        self.db_url = db_url or os.environ.get('DATABASE_URL')
        # 如果找不到 db_url 则报错
        if not self.db_url:
            raise ValueError('You must provide a db_url.')
        # Create an engine.
        # 建立数据库连接
        self._engine = create_engine(self.db_url, **kwargs)
        self.open = True
    # close 关闭引擎
    def close(self):
        """Closes the Database."""
        self._engine.dispose()
        self.open = False
    # with 语句支持
    # __enter__ 该方法返回值将赋值给 as 后的变量
    def __enter__(self):
        return self
    # __exit__  在 with 语句代码块执行完后调用
    def __exit__(self, exc, val, traceback):
        self.close()
    def __repr__(self):
        return '<Database open={}>'.format(self.open)
    # get_table_names 返回包含已连接数据库中所有表名的列表
    def get_table_names(self, internal=False):
        """Returns a list of table names for the connected database."""
        # Setup SQLAlchemy for Database inspection.
        return inspect(self._engine).get_table_names()
    # get_connection 获取并返回数据库连接对象 Connection(见 Connection 类)
    def get_connection(self):
        """Get a connection to this Database. Connections are retrieved from a
        pool.
        # 如果引擎关闭则报错
        if not self.open:
            raise exc.ResourceClosedError('Database closed.')
        return Connection(self._engine.connect())
    # 以下四个函数均通过 Connection 对象执行(见 Connection 中定义)
    def query(self, query, fetchall=False, **params):
        """Executes the given SQL query against the Database. Parameters can,
        optionally, be provided. Returns a RecordCollection, which can be
        iterated over to get result rows as dictionaries.
        with self.get_connection() as conn:
            return conn.query(query, fetchall, **params)
    def bulk_query(self, query, *multiparams):
        """Bulk insert or update."""
        with self.get_connection() as conn:
            conn.bulk_query(query, *multiparams)
    def query_file(self, path, fetchall=False, **params):
        """Like Database.query, but takes a filename to load a query from."""
        with self.get_connection() as conn:
            return conn.query_file(path, fetchall, **params)
    def bulk_query_file(self, path, *multiparams):
        """Like Database.bulk_query, but takes a filename to load a query from."""
        with self.get_connection() as conn:
            conn.bulk_query_file(path, *multiparams)
    # @contextmanager 上下文管理器装饰器,接收一个 generator,用 yield 返回对象为 with ... as var 中的变量
    # transaction 用于执行事务操作
    @contextmanager
    def transaction(self):
        """A context manager for executing a transaction on this Database."""
        conn = self.get_connection()    # 获取 Connection 对象
        tx = conn.transaction() # 获取 Transaction 对象
        try:
            yield conn
            tx.commit()
        except:
            tx.rollback()
        finally:
            conn.close()

Database 类是 Records 中主要的操作对象,可以通过调用 Database 完成数据库的连接、获取所有表名、SQL 语句执行、事务操作。

1.数据库连接 ,直接调用了 sqlalchemy 的连接方式,通过传入 Database URL 完成数据库连接。

一个典型的 Database URL 格式为:dialect+driver://username:password@host:port/database    
Python 通过 PyMySQL 调用 MySQL:mysql+pymysql://scott:tiger@localhost/foo 

2.获取表名 ,调用了 sqlalchemy 的 inspect 方法获取

3.SQL 语句执行 ,主要功能的实现定义在 Connection 类,此处主要通过 with 语句调用。

4.事务操作 , SQL 中事务的主要操作有 BEGIN 、COMMIT 、ROLLBACK,而 Python 上下文管理器的特点是,可以控制语句块执行前后的动作,此处上下文管理器执行前获取 Connection 和 Transaction 对象,执行完毕后通过 try ... except ... finally 语句分别对语句块 执行成功、执行失败、执行结束 三种情况的 Transaction 动作进行了定义。优点是大大简化了事务操作流程,用户只需关注 SQL 语句的书写,这也反映了作者创建此包的初衷,“Just write SQL. No bells, no whistles. ”

Connection 类

# Connection 数据库连接对象
class Connection(object):
    """A Database connection."""
    def __init__(self, connection):
        self._conn = connection # self._engine.connect()
        self.open = not connection.closed
    # close 关闭连接
    def close(self):
        self._conn.close()
        self.open = False
    # with 语句支持(__enter__ 、 __exit__)
    def __enter__(self):
        return self
    def __exit__(self, exc, val, traceback):
        self.close()
    def __repr__(self):
        return '<Connection open={}>'.format(self.open)
    # query 执行 SQL 语句 
    def query(self, query, fetchall=False, **params):
        """Executes the given SQL query against the connected Database.
        Parameters can, optionally, be provided. Returns a RecordCollection,
        which can be iterated over to get result rows as dictionaries.
        # Execute the given query.
        # 执行给定语句
        # text() 在此处的作用是将 SQL 语句格式化,使其能够通过外部参数动态调整
        # **params 可变关键字参数,以字典形式传入
        cursor = self._conn.execute(text(query), **params) # TODO: PARAMS GO HERE
        # Row-by-row Record generator.
        # cursor 是 sqlalchemy 中的 ResultProxy 对象
        # cursor.keys() 获取列名
        # row_gen 是包含多个 Record 对象的迭代器
        row_gen = (Record(cursor.keys(), row) for row in cursor)
        # Convert psycopg2 results to RecordCollection.
        # 将结果存入 RecordCollection 对象
        results = RecordCollection(row_gen)
        # Fetch all results if desired.
        # fetchall=True 获取所有结果
        if fetchall:
            results.all()
        return results
    # bulk_query 批量执行 SQL 语句
    def bulk_query(self, query, *multiparams):
        """Bulk insert or update."""
        # *multiparams 可变参数,以元组形式传入
        self._conn.execute(text(query), *multiparams)
    # query_file 从 .sql 文件中执行
    def query_file(self, path, fetchall=False, **params):
        """Like Connection.query, but takes a filename to load a query from."""
        # If path doesn't exists
        if not os.path.exists(path):
            raise IOError("File '{}' not found!".format(path))
        # If it's a directory
        if os.path.isdir(path):
            raise IOError("'{}' is a directory!".format(path))
        # Read the given .sql file into memory.
        with open(path) as f:
            query = f.read()
        # Defer processing to self.query method.
        return self.query(query=query, fetchall=fetchall, **params)
    # bulk_query_file 从 .sql 文件中批量执行
    def bulk_query_file(self, path, *multiparams):
        """Like Connection.bulk_query, but takes a filename to load a query
        from.
         # If path doesn't exists
        if not os.path.exists(path):
            raise IOError("File '{}'' not found!".format(path))
        # If it's a directory
        if os.path.isdir(path):
            raise IOError("'{}' is a directory!".format(path))
        # Read the given .sql file into memory.
        with open(path) as f:
            query = f.read()
        self._conn.execute(text(query), *multiparams)
    # transaction 返回一个 Transaction 事务对象,可调用 commit 或 rollback
    def transaction(self):
        """Returns a transaction object. Call ``commit`` or ``rollback``
        on the returned object as appropriate."""
        return self._conn.begin()

不难看出,Connection 主要实现的功能有 contextmanager 上下文管理器、query 返回 SQL 语句结果、transaction 返回事务对象。

  1. contextmanager 的实现使得连接可以通过 with 语句完成,确保 SQL 语句执行结束后及时关闭 Connection;
  2. query 主要完成了三个步骤,第一步提取 sqlalchemy 查询得到的 ResultProxy 对象,第二步将每条数据封装入 Records 对象,第三步将每条 Records 封装入 RecordsCollection 对象;
  3. transaction 返回了 sqlalchemy 的 Transaction 对象,此对象在 Database 类中发挥作用。

_reduce_datetimes 方法

# _reduce_datetimes 接收一行,将 datetimes 格式转换为ISO格式的时间字符串
def _reduce_datetimes(row):
    """Receives a row, converts datetimes to strings."""
    row = list(row)
    # 通过索引修改
    for i in range(len(row)):
        if hasattr(row[i], 'isoformat'):
            row[i] = row[i].isoformat()
    return tuple(row)

cli 方法

# cli 命令行界面(command line interface)
def cli():
    supported_formats = 'csv tsv json yaml html xls xlsx dbf latex ods'.split()
    formats_lst=", ".join(supported_formats)
    cli_docs ="""Records: SQL for Humans™
A Kenneth Reitz project.
Usage:
  records <query> [<format>] [<params>...] [--url=<url>]
  records (-h | --help)
Options:
  -h --help     Show this screen.
  --url=<url>   The database URL to use. Defaults to $DATABASE_URL.
Supported Formats:
   %(formats_lst)s
   Note: xls, xlsx, dbf, and ods formats are binary, and should only be
         used with redirected output e.g. '$ records sql xls > sql.xls'.
Query Parameters:
    Query parameters can be specified in key=value format, and injected
    into your query in :key format e.g.:
    $ records 'select * from repos where language ~= :lang' lang=python
Notes:
  - While you may specify a database connection string with --url, records
    will automatically default to the value of $DATABASE_URL, if available.
  - Query is intended to be the path of a SQL file, however a query string
    can be provided instead. Use this feature discernfully; it's dangerous.
  - Records is intended for report-style exports of database queries, and
    has not yet been optimized for extremely large data dumps.
    """ % dict(formats_lst=formats_lst)
    # Parse the command-line arguments.
    arguments = docopt(cli_docs)
    query = arguments['<query>']
    params = arguments['<params>']
    format = arguments.get('<format>')
    if format and "=" in format:
        del arguments['<format>']
        arguments['<params>'].append(format)
        format = None
    if format and format not in supported_formats:
        print('%s format not supported.' % format)
        print('Supported formats are %s.' % formats_lst)
        exit(62)
    # Can't send an empty list if params aren't expected.
    try:
        params = dict([i.split('=') for i in params])
    except ValueError:
        print('Parameters must be given in key=value format.')
        exit(64)
    # Be ready to fail on missing packages
    try:
        # Create the Database.
        db = Database(arguments['--url'])
        # Execute the query, if it is a found file.
        if os.path.isfile(query):
            rows = db.query_file(query, **params)
        # Execute the query, if it appears to be a query string.
        elif len(query.split()) > 2:
            rows = db.query(query, **params)
        # Otherwise, say the file wasn't found.
        else:
            print('The given query could not be found.')
            exit(66)
        # Print results in desired format.
        if format:
            content = rows.export(format)
            if isinstance(content, bytes):
                print_bytes(content)
            else:
                print(content)
        else:
            print(rows.dataset)
    except ImportError as impexc:
        print(impexc.msg)
        print("Used database or format require a package, which is missing.")
        print("Try to install missing packages.")
        exit(60)

print_bytes 方法

# print_bytes 打印输出二进制对象
def print_bytes(content):
    try:
        stdout.buffer.write(content)
    except AttributeError:
        stdout.write(content)

if name == ' main ':

# Run the CLI when executed directly.
# 如果直接运行则启动命令行界面
if __name__ == '__main__':
    cli()

实践

安装

$ pip install records
$ pipenv install records[pandas]    # 推荐安装方式

执行语句

import records
# 获取数据库
db = records.Database('mysql+pymysql://root:@localhost:3306/dev01_git')
rows = db.query('select * from pc_user')

创建表

# 连接数据库
db = records.Database('mysql+pymysql://root:@localhost:3306/dev01_git')
# 创建表
sql_create_table = """CREATE TABLE IF NOT EXISTS pc_user (
    name varchar(20),
    age int
) DEFAULT CHARSET=utf8 ;"""
db.query(sql_create_table)

插入数据

# 插入单条
user = {"name": "zhang1", "age": 13}
db.query('INSERT INTO pc_user(name,age) values (:name, :age)', **user)
# 插入多条
users = [
    {"name":"zhang2", "age": 14},
    {"name":"zhang3", "age": 15},
    {"name":"zhang4", "age": 16}
db.bulk_query('INSERT INTO pc_user(name,age) values (:name, :age)', users)

查询数据

rows = db.query('SELECT * FROM pc_user;')
# 查询所有数据
print(rows.all())
# 字典形式展示
print(rows.all(as_dict=True))
# 获取第一条记录
print(rows.first())
# 以字典形式获取第一条记录
print(rows.first(as_dict=True))
# 顺序字典
print(rows.first(as_ordereddict=True))
# 查询唯一的一个
print(rows.one())

数据库事务操作

with db.transaction() as tx:
    user = {"name": "zhang5", "age": 20}
    tx.query('INSERT INTO pc_user(name,age) values (:name, :age)', **user)
    tx.query('sof') # 错误语句,自动回滚

数据导出

# 导出为json
rows = db.query('SELECT * FROM pc_user;')
json_rows = rows.export('yaml')