close_price
double
,
created_at
timestamp
default
now(),
updated_at
timestamp
default
now(),
primary
key (id)
)COMMENT
=
'股票价格表'
;
create
index ids_stockprices
on
stock_prices(ticker, as_of_date);
create
index ids_stockpricestage
on
stock_prices_stage(ticker, as_of_date);
二、使用 airflow Connection 管理数据库连接信息
在上一节代码的基础上,将保存到文件的数据转存到数据库中,V2版本的代码如下:
download_stock_price_v2.py
2.1 传统连接方法
"""Example DAG demonstrating the usage of the BashOperator."""
from datetime import timedelta
from textwrap import dedent
import yfinance as yf
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
import mysql.connector
def download_price(*args, **context):
stock_list = get_tickers(context)
for ticker in stock_list:
dat = yf.Ticker(ticker)
hist = dat.history(period="1mo")
with open(get_file_path(ticker), 'w') as writer:
hist.to_csv(writer, index=True)
print("Finished downloading price data for " + ticker)
def get_file_path(ticker):
return f'./{ticker}.csv'
def load_price_data(ticker):
with open(get_file_path(ticker), 'r') as reader:
lines = reader.readlines()
return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']
def get_tickers(context):
stock_list = Variable.get("stock_list_json", deserialize_json=True)
stocks = context["dag_run"].conf.get("stocks")
if stocks:
stock_list = stocks
return stock_list
def save_to_mysql_stage(*args, **context):
tickers = get_tickers(context)
mydb = mysql.connector.connect(
host="98.14.13.15",
user="root",
password="Quant888",
database="demodb",
port=3307
mycursor = mydb.cursor()
for ticker in tickers:
val = load_price_data(ticker)
print(f"{ticker} length={len(val)} {val[1]}")
sql = """INSERT INTO stock_prices_stage
(ticker, as_of_date, open_price, high_price, low_price, close_price)
VALUES (%s,%s,%s,%s,%s,%s)"""
mycursor.executemany(sql, val)
mydb.commit()
print(mycursor.rowcount, "record inserted.")
default_args = {
'owner': 'airflow'
with DAG(
dag_id='download_stock_price_v2',
default_args=default_args,
description='download stock price and save to local csv files and save to database',
schedule_interval=None,
start_date=days_ago(2),
tags=['quantdata'],
) as dag:
dag.doc_md = """
This DAG download stock price
download_task = PythonOperator(
task_id="download_prices",
python_callable=download_price,
provide_context=True
save_to_mysql_task = PythonOperator(
task_id="save_to_database",
python_callable=save_to_mysql_stage,
provide_context=True
download_task >> save_to_mysql_task
然后在 airflow 后台手动触发执行,前两次执行失败,后边调试后,执行成功了
可以看到数据已经入库了:
2.2 airflow Connection管理连接信息
上边的demo有些问题,将数据库的连接直接硬编码到代码中了,这样后期维护不是很好,airflow给我们提供了 Connections 连接方法,可以使用该方法将连接信息直接写入到这里即可。
选择连接类型,缺少了MySQL连接类型:
Conn Type missing? Make sure you
请看官方文档:
airflow.apache.org/docs/apache…
airflow.apache.org/docs/apache…
airflow.apache.org/docs/#provi…
$ pip install apache-airflow-providers-mysql
然后重新刷新连接页面,可以看到连接类型 MySQL 已经出现了:
然后填入相关的数据库连接信息:
然后对代码进行修改:
def save_to_mysql_stage(*args, **context):
tickers = get_tickers(context)
# 连接数据库(硬编码方式连接)
mydb = mysql.connector.connect(
host="98.14.14.145",
user="root",
password="Quant888",
database="demodb",
port=3307
from airflow.hooks.base_hook import BaseHook
conn = BaseHook.get_connection('demodb')
mydb = mysql.connector.connect(
host=conn.host,
user=conn.login,
password=conn.password,
database=conn.schema,
port=conn.port
mycursor = mydb.cursor()
for ticker in tickers:
val = load_price_data(ticker)
print(f"{ticker} length={len(val)} {val[1]}")
sql = """INSERT INTO stock_prices_stage
(ticker, as_of_date, open_price, high_price, low_price, close_price)
VALUES (%s,%s,%s,%s,%s,%s)"""
mycursor.executemany(sql, val)
mydb.commit()
print(mycursor.rowcount, "record inserted.")
完整代码:
"""Example DAG demonstrating the usage of the BashOperator."""
from datetime import timedelta
from textwrap import dedent
import yfinance as yf
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
import mysql.connector
def download_price(*args, **context):
stock_list = get_tickers(context)
for ticker in stock_list:
dat = yf.Ticker(ticker)
hist = dat.history(period="1mo")
with open(get_file_path(ticker), 'w') as writer:
hist.to_csv(writer, index=True)
print("Finished downloading price data for " + ticker)
def get_file_path(ticker):
return f'./{ticker}.csv'
def load_price_data(ticker):
with open(get_file_path(ticker), 'r') as reader:
lines = reader.readlines()
return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']
def get_tickers(context):
stock_list = Variable.get("stock_list_json", deserialize_json=True)
stocks = context["dag_run"].conf.get("stocks")
if stocks:
stock_list = stocks
return stock_list
def save_to_mysql_stage(*args, **context):
tickers = get_tickers(context)
# 连接数据库(硬编码方式连接)
mydb = mysql.connector.connect(
host="98.14.13.14",
user="root",
password="Quan888",
database="demodb",
port=3307
from airflow.hooks.base_hook import BaseHook
conn = BaseHook.get_connection('demodb')
mydb = mysql.connector.connect(
host=conn.host,
user=conn.login,
password=conn.password,
database=conn.schema,
port=conn.port
mycursor = mydb.cursor()
for ticker in tickers:
val = load_price_data(ticker)
print(f"{ticker} length={len(val)} {val[1]}")
sql = """INSERT INTO stock_prices_stage
(ticker, as_of_date, open_price, high_price, low_price, close_price)
VALUES (%s,%s,%s,%s,%s,%s)"""
mycursor.executemany(sql, val)
mydb.commit()
print(mycursor.rowcount, "record inserted.")
default_args = {
'owner': 'airflow'
with DAG(
dag_id='download_stock_price_v2',
default_args=default_args,
description='download stock price and save to local csv files and save to database',
schedule_interval=None,
start_date=days_ago(2),
tags=['quantdata'],
) as dag:
dag.doc_md = """
This DAG download stock price
download_task = PythonOperator(
task_id="download_prices",
python_callable=download_price,
provide_context=True
save_to_mysql_task = PythonOperator(
task_id="save_to_database",
python_callable=save_to_mysql_stage,
provide_context=True
download_task >> save_to_mysql_task
三、使用 MyqLOperator 执行数据库操作
在 dags/
目录下新建sql文件,用来合并缓冲表(stage)的数据到正式表。
merge_stock_price.sql
UPDATE stock_prices p, stock_prices_stage s
SET p.open_price = s.open_price,
p.high_price = s.high_price,
p.low_price = s.low_price,
p.close_price = s.close_price,
updated_at = now()
WHERE p.ticker = s.ticker
AND p.as_of_date = s.as_of_date;
INSERT INTO stock_prices
(ticker,as_of_date,open_price,high_price,low_price,close_price)
SELECT ticker,as_of_date,open_price,high_price,low_price,close_price
FROM stock_prices_stage s
WHERE NOT EXISTS
(SELECT 1 FROM stock_prices p
WHERE p.ticker = s.ticker
AND p.as_of_date = s.as_of_date);
TRUNCATE TABLE stock_prices_stage;
在 download_stock_price_v2.py
文件新建 MySQL task 任务:
需要先引入:
from airflow.providers.mysql.operators.mysql import MySqlOperator
mysql_task = MySqlOperator(
task_id="merge_stock_price",
mysql_conn_id='demodb',
sql="merge_stock_price.sql",
dag=dag,
download_task >> save_to_mysql_task >> mysql_task
完整代码:
"""Example DAG demonstrating the usage of the BashOperator."""
from datetime import timedelta
from textwrap import dedent
import yfinance as yf
import mysql.connector
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.mysql.operators.mysql import MySqlOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
def download_price(*args, **context):
stock_list = get_tickers(context)
for ticker in stock_list:
dat = yf.Ticker(ticker)
hist = dat.history(period="1mo")
with open(get_file_path(ticker), 'w') as writer:
hist.to_csv(writer, index=True)
print("Finished downloading price data for " + ticker)
def get_file_path(ticker):
return f'./{ticker}.csv'
def load_price_data(ticker):
with open(get_file_path(ticker), 'r') as reader:
lines = reader.readlines()
return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']
def get_tickers(context):
stock_list = Variable.get("stock_list_json", deserialize_json=True)
stocks = context["dag_run"].conf.get("stocks")
if stocks:
stock_list = stocks
return stock_list
def save_to_mysql_stage(*args, **context):
tickers = get_tickers(context)
# 连接数据库(硬编码方式连接)
mydb = mysql.connector.connect(
host="98.14.14.15",
user="root",
password="Quan888",
database="demodb",
port=3307
from airflow.hooks.base_hook import BaseHook
conn = BaseHook.get_connection('demodb')
mydb = mysql.connector.connect(
host=conn.host,
user=conn.login,
password=conn.password,
database=conn.schema,
port=conn.port
mycursor = mydb.cursor()
for ticker in tickers:
val = load_price_data(ticker)
print(f"{ticker} length={len(val)} {val[1]}")
sql = """INSERT INTO stock_prices_stage
(ticker, as_of_date, open_price, high_price, low_price, close_price)
VALUES (%s,%s,%s,%s,%s,%s)"""
mycursor.executemany(sql, val)
mydb.commit()
print(mycursor.rowcount, "record inserted.")
default_args = {
'owner': 'airflow'
with DAG(
dag_id='download_stock_price_v2',
default_args=default_args,
description='download stock price and save to local csv files and save to database',
schedule_interval=None,
start_date=days_ago(2),
tags=['quantdata'],
) as dag:
dag.doc_md = """
This DAG download stock price
download_task = PythonOperator(
task_id="download_prices",
python_callable=download_price,
provide_context=True
save_to_mysql_task = PythonOperator(
task_id="save_to_database",
python_callable=save_to_mysql_stage,
provide_context=True
mysql_task = MySqlOperator(
task_id="merge_stock_price",
mysql_conn_id='demodb',
sql="merge_stock_price.sql",
dag=dag,
download_task >> save_to_mysql_task >> mysql_task
然后手动执行airflow,可以看到已经执行成功了:
然后看相关表数据,也已经更新成功了
四、使用 XComs 在任务之间传递数据
XComs 概念
XComs(“交叉通信”的缩写)是一种让任务相互通信的机制,因为默认情况下任务是完全隔离的,并且可能运行在完全不同的机器上。
XCom 由一个键(本质上是它的名称)以及它来自的 task_id 和 dag_id 标识。它们可以具有任何(可序列化的)值,但它们仅适用于少量数据;不要使用它们来传递大值,例如数据帧。
简单一句话,XComs可以在多个task之间进行通信(数据的传递)。
XComs are explicitly "pushed" and "pulled" to/from their storage using the xcom_push and xcom_pull methods on Task Instances. Many operators will auto-push their results into an XCom key called return_value
if the do_xcom_push argument is set to True (as it is by default), and @task functions do this as well.
value = task_instance.xcom_pull(task_ids='pushing_task')
使用场景:增加一支不存在股票,然后对这只股票进行验证,存在的股票才可以传入到后边。
修改 download_stock_price_v2.py
文件下载代码:
然后将股票保存到MySQL stage 时,通过上一步返回的股票来获取已经过滤的ticker。
download_stock_price_v2.py
完整代码
"""Example DAG demonstrating the usage of the BashOperator."""
from datetime import timedelta
from textwrap import dedent
import yfinance as yf
import mysql.connector
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.mysql.operators.mysql import MySqlOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
def download_price(*args, **context):
stock_list = get_tickers(context)
valid_tickers = []
for ticker in stock_list:
dat = yf.Ticker(ticker)
hist = dat.history(period="1mo")
if hist.shape[0] > 0:
valid_tickers.append(ticker)
else:
continue
with open(get_file_path(ticker), 'w') as writer:
hist.to_csv(writer, index=True)
print("Finished downloading price data for " + ticker)
return valid_tickers
def get_file_path(ticker):
return f'./{ticker}.csv'
def load_price_data(ticker):
with open(get_file_path(ticker), 'r') as reader:
lines = reader.readlines()
return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']
def get_tickers(context):
stock_list = Variable.get("stock_list_json", deserialize_json=True)
stocks = context["dag_run"].conf.get("stocks")
if stocks:
stock_list = stocks
return stock_list
def save_to_mysql_stage(*args, **context):
tickers = context['ti'].xcom_pull(task_ids='download_prices')
print(f"received tickers:{tickers}")
# 连接数据库(硬编码方式连接)
mydb = mysql.connector.connect(
host="98.14.14.15",
user="root",
password="Quant888",
database="demodb",
port=3307
from airflow.hooks.base_hook import BaseHook
conn = BaseHook.get_connection('demodb')
mydb = mysql.connector.connect(
host=conn.host,
user=conn.login,
password=conn.password,
database=conn.schema,
port=conn.port
mycursor = mydb.cursor()
for ticker in tickers:
val = load_price_data(ticker)
print(f"{ticker} length={len(val)} {val[1]}")
sql = """INSERT INTO stock_prices_stage
(ticker, as_of_date, open_price, high_price, low_price, close_price)
VALUES (%s,%s,%s,%s,%s,%s)"""
mycursor.executemany(sql, val)
mydb.commit()
print(mycursor.rowcount, "record inserted.")
default_args = {
'owner': 'airflow'
with DAG(
dag_id='download_stock_price_v2',
default_args=default_args,
description='download stock price and save to local csv files and save to database',
schedule_interval=None,
start_date=days_ago(2),
tags=['quantdata'],
) as dag:
dag.doc_md = """
This DAG download stock price
download_task = PythonOperator(
task_id="download_prices",
python_callable=download_price,
provide_context=True
save_to_mysql_task = PythonOperator(
task_id="save_to_database",
python_callable=save_to_mysql_stage,
provide_context=True
mysql_task = MySqlOperator(
task_id="merge_stock_price",
mysql_conn_id='demodb',
sql="merge_stock_price.sql",
dag=dag,
download_task >> save_to_mysql_task >> mysql_task
然后在 Variables
增加一个不存在的 ticker(FBXXOO),以此来验证Xcom数据传递进行验证:
手动执行DAG,可以通过日志打印看到已经获取到了 Xcom tickers = context['ti'].xcom_pull(task_ids='download_prices')
上一个任务传递过来的数据了。
相关文章:
Airflow 相关概念文档
Airflow XComs官方文档