Flask数据模型:使用flask-sqlalchemy操作数据库

摘要: Flask flask-sqlalchemy

flask-sqlalchemy概述

SQLAlchemy 是一个基于Python实现的 ORM (对象关系映射)框架,使用关系对象映射数据库操作,即将类和对象操作SQL,核心思想是将数据库表中的记录映射为对象,对数据库抽象化。
flask-sqlalchemy 是一个简化SQLAlchemy操作的flask扩展,提供了有用的默认值和额外的助手来更简单地完成常见任务。
flask-sqlalchemy安装

pip install flask-sqlalchemy

ORM关系映射

  • 类的属性相当于表的一个字段
  • ORM框架优点:在于可以简化数据库访问代码,比如获取数据库链接,建立游标,定义SQL语句,读取数据,删除游标,删除链接这些重复代码,提高开发效率,使开发者更加专注于web功能开发而不是底层数据驱动,并且统一数据库访问的代码格式。
    ORM框架缺点牺牲程序执行效率,特别是对于复杂的SQL语句

    flask-sqlalchemy初始化

    flask-sqlalchemy访问MySQL,首先定义一个配置文件为flask-sqlalchemy初始化做准备,配置单独写为一个文件config.py

    DIALECT = 'mysql'
    DRIVER = 'pymysql'
    USERNAME = 'gp'
    PASSWORD = '123456'
    HOST = '127.0.0.1'
    PORT = '3306'
    DATABASE = 'pira'
    DB_URI = '{}+{}://{}:{}@{}:{}/{}?charset=utf8'.format(DIALECT, DRIVER, USERNAME, PASSWORD, HOST, PORT, DATABASE)
    SQLALCHEMY_DATABASE_URI = DB_URI
    SQLALCHEMY_TRACK_MODIFICATIONS = False
    SQLALCHEMY_ECHO = True
    

    初始化脚本中导入config.py,使用app.config.from_object导入配置,SQLAlchemy(app)初始化数据库对象,需要提前在MySQL中创建一下pira库

    from flask import Flask, render_template, request
    from flask_sqlalchemy import SQLAlchemy
    import config
    app = Flask(__name__)
    app.config.from_object(config)
    db = SQLAlchemy(app)
    

    也可以直接使用key,value设置app.config,定义SQLALCHEMY_DATABASE_URISQLALCHEMY_TRACK_MODIFICATIONSSQLALCHEMY_ECHO三个属性

    app = Flask(__name__)
    app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://gp:123456@127.0.0.1:3306/pira?charset=utf8'
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
    app.config['SQLALCHEMY_ECHO'] = True
    db = SQLAlchemy(app)
    

    数据库初始化相关配置如下

    flask-sqlalchemy 模型与表映射

    获得db对象后进行db.init_app(app)初始化,定义一个class继承数据库对象,定义类的属性作为字段,每个字段定义类型,调用db.create_all()创建表,实例化一个class对象,__tablename__属性为表名,对象的每一个属性对应表的列,属性名和字段名同名,调用db.session.adddb.session.commit插入一条数据。
    注意:db.create_all()只会在数据库表不存在时,flask_sqlalchemy才会创建表,每次调用db.create_all(),只要有至少一个类对象没有对应的数据表,都会生效创建新表

    from flask import Flask
    from flask_sqlalchemy import SQLAlchemy
    import config
    app = Flask(__name__)
    app.config.from_object(config)
    db = SQLAlchemy(app)
    class Test(db.Model):
        __tablename__ = 'test'
        id = db.Column(db.Integer, primary_key=True, autoincrement=True)
        name = db.Column(db.String(50), nullable=False)
        score = db.Column(db.Float, nullable=False, default=0.0)
    db.create_all()
    person1 = Test(name='gp', score=93.5)
    db.session.add(person1)
    db.session.commit()
    if __name__ == '__main__':
        app.run(host="0.0.0.0", port=5000)
    

    查看mysql中表数据

    mysql> select * from test;
    +----+------+-------+
    | id | name | score |
    +----+------+-------+
    |  1 | gp   |  93.5 |
    +----+------+-------+
    

    flask-sqlalchemy常用的数据类型如下

    数据增删改查

    插入数据db.session.add,实例化类对象,调用db.session.add(插入,db.session.commit()提交

    from flask import Flask, request
    from flask_sqlalchemy import SQLAlchemy
    import config
    app = Flask(__name__)
    app.config.from_object(config)
    db = SQLAlchemy(app)
    class Test(db.Model):
        __tablename__ = 'test'
        id = db.Column(db.Integer, primary_key=True, autoincrement=True)
        name = db.Column(db.String(50), nullable=False)
        score = db.Column(db.Float, nullable=False, default=0.0)
    @app.route('/', methods=['GET'])
    def index():
        data = request.args.to_dict()
        name = data.get('name')
        score = data.get('score')
        d1 = Test(name=name, score=score)
        db.session.add(d1)
        db.session.commit()
        return '插入成功:name={},score={}'.format(name, score)
    if __name__ == '__main__':
        app.run(host="0.0.0.0", port=5000)
    

    使用浏览器http发送3条get请求如http://127.0.0.1:5000/?name=sjl&score=66.2
    ,请求url中带有name和score参数,查询数据库的入库数据

    mysql> select * from test;
    +----+------+-------+
    | id | name | score |
    +----+------+-------+
    |  1 | gp   |  93.5 |
    |  2 | wf   |   3.3 |
    |  3 | wbb  |    55 |
    |  4 | sjl  |  66.2 |
    +----+------+-------+
    

    查询数据:要使用flask-sqlalchemy查询数据库数据,需要调用数据库对象的query.filter找到对应的记录,再调用first或者all方法返回最前面一条数据或者所有数据,类似数据库的fetchonefetall。如果first为空返回为Python的None,如果all为空,返回[ ]空list

    from flask import Flask, request
    from flask_sqlalchemy import SQLAlchemy
    import config
    app = Flask(__name__)
    app.config.from_object(config)
    db = SQLAlchemy(app)
    class Test(db.Model):
        __tablename__ = 'test'
        id = db.Column(db.Integer, primary_key=True, autoincrement=True)
        name = db.Column(db.String(50), nullable=False)
        score = db.Column(db.Float, nullable=False, default=0.0)
    @app.route('/<string:name>.html')
    def fetch_data(name: str):
        data = Test.query.filter(Test.name == name).first()
        return 'score=' + str(data.score)
    if __name__ == '__main__':
        app.run(host="0.0.0.0", port=5000)
    

    表的字段可以直接在数据库对象上使用. + 字段名来进行列筛选和过滤,得到的数据是一条记录即一个类,类的属性就是每一个字段。在后台可以看到底层的SQL语句

    2021-01-11 09:56:06,214 INFO sqlalchemy.engine.base.Engine SELECT test.id AS test_id, test.name AS test_name, test.score AS test_score 
    FROM test 
    WHERE test.name = %(name_1)s 
     LIMIT %(param_1)s
    2021-01-11 09:56:06,214 INFO sqlalchemy.engine.base.Engine {'name_1': 'wf', 'param_1': 1}
    

    查询多条数据使用all(),返回的对象使用循环遍历出每一条记录,每一条记录使用. + 字段得到具体数据,否则返回的是一个类

    @app.route('/gt<string:score>')
    def get_gt_data(score: str):
        data = Test.query.filter(Test.score > float(score)).all()
        for d in data:
            print(d.name)
        return '查询成功'
    
    2021-01-11 10:05:04,277 INFO sqlalchemy.engine.base.Engine SELECT test.id AS test_id, test.name AS test_name, test.score AS test_score 
    FROM test 
    WHERE test.score > %(score_1)s
    2021-01-11 10:05:04,277 INFO sqlalchemy.engine.base.Engine {'score_1': 3.0}
    

    在查询中如果有多个filter条件,在filter方法中加入多个条件语句,中间用逗号隔开

    industry_avg_score = PiraScore.query.filter(PiraScore.industry == industry_code, PiraScore.datetime == '2020-12-07').first()
    

    排序:在查询语句中增加其他条件,比如limitorder_bygroup_bycount

    data = Test.query.filter(Test.score > float(score)).order_by(Test.name).all()
    
    data = Test.query.filter(Test.score > float(score)).order_by(Test.name.desc()).all()
    

    随机排序,调用sqlalchemy下func的rand()方法

    from sqlalchemy import func
    order_by(func.rand())
    

    count计数查询直接调用count(),返回结果是一个int

    @app.route('/count')
    def get_count_book():
        # res = db.session.query(db.func.count(Book.id)).scalar()
        res = Book.query.count()
        print(res)  # 2
        print(type(res))  # int
        return '查询成功'
    

    聚合函数max,min,avg,调用sqlalchemy下func.avgfunc.maxfunc.min

    from sqlalchemy import func
    industry_avg_score = PiraScore.query.filter(PiraScore.industry == industry_code, PiraScore.datetime == '2020-12-08').with_entities(func.avg(PiraScore.score)).first()[0]
    industry_max_score = PiraScore.query.filter(PiraScore.industry == industry_code, PiraScore.datetime == '2020-12-08').with_entities(func.max(PiraScore.score)).first()[0]
    industry_min_score = PiraScore.query.filter(PiraScore.industry == industry_code, PiraScore.datetime == '2020-12-08').with_entities(func.min(PiraScore.score)).first()[0]
    

    其中with_entities代表只需要获取需要的字段,多个字段用逗号隔开,比如配合group by输出聚合结果和分组字段的值

    industry_scores = PiraScore.query.filter(PiraScore.industry == industry_code).group_by(PiraScore.datetime).with_entities(PiraScore.datetime, func.avg(PiraScore.score)).all()
    
    [('2020-10-31', Decimal('33.3807')), ('2020-11-01', Decimal('33.5413')), ('2020-11-02', Decimal('33.7248')), ('2020-11-03', Decimal('33.6560')), ('2020-11-04', Decimal('33.8165')), ('2020-11-05', Decimal('33.8165')), ('2020-11-06', Decimal('33.6835')), ('2020-11-07', Decimal('33.7202')), ('2020-11-08', Decimal('33.7248')), ('2020-11-09', Decimal('33.7523')), ('2020-11-10', Decimal('33.9495')), ('2020-11-11', Decimal('34.1330')), ('2020-11-12', Decimal('34.4495')), ('2020-11-13', Decimal('34.5092')), ('2020-11-14', Decimal('34.7339')), ('2020-11-15', Decimal('34.8257')), ('2020-11-16', Decimal('34.7615')), ('2020-11-17', Decimal('34.7798')), ('2020-11-18', Decimal('34.8716')), ('2020-11-19', Decimal('34.9174')), ('2020-11-20', Decimal('35.0413')), ('2020-11-21', Decimal('35.0826')), ('2020-11-22', Decimal('35.0688')), ('2020-11-23', Decimal('35.1422')), ('2020-11-24', Decimal('35.1101')), ('2020-11-25', Decimal('35.0917')), ('2020-11-26', Decimal('35.1651')), ('2020-11-27', Decimal('35.0917')), ('2020-11-28', Decimal('35.1147')), ('2020-11-29', Decimal('35.1468')), ('2020-11-30', Decimal('35.0413')), ('2020-12-01', Decimal('34.9587')), ('2020-12-02', Decimal('34.9128')), ('2020-12-03', Decimal('34.9725')), ('2020-12-04', Decimal('34.9541')), ('2020-12-05', Decimal('35.0596')), ('2020-12-06', Decimal('34.9862')), ('2020-12-07', Decimal('34.4780'))]
    

    再复杂一点增加时间排序,输出最近的15条数据

    industry_scores = PiraScore.query.filter(PiraScore.industry == industry_code).group_by(PiraScore.datetime).with_entities(PiraScore.datetime,func.avg(PiraScore.score)).order_by(PiraScore.datetime.desc()).limit(15).all()
    
    [('2020-12-07', Decimal('34.4780')), ('2020-12-06', Decimal('34.9862')), ('2020-12-05', Decimal('35.0596')), ('2020-12-04', Decimal('34.9541')), ('2020-12-03', Decimal('34.9725')), ('2020-12-02', Decimal('34.9128')), ('2020-12-01', Decimal('34.9587')), ('2020-11-30', Decimal('35.0413')), ('2020-11-29', Decimal('35.1468')), ('2020-11-28', Decimal('35.1147')), ('2020-11-27', Decimal('35.0917')), ('2020-11-26', Decimal('35.1651')), ('2020-11-25', Decimal('35.0917')), ('2020-11-24', Decimal('35.1101')), ('2020-11-23', Decimal('35.1422'))]
    

    关联查询:外连接根据前后顺序确定谁是主表,调用outerjoin方法,内连接调用join,在join方法中指定关联条件,最后使用with_entities拿到想要的字段

    risk_issue = EntLabelDetail.query.outerjoin(
                LabelDescribe, EntLabelDetail.label_code == LabelDescribe.label_code) \
                .filter(EntLabelDetail.ent_name == fullname, LabelDescribe.score > 5) \
                .with_entities(EntLabelDetail.ent_name, EntLabelDetail.datetime, LabelDescribe.label_name,
                               LabelDescribe.score).all()
    

    子查询:子查询使用subquery,代码分两步完成,第一步完成子查询条件作为subquery,第二步正常查询filter条件调用subquery输出的对象

    # 定义子查询对象
    subquery = EntIndustryInfo.query.filter(EntIndustryInfo.ent_name == ent_1).subquery()
    # 主查询中调用子查询条件
     industry_ents = EntIndustryInfo.query.filter(EntIndustryInfo.ind2_name == subquery.c.ind2_name)\
            .with_entities(EntIndustryInfo.ent_name).limit(10)
        for i in industry_ents:
            print(i)
    

    查看底层执行的sql语言,将子查询和主表拿出来,子查询表的字段和主表一一比对

    2021-01-13 10:26:47,903 INFO sqlalchemy.engine.base.Engine SELECT pira_ent_industry.ent_name AS pira_ent_industry_ent_name 
    FROM pira_ent_industry, (SELECT pira_ent_industry.id AS id, pira_ent_industry.ent_name AS ent_name, pira_ent_industry.ind1_code AS ind1_code, pira_ent_industry.ind1_name AS ind1_name, pira_ent_industry.ind2_name AS ind2_name 
    FROM pira_ent_industry 
    WHERE pira_ent_industry.ent_name = %(ent_name_1)s) AS anon_1 
    WHERE pira_ent_industry.ind2_name = anon_1.ind2_name 
     LIMIT %(param_1)s
    

    修改数据:修改数据也是需要先filter到某一个或多个类对象,然后修改类的属性,最后commit即可

    @app.route('/alter')
    def alter():
        data = Test.query.filter(Test.name == 'wf').first()
        data.score = 0.0
        db.session.commit()
        return '修改成功'
    

    修改多条数据,循环修改类属性

    @app.route('/alter')
    def alter():
        data = Test.query.filter(Test.score > 1.0).all()
        for d in data:
            d.score = 0.0
        db.session.commit()
        return '修改成功'
    
    mysql> select * from test;
    +----+------+-------+
    | id | name | score |
    +----+------+-------+
    |  1 | gp   |     0 |
    |  2 | wf   |     0 |
    |  3 | wbb  |     0 |
    |  4 | sjl  |     0 |
    +----+------+-------+
    

    删除数据:先用filter找到对应的类,调用db.session.delete()commit删除数据

    @app.route('/delete')
    def delete():
        data = Test.query.filter(Test.name == 'wf').first()
        db.session.delete(data)
        db.session.commit()
        return '删除成功'
    

    model对象循环引用

    循环引用这个问题出现的原因是

  • 数据库脚本和主视图脚本不写在同一个脚本,数据库对象class一起写在同一个脚本下,主视图脚本调用
  • 数据库脚本需要先定义db对象才能创建,因为要继承db.Model和使用db.Column等操作
  • db对象的定义在主视图脚本,因为需要传入的app在主视图脚本,但是主视图脚本在最开始就要导入数据库脚本
  • 因此造成主视图脚本在启动一开始就需要数据库脚本,但是数据库脚本在一开始就需要主视图脚本,导致报错
  • 将db对象的定义单独放在一个脚本,并且先不指定app
  • 数据库定义单独写一个脚本,其中调用db脚本中的空db对象,先保证语法正确,调用合法
  • 在主视图函数中调用db和数据库对象,调用db.init_app(app)将app填充给空db
  • 代码实现如下:分别创建db脚本external.py,数据库脚本models.py,主视图脚本main.py

    # external.py
    from flask_sqlalchemy import SQLAlchemy
    db = SQLAlchemy()
    

    在主视图脚本中导入models,再导入db做初始化

    # models.py
    from external import db
    class Book(db.Model):
        __tablename__ = 'book'
        id = db.Column(db.Integer, primary_key=True, autoincrement=True)
        author = db.Column(db.String(100), nullable=False)
    class Author(db.Model):
        __tablename__ = 'author'
        id = db.Column(db.Integer, primary_key=True, autoincrement=True)
        name = db.Column(db.String(100), nullable=False)
        country = db.Column(db.String(100), nullable=False)
    
    # main.py
    from flask import Flask, request
    import config
    from external import db
    from models import Book, Author
    app = Flask(__name__)
    app.config.from_object(config)
    db.init_app(app)
    # db.create_all()
    with app.app_context():
        db.create_all()
    @app.route('/')
    def index():
        return 'welcome'
    @app.route('/add_book')
    def add_book():
        db.session.add(Book(author="gp"))
        db.session.commit()
        return "book插入成功"
    @app.route('/add_author')
    def add_author():
        db.session.add(Author(name="gp", country="china"))
        db.session.commit()
        return "author插入成功"
    if __name__ == '__main__':
        app.run(host="0.0.0.0", port=5000)
    

    在主视图脚本中db初始化后创建所有表要调用app.app_context()上下文,否则db.create_all()报错,如果数据库表已经存在多次create_all不会报错。

    flask-migrate 数据库迁移

    如果要对映射完的表进行修改操作,比如新增字段,修改字段类型,重命名等,由于db.create_all()只能在表不存在时生效,所以必须删除原表,创建新表从头开始,因此原表数据全部丢失。此时需要借助flask-migrate插件进行数据库迁移,不至于丢失数据。
    flask-migrate安装

    pip install flask-migrate
    

    还需要安装Flask-Script以支持使用命令行的方式操作Flask

    pip install Flask-Script
    

    flask-migrate使用步骤
    (1) 编写数据库迁移脚本manager.py
    (2) 准备好数据模型
    (3) 执行迁移命名
    首先编写迁移脚本manager.py,这个是固定写法

    from flask_migrate import Migrate, MigrateCommand
    from flask_script import Manager
    from main import app, db
    migrate = Migrate(app, db)  # 指定迁移的app和db
    manager = Manager(app)
    manager.add_command('db', MigrateCommand)
    if __name__ == '__main__':
        manager.run()
    

    更新数据模型,在models.py中更改Author类,新增2个字段,保存脚本

    class Author(db.Model):
        __tablename__ = 'author'
        id = db.Column(db.Integer, primary_key=True, autoincrement=True)
        name = db.Column(db.String(100), nullable=False)
        country = db.Column(db.String(100), nullable=False)
        age = db.Column(db.Integer, nullable=False)
        sex = db.Column(db.String(10), nullable=False)
    

    运行命令,分别运行

    python manager.py db init
    python manager.py db migrate
    python manager.py db upgrade
    

    第一次运行迁移需要执行init命令,会在目录下新生成migrations目录,在目录下versions子目录下的py脚本记录了每次迁移的变化,比如

    revision = 'dae0a4b7523c'
    down_revision = 'ac73329210e6'
    branch_labels = None
    depends_on = None
    def upgrade():
        # ### commands auto generated by Alembic - please adjust! ###
        op.add_column('author', sa.Column('age', sa.Integer(), nullable=False))
        # ### end Alembic commands ###
    def downgrade():
        # ### commands auto generated by Alembic - please adjust! ###
        op.drop_column('author', 'age')
        # ### end Alembic commands ###
    

    其中记录了当前迁移的版本号revision,上一个版本号down_revisionupgradedowngrade记录了升级和降级的操作,可以这个版本的更新是给author表新增了一个整数类型字段age。
    查看mysql数据迁移变化成功,原数据也存在

    mysql> select * from author where sex is not null;
    +----+------+---------+-----+-----+
    | id | name | country | age | sex |
    +----+------+---------+-----+-----+
    |  1 | gp   | china   |   0 |     |
    +----+------+---------+-----+-----+