420 lines
16 KiB
Python
420 lines
16 KiB
Python
# views.py 路由 + 视图函数
|
||
import os
|
||
import io
|
||
import base64
|
||
import jieba
|
||
import numpy as np
|
||
import pymysql
|
||
from flask import request, jsonify
|
||
from flask import Blueprint
|
||
import hashlib
|
||
|
||
from matplotlib import pyplot as plt
|
||
from sqlalchemy import desc
|
||
|
||
from .utils.api_utils import APIUtils
|
||
from .models import *
|
||
from .utils.prediction import get_finance_job_data, preprocess_data, arima_forecast, prepare_ml_data, \
|
||
train_random_forest
|
||
|
||
blus = Blueprint("user", __name__)
|
||
db_config = {
|
||
'host': '192.168.229.122',
|
||
'user': 'root',
|
||
'password': '123456',
|
||
'database': 'bigdata_ibecs',
|
||
'charset': 'utf8mb4'
|
||
}
|
||
# 注册
|
||
@blus.route('/api/register', methods=['POST'])
|
||
def user_register():
|
||
required_fields = ['username', 'password']
|
||
is_valid, message = APIUtils.validate_json(request.json, required_fields)
|
||
if not is_valid:
|
||
return APIUtils.error_response(message, status_code=400)
|
||
username = request.json['username']
|
||
password = request.json['password']
|
||
|
||
# 检查用户名是否已存在
|
||
existing_user = User.query.filter_by(username=username).first()
|
||
if existing_user:
|
||
return APIUtils.error_response("用户名已经存在!", status_code=400)
|
||
# 哈希处理密码
|
||
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
||
# 创建新用户
|
||
new_user = User(username=username, password=hashed_password,role=1)
|
||
db.session.add(new_user)
|
||
db.session.commit()
|
||
return APIUtils.success_response(message="登录成功!")
|
||
|
||
@blus.route('/items/<int:item_id>', methods=['DELETE'])
|
||
def delete_item(item_id):
|
||
item = JobPosition.query.get_or_404(item_id)
|
||
db.session.delete(item)
|
||
db.session.commit()
|
||
return APIUtils.success_response(message="删除成功!")
|
||
@blus.route('/items', methods=['GET'])
|
||
def get_items():
|
||
page = request.args.get('current', 1, type=int)
|
||
size = request.args.get('size', 10, type=int)
|
||
companyName = request.args.get('companyName', '', type=str)
|
||
city = request.args.get('city', '', type=str)
|
||
query = JobPosition.query
|
||
if companyName:
|
||
query = query.filter(JobPosition.company_name.like(f'%{companyName}%'))
|
||
# 如果传入 city,则进行模糊查询
|
||
if city:
|
||
query = query.filter(JobPosition.city.like(f'%{city}%'))
|
||
# 获取分页数据
|
||
pagination = query.paginate(page=page, per_page=size, error_out=False)
|
||
# 构建响应数据
|
||
response = {
|
||
'list': [item.to_dict() for item in pagination.items],
|
||
'page': {
|
||
'total': pagination.total, # 总记录数
|
||
'current': page, # 当前页码
|
||
'size': size, # 每页请求的记录数
|
||
'pages': pagination.pages # 总页数
|
||
}
|
||
}
|
||
return APIUtils.success_response(data=response, message="获取数据成功")
|
||
|
||
@blus.route('/api/login', methods=['POST'])
|
||
def user_login():
|
||
required_fields = ['username', 'password']
|
||
is_valid, message = APIUtils.validate_json(request.json, required_fields)
|
||
if not is_valid:
|
||
return APIUtils.error_response(message, status_code=400)
|
||
username = request.json['username']
|
||
password = request.json['password']
|
||
user = User.query.filter_by(username=username).first()
|
||
if user is None:
|
||
return APIUtils.error_response("用户名错误或不存在!", status_code=500)
|
||
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
||
if hashed_password != user.password:
|
||
return APIUtils.error_response("密码错误或不存在!", status_code=500)
|
||
return APIUtils.success_response(data={'token': user.id,'userId': user.id, 'username': user.username,'role':user.role}, message="登录成功!")
|
||
|
||
@blus.route('/sys/user/info', methods=['GET'])
|
||
def user_info():
|
||
token = request.headers.get('token')
|
||
|
||
user = User.query.filter_by(id=token).first()
|
||
|
||
return APIUtils.success_response(data={'token': user.id,'userId': user.id, 'username': user.username,'role':user.role}, message="登录成功!")
|
||
|
||
@blus.route('/change_password', methods=['POST'])
|
||
def change_password():
|
||
required_fields = ['username', 'old_password', 'new_password']
|
||
is_valid, message = APIUtils.validate_json(request.json, required_fields)
|
||
if not is_valid:
|
||
return APIUtils.error_response(message, status_code=400)
|
||
username = request.json['username']
|
||
old_password = request.json['old_password']
|
||
new_password = request.json['new_password']
|
||
user = User.query.filter_by(username=username).first()
|
||
if user is None:
|
||
return APIUtils.error_response("用户不存在!", status_code=404)
|
||
hashed_old_password = hashlib.sha256(old_password.encode()).hexdigest()
|
||
if hashed_old_password != user.password:
|
||
return APIUtils.error_response("原始密码不正确!", status_code=401)
|
||
# 哈希处理新密码
|
||
hashed_new_password = hashlib.sha256(new_password.encode()).hexdigest()
|
||
user.password = hashed_new_password
|
||
db.session.commit()
|
||
return APIUtils.success_response(message="密码修改成功!")
|
||
|
||
@blus.route('/api/users/<int:user_id>', methods=['DELETE'])
|
||
def delete_user(user_id):
|
||
# 根据用户 ID 查询用户
|
||
user = User.query.get(user_id)
|
||
if user is None:
|
||
return APIUtils.error_response("用户不存在!", status_code=404)
|
||
# 检查是否为 admin 用户
|
||
if user.username.lower() == 'admin':
|
||
return APIUtils.error_response("无法删除管理员账户!", status_code=403)
|
||
# 删除用户
|
||
db.session.delete(user)
|
||
db.session.commit()
|
||
return APIUtils.success_response(message="用户删除成功!")
|
||
# 用户管理
|
||
@blus.route('/api/users/page', methods=['GET'])
|
||
def get_users():
|
||
# 获取分页参数,默认为第 1 页,每页 10 条记录
|
||
page = request.args.get('page', 1, type=int)
|
||
per_page = request.args.get('limit', 10, type=int)
|
||
# 获取 username 参数,如果没有则为 None
|
||
username = request.args.get('username', type=str)
|
||
# 构建查询,先查询所有用户
|
||
query = User.query
|
||
# 如果提供了 username,则根据 username 进行筛选
|
||
if username:
|
||
query = query.filter(User.username.like(f'%{username}%'))
|
||
|
||
# 执行分页查询
|
||
users_pagination = query.paginate(page=page, per_page=per_page, error_out=False)
|
||
|
||
# 获取用户数据
|
||
users = users_pagination.items
|
||
|
||
# 将用户数据转为 JSON 格式
|
||
users_list = []
|
||
for user in users:
|
||
users_list.append({
|
||
'id': user.id,
|
||
'username': user.username,
|
||
'password': user.password,
|
||
'role': user.role
|
||
|
||
})
|
||
|
||
# 构建响应数据,包括分页信息
|
||
response = {
|
||
'list': users_list,
|
||
'page': {
|
||
'total': users_pagination.total, # 总记录数
|
||
'page': users_pagination.page, # 当前页码
|
||
'limit': users_pagination.per_page # 每页记录数
|
||
}
|
||
}
|
||
|
||
return APIUtils.success_response(data=response, message="获取用户列表成功")
|
||
|
||
# 新增用户
|
||
@blus.route('/api/users', methods=['POST'])
|
||
def add_user():
|
||
data = request.get_json()
|
||
# 验证必填字段
|
||
if not all([data.get('username'),data.get('password')]):
|
||
return APIUtils.error_response(message="用户名、手机号和密码不能为空", code=400)
|
||
# 检查用户名是否已存在
|
||
if User.query.filter_by(username=data['username']).first():
|
||
return APIUtils.error_response(message="用户名已存在")
|
||
# 创建新用户
|
||
new_user = User(
|
||
username=data['username'],
|
||
|
||
password=data['password'], # 注意:实际项目中密码应该加密存储
|
||
role=data.get('role', 1) # 默认普通用户
|
||
)
|
||
|
||
db.session.add(new_user)
|
||
db.session.commit()
|
||
return APIUtils.success_response(message="用户添加成功")
|
||
|
||
# 修改用户
|
||
@blus.route('/api/users/<int:user_id>', methods=['PUT'])
|
||
def update_user(user_id):
|
||
data = request.get_json()
|
||
user = User.query.get(user_id)
|
||
|
||
if not user:
|
||
return APIUtils.error_response(message="用户不存在", code=404)
|
||
# 更新字段(密码单独处理)
|
||
if 'username' in data:
|
||
# 检查新用户名是否已被其他用户使用
|
||
if User.query.filter(User.username == data['username'], User.id != user_id).first():
|
||
return APIUtils.error_response(message="用户名已存在", code=400)
|
||
user.username = data['username']
|
||
|
||
if 'role' in data:
|
||
user.role = data['role']
|
||
db.session.commit()
|
||
return APIUtils.success_response(message="用户信息更新成功")
|
||
|
||
def plot_to_base64(plt_figure):
|
||
"""将matplotlib图表转换为base64编码的图片"""
|
||
buf = io.BytesIO()
|
||
plt_figure.savefig(buf, format='png', dpi=100)
|
||
buf.seek(0)
|
||
return base64.b64encode(buf.read()).decode('utf-8')
|
||
|
||
def generate_prediction_report(job_count_forecast, salary_forecast, forecast_dates):
|
||
"""生成预测报告(返回字典格式)"""
|
||
report = {
|
||
"预测时间范围": f"{forecast_dates[0].strftime('%Y-%m-%d')} 至 {forecast_dates[-1].strftime('%Y-%m-%d')}",
|
||
"总预测招聘岗位数": int(np.sum(job_count_forecast)),
|
||
"日均预测招聘数": round(np.mean(job_count_forecast), 1),
|
||
"预测平均薪资变化": f"{round((salary_forecast[-1] - salary_forecast[0]) / salary_forecast[0] * 100, 2)}%",
|
||
"预测最高薪资": round(np.max(salary_forecast), 2),
|
||
"预测最低薪资": round(np.min(salary_forecast), 2),
|
||
"预测趋势": "上升" if salary_forecast[-1] > salary_forecast[0] else "下降"
|
||
}
|
||
return report
|
||
@blus.route('/api/prediction', methods=['GET'])
|
||
def get_prediction():
|
||
try:
|
||
# 1. 获取数据
|
||
job_data = get_finance_job_data()
|
||
# 2. 数据预处理
|
||
daily_job_data, raw_data = preprocess_data(job_data)
|
||
# 3. 时间序列预测
|
||
job_count_forecast, forecast_dates = arima_forecast(daily_job_data, 'job_count', 30)
|
||
salary_forecast, _ = arima_forecast(daily_job_data, 'salary_avg', 30)
|
||
|
||
# 4. 机器学习预测
|
||
ml_data = prepare_ml_data(daily_job_data)
|
||
job_count_model = train_random_forest(ml_data, 'job_count')
|
||
salary_model = train_random_forest(ml_data, 'salary_avg')
|
||
|
||
# 5. 生成图表并转换为base64
|
||
# 招聘数量趋势图
|
||
plt.figure(figsize=(12, 6))
|
||
plt.plot(daily_job_data.index, daily_job_data['job_count'], label='历史数据')
|
||
plt.plot(forecast_dates, job_count_forecast, label='预测数据', color='red')
|
||
plt.title('金融行业招聘数量趋势预测')
|
||
plt.xlabel('日期')
|
||
plt.ylabel('数量')
|
||
plt.legend()
|
||
plt.grid()
|
||
job_count_plot = plot_to_base64(plt)
|
||
plt.close()
|
||
|
||
# 平均薪资趋势图
|
||
plt.figure(figsize=(12, 6))
|
||
plt.plot(daily_job_data.index, daily_job_data['salary_avg'], label='历史数据')
|
||
plt.plot(forecast_dates, salary_forecast, label='预测数据', color='red')
|
||
plt.title('金融行业平均薪资趋势预测')
|
||
plt.xlabel('日期')
|
||
plt.ylabel('薪资')
|
||
plt.legend()
|
||
plt.grid()
|
||
salary_plot = plot_to_base64(plt)
|
||
plt.close()
|
||
|
||
# 6. 生成报告
|
||
prediction_report = generate_prediction_report(job_count_forecast, salary_forecast, forecast_dates)
|
||
|
||
# 准备响应数据
|
||
response_data = {
|
||
"report": prediction_report,
|
||
"plots": {
|
||
"job_count": job_count_plot,
|
||
"salary": salary_plot
|
||
},
|
||
"forecast_data": {
|
||
"dates": [date.strftime('%Y-%m-%d') for date in forecast_dates],
|
||
"job_count": job_count_forecast.tolist(),
|
||
"salary": salary_forecast.tolist()
|
||
}
|
||
}
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "预测数据获取成功",
|
||
"data": response_data
|
||
})
|
||
except Exception as e:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"预测数据获取失败: {str(e)}",
|
||
"data": None
|
||
}), 500
|
||
|
||
# SQL查询
|
||
@blus.route('/api/mysql', methods=['POST'])
|
||
def mysql():
|
||
data = request.get_json()
|
||
|
||
# 检查 SQL 参数是否存在
|
||
if not data['sql']:
|
||
return APIUtils.error_response(message="没有sql参数")
|
||
try:
|
||
# 连接数据库
|
||
connection = pymysql.connect(**db_config)
|
||
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
|
||
# 自定义 SQL 查询
|
||
cursor.execute(data['sql'])
|
||
# 获取查询结果
|
||
results = cursor.fetchall()
|
||
return results
|
||
except pymysql.MySQLError as e:
|
||
return APIUtils.error_response(message=f"数据库连接失败:{str(e)}")
|
||
except Exception as e:
|
||
return APIUtils.error_response(message=f"查询执行失败:{str(e)}")
|
||
|
||
@blus.route('/api/word', methods=['GET'])
|
||
def word():
|
||
try:
|
||
# 构造停用词文件的动态路径
|
||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||
stopwords_file = os.path.join(base_dir, 'utils', 'stopwords.txt')
|
||
|
||
# 读取停用词文件,存入集合,提高查找效率
|
||
stopwords = set()
|
||
with open(stopwords_file, encoding='utf-8') as f:
|
||
for line in f:
|
||
word = line.strip()
|
||
if word:
|
||
stopwords.add(word)
|
||
|
||
connection = pymysql.connect(**db_config)
|
||
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
|
||
# 获取 job_description 字段
|
||
query = "SELECT job_description FROM job_positions LIMIT 1000"
|
||
cursor.execute(query)
|
||
# 词频统计字典
|
||
word_counts = {}
|
||
# 处理每个 job_description
|
||
for row in cursor.fetchall():
|
||
job_desc = row['job_description']
|
||
# 分词并统计
|
||
for word in jieba.cut(job_desc):
|
||
word = word.strip()
|
||
# 过滤空字符串和停用词
|
||
if word and word not in stopwords:
|
||
word_counts[word] = word_counts.get(word, 0) + 1
|
||
# 转换为要求的格式
|
||
result = [
|
||
{"name": word, "value": count}
|
||
for word, count in word_counts.items()
|
||
]
|
||
|
||
return APIUtils.success_response(data=result)
|
||
except pymysql.MySQLError as err:
|
||
return APIUtils.error_response(message=str(err))
|
||
finally:
|
||
connection.close()
|
||
@blus.route('/api/caiji', methods=['GET'])
|
||
def caiji():
|
||
try:
|
||
# 构造停用词文件的动态路径
|
||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||
stopwords_file = os.path.join(base_dir, 'utils', 'stopwords.txt')
|
||
|
||
# 读取停用词文件,存入集合,提高查找效率
|
||
stopwords = set()
|
||
with open(stopwords_file, encoding='utf-8') as f:
|
||
for line in f:
|
||
word = line.strip()
|
||
if word:
|
||
stopwords.add(word)
|
||
|
||
connection = pymysql.connect(**db_config)
|
||
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
|
||
# 获取 job_description 字段
|
||
query = "SELECT job_description FROM job_positions LIMIT 1000"
|
||
cursor.execute(query)
|
||
# 词频统计字典
|
||
word_counts = {}
|
||
# 处理每个 job_description
|
||
for row in cursor.fetchall():
|
||
job_desc = row['job_description']
|
||
# 分词并统计
|
||
for word in jieba.cut(job_desc):
|
||
word = word.strip()
|
||
# 过滤空字符串和停用词
|
||
if word and word not in stopwords:
|
||
word_counts[word] = word_counts.get(word, 0) + 1
|
||
# 转换为要求的格式
|
||
result = [
|
||
{"name": word, "value": count}
|
||
for word, count in word_counts.items()
|
||
]
|
||
|
||
return APIUtils.success_response(data=result)
|
||
except pymysql.MySQLError as err:
|
||
return APIUtils.error_response(message=str(err))
|
||
finally:
|
||
connection.close()
|