# 1.导入模块
import argparse
def get_parser():
# 2.创建解析器
parser = argparse.ArgumentParser()
# 3.添加参数
parser.add_argument("--trainData", type=str, default="data/train.json")
parser.add_argument("--validData", type=str, default="data/dev.json")
parser.add_argument("--num_epochs", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--freeze_bert", default=False, action="store_true",
help="If provided, freeze the layers of bert")
parser.add_argument("--emb_dropout", default=True,action="store_true",
help="If provided, add dropout to the output embedding of bert"
)
parser.add_argument("--dropout_rate", type=float, default=0.5)
parser.add_argument("--num_layers", dest="num_layers", default=1)
parser.add_argument("--hidden_units", dest="hidden_units", default=128)
parser.add_argument("--log_dir", type=str, default="logs/")
parser.add_argument("--model_path", type=str, default="checkpoints/")
# 4.解析参数
hp = parser.parse_args()
return hp
if __name__ == "__main__":
args = get_parser()
print(args.__dict__)