bigdata-ibecs/App/views.py
2025-06-23 14:15:50 +08:00

420 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# views.py 路由 + 视图函数
import os
import io
import base64
import jieba
import numpy as np
import pymysql
from flask import request, jsonify
from flask import Blueprint
import hashlib
from matplotlib import pyplot as plt
from sqlalchemy import desc
from .utils.api_utils import APIUtils
from .models import *
from .utils.prediction import get_finance_job_data, preprocess_data, arima_forecast, prepare_ml_data, \
train_random_forest
blus = Blueprint("user", __name__)
db_config = {
'host': '192.168.229.122',
'user': 'root',
'password': '123456',
'database': 'bigdata_ibecs',
'charset': 'utf8mb4'
}
# 注册
@blus.route('/api/register', methods=['POST'])
def user_register():
required_fields = ['username', 'password']
is_valid, message = APIUtils.validate_json(request.json, required_fields)
if not is_valid:
return APIUtils.error_response(message, status_code=400)
username = request.json['username']
password = request.json['password']
# 检查用户名是否已存在
existing_user = User.query.filter_by(username=username).first()
if existing_user:
return APIUtils.error_response("用户名已经存在!", status_code=400)
# 哈希处理密码
hashed_password = hashlib.sha256(password.encode()).hexdigest()
# 创建新用户
new_user = User(username=username, password=hashed_password,role=1)
db.session.add(new_user)
db.session.commit()
return APIUtils.success_response(message="登录成功!")
@blus.route('/items/<int:item_id>', methods=['DELETE'])
def delete_item(item_id):
item = JobPosition.query.get_or_404(item_id)
db.session.delete(item)
db.session.commit()
return APIUtils.success_response(message="删除成功!")
@blus.route('/items', methods=['GET'])
def get_items():
page = request.args.get('current', 1, type=int)
size = request.args.get('size', 10, type=int)
companyName = request.args.get('companyName', '', type=str)
city = request.args.get('city', '', type=str)
query = JobPosition.query
if companyName:
query = query.filter(JobPosition.company_name.like(f'%{companyName}%'))
# 如果传入 city则进行模糊查询
if city:
query = query.filter(JobPosition.city.like(f'%{city}%'))
# 获取分页数据
pagination = query.paginate(page=page, per_page=size, error_out=False)
# 构建响应数据
response = {
'list': [item.to_dict() for item in pagination.items],
'page': {
'total': pagination.total, # 总记录数
'current': page, # 当前页码
'size': size, # 每页请求的记录数
'pages': pagination.pages # 总页数
}
}
return APIUtils.success_response(data=response, message="获取数据成功")
@blus.route('/api/login', methods=['POST'])
def user_login():
required_fields = ['username', 'password']
is_valid, message = APIUtils.validate_json(request.json, required_fields)
if not is_valid:
return APIUtils.error_response(message, status_code=400)
username = request.json['username']
password = request.json['password']
user = User.query.filter_by(username=username).first()
if user is None:
return APIUtils.error_response("用户名错误或不存在!", status_code=500)
hashed_password = hashlib.sha256(password.encode()).hexdigest()
if hashed_password != user.password:
return APIUtils.error_response("密码错误或不存在!", status_code=500)
return APIUtils.success_response(data={'token': user.id,'userId': user.id, 'username': user.username,'role':user.role}, message="登录成功!")
@blus.route('/sys/user/info', methods=['GET'])
def user_info():
token = request.headers.get('token')
user = User.query.filter_by(id=token).first()
return APIUtils.success_response(data={'token': user.id,'userId': user.id, 'username': user.username,'role':user.role}, message="登录成功!")
@blus.route('/change_password', methods=['POST'])
def change_password():
required_fields = ['username', 'old_password', 'new_password']
is_valid, message = APIUtils.validate_json(request.json, required_fields)
if not is_valid:
return APIUtils.error_response(message, status_code=400)
username = request.json['username']
old_password = request.json['old_password']
new_password = request.json['new_password']
user = User.query.filter_by(username=username).first()
if user is None:
return APIUtils.error_response("用户不存在!", status_code=404)
hashed_old_password = hashlib.sha256(old_password.encode()).hexdigest()
if hashed_old_password != user.password:
return APIUtils.error_response("原始密码不正确!", status_code=401)
# 哈希处理新密码
hashed_new_password = hashlib.sha256(new_password.encode()).hexdigest()
user.password = hashed_new_password
db.session.commit()
return APIUtils.success_response(message="密码修改成功!")
@blus.route('/api/users/<int:user_id>', methods=['DELETE'])
def delete_user(user_id):
# 根据用户 ID 查询用户
user = User.query.get(user_id)
if user is None:
return APIUtils.error_response("用户不存在!", status_code=404)
# 检查是否为 admin 用户
if user.username.lower() == 'admin':
return APIUtils.error_response("无法删除管理员账户!", status_code=403)
# 删除用户
db.session.delete(user)
db.session.commit()
return APIUtils.success_response(message="用户删除成功!")
# 用户管理
@blus.route('/api/users/page', methods=['GET'])
def get_users():
# 获取分页参数,默认为第 1 页,每页 10 条记录
page = request.args.get('page', 1, type=int)
per_page = request.args.get('limit', 10, type=int)
# 获取 username 参数,如果没有则为 None
username = request.args.get('username', type=str)
# 构建查询,先查询所有用户
query = User.query
# 如果提供了 username则根据 username 进行筛选
if username:
query = query.filter(User.username.like(f'%{username}%'))
# 执行分页查询
users_pagination = query.paginate(page=page, per_page=per_page, error_out=False)
# 获取用户数据
users = users_pagination.items
# 将用户数据转为 JSON 格式
users_list = []
for user in users:
users_list.append({
'id': user.id,
'username': user.username,
'password': user.password,
'role': user.role
})
# 构建响应数据,包括分页信息
response = {
'list': users_list,
'page': {
'total': users_pagination.total, # 总记录数
'page': users_pagination.page, # 当前页码
'limit': users_pagination.per_page # 每页记录数
}
}
return APIUtils.success_response(data=response, message="获取用户列表成功")
# 新增用户
@blus.route('/api/users', methods=['POST'])
def add_user():
data = request.get_json()
# 验证必填字段
if not all([data.get('username'),data.get('password')]):
return APIUtils.error_response(message="用户名、手机号和密码不能为空", code=400)
# 检查用户名是否已存在
if User.query.filter_by(username=data['username']).first():
return APIUtils.error_response(message="用户名已存在")
# 创建新用户
new_user = User(
username=data['username'],
password=data['password'], # 注意:实际项目中密码应该加密存储
role=data.get('role', 1) # 默认普通用户
)
db.session.add(new_user)
db.session.commit()
return APIUtils.success_response(message="用户添加成功")
# 修改用户
@blus.route('/api/users/<int:user_id>', methods=['PUT'])
def update_user(user_id):
data = request.get_json()
user = User.query.get(user_id)
if not user:
return APIUtils.error_response(message="用户不存在", code=404)
# 更新字段(密码单独处理)
if 'username' in data:
# 检查新用户名是否已被其他用户使用
if User.query.filter(User.username == data['username'], User.id != user_id).first():
return APIUtils.error_response(message="用户名已存在", code=400)
user.username = data['username']
if 'role' in data:
user.role = data['role']
db.session.commit()
return APIUtils.success_response(message="用户信息更新成功")
def plot_to_base64(plt_figure):
"""将matplotlib图表转换为base64编码的图片"""
buf = io.BytesIO()
plt_figure.savefig(buf, format='png', dpi=100)
buf.seek(0)
return base64.b64encode(buf.read()).decode('utf-8')
def generate_prediction_report(job_count_forecast, salary_forecast, forecast_dates):
"""生成预测报告(返回字典格式)"""
report = {
"预测时间范围": f"{forecast_dates[0].strftime('%Y-%m-%d')}{forecast_dates[-1].strftime('%Y-%m-%d')}",
"总预测招聘岗位数": int(np.sum(job_count_forecast)),
"日均预测招聘数": round(np.mean(job_count_forecast), 1),
"预测平均薪资变化": f"{round((salary_forecast[-1] - salary_forecast[0]) / salary_forecast[0] * 100, 2)}%",
"预测最高薪资": round(np.max(salary_forecast), 2),
"预测最低薪资": round(np.min(salary_forecast), 2),
"预测趋势": "上升" if salary_forecast[-1] > salary_forecast[0] else "下降"
}
return report
@blus.route('/api/prediction', methods=['GET'])
def get_prediction():
try:
# 1. 获取数据
job_data = get_finance_job_data()
# 2. 数据预处理
daily_job_data, raw_data = preprocess_data(job_data)
# 3. 时间序列预测
job_count_forecast, forecast_dates = arima_forecast(daily_job_data, 'job_count', 30)
salary_forecast, _ = arima_forecast(daily_job_data, 'salary_avg', 30)
# 4. 机器学习预测
ml_data = prepare_ml_data(daily_job_data)
job_count_model = train_random_forest(ml_data, 'job_count')
salary_model = train_random_forest(ml_data, 'salary_avg')
# 5. 生成图表并转换为base64
# 招聘数量趋势图
plt.figure(figsize=(12, 6))
plt.plot(daily_job_data.index, daily_job_data['job_count'], label='历史数据')
plt.plot(forecast_dates, job_count_forecast, label='预测数据', color='red')
plt.title('金融行业招聘数量趋势预测')
plt.xlabel('日期')
plt.ylabel('数量')
plt.legend()
plt.grid()
job_count_plot = plot_to_base64(plt)
plt.close()
# 平均薪资趋势图
plt.figure(figsize=(12, 6))
plt.plot(daily_job_data.index, daily_job_data['salary_avg'], label='历史数据')
plt.plot(forecast_dates, salary_forecast, label='预测数据', color='red')
plt.title('金融行业平均薪资趋势预测')
plt.xlabel('日期')
plt.ylabel('薪资')
plt.legend()
plt.grid()
salary_plot = plot_to_base64(plt)
plt.close()
# 6. 生成报告
prediction_report = generate_prediction_report(job_count_forecast, salary_forecast, forecast_dates)
# 准备响应数据
response_data = {
"report": prediction_report,
"plots": {
"job_count": job_count_plot,
"salary": salary_plot
},
"forecast_data": {
"dates": [date.strftime('%Y-%m-%d') for date in forecast_dates],
"job_count": job_count_forecast.tolist(),
"salary": salary_forecast.tolist()
}
}
return jsonify({
"status": "success",
"message": "预测数据获取成功",
"data": response_data
})
except Exception as e:
return jsonify({
"status": "error",
"message": f"预测数据获取失败: {str(e)}",
"data": None
}), 500
# SQL查询
@blus.route('/api/mysql', methods=['POST'])
def mysql():
data = request.get_json()
# 检查 SQL 参数是否存在
if not data['sql']:
return APIUtils.error_response(message="没有sql参数")
try:
# 连接数据库
connection = pymysql.connect(**db_config)
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
# 自定义 SQL 查询
cursor.execute(data['sql'])
# 获取查询结果
results = cursor.fetchall()
return results
except pymysql.MySQLError as e:
return APIUtils.error_response(message=f"数据库连接失败:{str(e)}")
except Exception as e:
return APIUtils.error_response(message=f"查询执行失败:{str(e)}")
@blus.route('/api/word', methods=['GET'])
def word():
try:
# 构造停用词文件的动态路径
base_dir = os.path.dirname(os.path.abspath(__file__))
stopwords_file = os.path.join(base_dir, 'utils', 'stopwords.txt')
# 读取停用词文件,存入集合,提高查找效率
stopwords = set()
with open(stopwords_file, encoding='utf-8') as f:
for line in f:
word = line.strip()
if word:
stopwords.add(word)
connection = pymysql.connect(**db_config)
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
# 获取 job_description 字段
query = "SELECT job_description FROM job_positions LIMIT 1000"
cursor.execute(query)
# 词频统计字典
word_counts = {}
# 处理每个 job_description
for row in cursor.fetchall():
job_desc = row['job_description']
# 分词并统计
for word in jieba.cut(job_desc):
word = word.strip()
# 过滤空字符串和停用词
if word and word not in stopwords:
word_counts[word] = word_counts.get(word, 0) + 1
# 转换为要求的格式
result = [
{"name": word, "value": count}
for word, count in word_counts.items()
]
return APIUtils.success_response(data=result)
except pymysql.MySQLError as err:
return APIUtils.error_response(message=str(err))
finally:
connection.close()
@blus.route('/api/caiji', methods=['GET'])
def caiji():
try:
# 构造停用词文件的动态路径
base_dir = os.path.dirname(os.path.abspath(__file__))
stopwords_file = os.path.join(base_dir, 'utils', 'stopwords.txt')
# 读取停用词文件,存入集合,提高查找效率
stopwords = set()
with open(stopwords_file, encoding='utf-8') as f:
for line in f:
word = line.strip()
if word:
stopwords.add(word)
connection = pymysql.connect(**db_config)
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
# 获取 job_description 字段
query = "SELECT job_description FROM job_positions LIMIT 1000"
cursor.execute(query)
# 词频统计字典
word_counts = {}
# 处理每个 job_description
for row in cursor.fetchall():
job_desc = row['job_description']
# 分词并统计
for word in jieba.cut(job_desc):
word = word.strip()
# 过滤空字符串和停用词
if word and word not in stopwords:
word_counts[word] = word_counts.get(word, 0) + 1
# 转换为要求的格式
result = [
{"name": word, "value": count}
for word, count in word_counts.items()
]
return APIUtils.success_response(data=result)
except pymysql.MySQLError as err:
return APIUtils.error_response(message=str(err))
finally:
connection.close()