# views.py 路由 + 视图函数 import os import io import base64 import jieba import numpy as np import pymysql from flask import request, jsonify from flask import Blueprint import hashlib from matplotlib import pyplot as plt from sqlalchemy import desc from .utils.api_utils import APIUtils from .models import * from .utils.prediction import get_finance_job_data, preprocess_data, arima_forecast, prepare_ml_data, \ train_random_forest blus = Blueprint("user", __name__) db_config = { 'host': '192.168.229.122', 'user': 'root', 'password': '123456', 'database': 'bigdata_ibecs', 'charset': 'utf8mb4' } # 注册 @blus.route('/api/register', methods=['POST']) def user_register(): required_fields = ['username', 'password'] is_valid, message = APIUtils.validate_json(request.json, required_fields) if not is_valid: return APIUtils.error_response(message, status_code=400) username = request.json['username'] password = request.json['password'] # 检查用户名是否已存在 existing_user = User.query.filter_by(username=username).first() if existing_user: return APIUtils.error_response("用户名已经存在!", status_code=400) # 哈希处理密码 hashed_password = hashlib.sha256(password.encode()).hexdigest() # 创建新用户 new_user = User(username=username, password=hashed_password,role=1) db.session.add(new_user) db.session.commit() return APIUtils.success_response(message="登录成功!") @blus.route('/items/', methods=['DELETE']) def delete_item(item_id): item = JobPosition.query.get_or_404(item_id) db.session.delete(item) db.session.commit() return APIUtils.success_response(message="删除成功!") @blus.route('/items', methods=['GET']) def get_items(): page = request.args.get('current', 1, type=int) size = request.args.get('size', 10, type=int) companyName = request.args.get('companyName', '', type=str) city = request.args.get('city', '', type=str) query = JobPosition.query if companyName: query = query.filter(JobPosition.company_name.like(f'%{companyName}%')) # 如果传入 city,则进行模糊查询 if city: query = query.filter(JobPosition.city.like(f'%{city}%')) # 获取分页数据 pagination = query.paginate(page=page, per_page=size, error_out=False) # 构建响应数据 response = { 'list': [item.to_dict() for item in pagination.items], 'page': { 'total': pagination.total, # 总记录数 'current': page, # 当前页码 'size': size, # 每页请求的记录数 'pages': pagination.pages # 总页数 } } return APIUtils.success_response(data=response, message="获取数据成功") @blus.route('/api/login', methods=['POST']) def user_login(): required_fields = ['username', 'password'] is_valid, message = APIUtils.validate_json(request.json, required_fields) if not is_valid: return APIUtils.error_response(message, status_code=400) username = request.json['username'] password = request.json['password'] user = User.query.filter_by(username=username).first() if user is None: return APIUtils.error_response("用户名错误或不存在!", status_code=500) hashed_password = hashlib.sha256(password.encode()).hexdigest() if hashed_password != user.password: return APIUtils.error_response("密码错误或不存在!", status_code=500) return APIUtils.success_response(data={'token': user.id,'userId': user.id, 'username': user.username,'role':user.role}, message="登录成功!") @blus.route('/sys/user/info', methods=['GET']) def user_info(): token = request.headers.get('token') user = User.query.filter_by(id=token).first() return APIUtils.success_response(data={'token': user.id,'userId': user.id, 'username': user.username,'role':user.role}, message="登录成功!") @blus.route('/change_password', methods=['POST']) def change_password(): required_fields = ['username', 'old_password', 'new_password'] is_valid, message = APIUtils.validate_json(request.json, required_fields) if not is_valid: return APIUtils.error_response(message, status_code=400) username = request.json['username'] old_password = request.json['old_password'] new_password = request.json['new_password'] user = User.query.filter_by(username=username).first() if user is None: return APIUtils.error_response("用户不存在!", status_code=404) hashed_old_password = hashlib.sha256(old_password.encode()).hexdigest() if hashed_old_password != user.password: return APIUtils.error_response("原始密码不正确!", status_code=401) # 哈希处理新密码 hashed_new_password = hashlib.sha256(new_password.encode()).hexdigest() user.password = hashed_new_password db.session.commit() return APIUtils.success_response(message="密码修改成功!") @blus.route('/api/users/', methods=['DELETE']) def delete_user(user_id): # 根据用户 ID 查询用户 user = User.query.get(user_id) if user is None: return APIUtils.error_response("用户不存在!", status_code=404) # 检查是否为 admin 用户 if user.username.lower() == 'admin': return APIUtils.error_response("无法删除管理员账户!", status_code=403) # 删除用户 db.session.delete(user) db.session.commit() return APIUtils.success_response(message="用户删除成功!") # 用户管理 @blus.route('/api/users/page', methods=['GET']) def get_users(): # 获取分页参数,默认为第 1 页,每页 10 条记录 page = request.args.get('page', 1, type=int) per_page = request.args.get('limit', 10, type=int) # 获取 username 参数,如果没有则为 None username = request.args.get('username', type=str) # 构建查询,先查询所有用户 query = User.query # 如果提供了 username,则根据 username 进行筛选 if username: query = query.filter(User.username.like(f'%{username}%')) # 执行分页查询 users_pagination = query.paginate(page=page, per_page=per_page, error_out=False) # 获取用户数据 users = users_pagination.items # 将用户数据转为 JSON 格式 users_list = [] for user in users: users_list.append({ 'id': user.id, 'username': user.username, 'password': user.password, 'role': user.role }) # 构建响应数据,包括分页信息 response = { 'list': users_list, 'page': { 'total': users_pagination.total, # 总记录数 'page': users_pagination.page, # 当前页码 'limit': users_pagination.per_page # 每页记录数 } } return APIUtils.success_response(data=response, message="获取用户列表成功") # 新增用户 @blus.route('/api/users', methods=['POST']) def add_user(): data = request.get_json() # 验证必填字段 if not all([data.get('username'),data.get('password')]): return APIUtils.error_response(message="用户名、手机号和密码不能为空", code=400) # 检查用户名是否已存在 if User.query.filter_by(username=data['username']).first(): return APIUtils.error_response(message="用户名已存在") # 创建新用户 new_user = User( username=data['username'], password=data['password'], # 注意:实际项目中密码应该加密存储 role=data.get('role', 1) # 默认普通用户 ) db.session.add(new_user) db.session.commit() return APIUtils.success_response(message="用户添加成功") # 修改用户 @blus.route('/api/users/', methods=['PUT']) def update_user(user_id): data = request.get_json() user = User.query.get(user_id) if not user: return APIUtils.error_response(message="用户不存在", code=404) # 更新字段(密码单独处理) if 'username' in data: # 检查新用户名是否已被其他用户使用 if User.query.filter(User.username == data['username'], User.id != user_id).first(): return APIUtils.error_response(message="用户名已存在", code=400) user.username = data['username'] if 'role' in data: user.role = data['role'] db.session.commit() return APIUtils.success_response(message="用户信息更新成功") def plot_to_base64(plt_figure): """将matplotlib图表转换为base64编码的图片""" buf = io.BytesIO() plt_figure.savefig(buf, format='png', dpi=100) buf.seek(0) return base64.b64encode(buf.read()).decode('utf-8') def generate_prediction_report(job_count_forecast, salary_forecast, forecast_dates): """生成预测报告(返回字典格式)""" report = { "预测时间范围": f"{forecast_dates[0].strftime('%Y-%m-%d')} 至 {forecast_dates[-1].strftime('%Y-%m-%d')}", "总预测招聘岗位数": int(np.sum(job_count_forecast)), "日均预测招聘数": round(np.mean(job_count_forecast), 1), "预测平均薪资变化": f"{round((salary_forecast[-1] - salary_forecast[0]) / salary_forecast[0] * 100, 2)}%", "预测最高薪资": round(np.max(salary_forecast), 2), "预测最低薪资": round(np.min(salary_forecast), 2), "预测趋势": "上升" if salary_forecast[-1] > salary_forecast[0] else "下降" } return report @blus.route('/api/prediction', methods=['GET']) def get_prediction(): try: # 1. 获取数据 job_data = get_finance_job_data() # 2. 数据预处理 daily_job_data, raw_data = preprocess_data(job_data) # 3. 时间序列预测 job_count_forecast, forecast_dates = arima_forecast(daily_job_data, 'job_count', 30) salary_forecast, _ = arima_forecast(daily_job_data, 'salary_avg', 30) # 4. 机器学习预测 ml_data = prepare_ml_data(daily_job_data) job_count_model = train_random_forest(ml_data, 'job_count') salary_model = train_random_forest(ml_data, 'salary_avg') # 5. 生成图表并转换为base64 # 招聘数量趋势图 plt.figure(figsize=(12, 6)) plt.plot(daily_job_data.index, daily_job_data['job_count'], label='历史数据') plt.plot(forecast_dates, job_count_forecast, label='预测数据', color='red') plt.title('金融行业招聘数量趋势预测') plt.xlabel('日期') plt.ylabel('数量') plt.legend() plt.grid() job_count_plot = plot_to_base64(plt) plt.close() # 平均薪资趋势图 plt.figure(figsize=(12, 6)) plt.plot(daily_job_data.index, daily_job_data['salary_avg'], label='历史数据') plt.plot(forecast_dates, salary_forecast, label='预测数据', color='red') plt.title('金融行业平均薪资趋势预测') plt.xlabel('日期') plt.ylabel('薪资') plt.legend() plt.grid() salary_plot = plot_to_base64(plt) plt.close() # 6. 生成报告 prediction_report = generate_prediction_report(job_count_forecast, salary_forecast, forecast_dates) # 准备响应数据 response_data = { "report": prediction_report, "plots": { "job_count": job_count_plot, "salary": salary_plot }, "forecast_data": { "dates": [date.strftime('%Y-%m-%d') for date in forecast_dates], "job_count": job_count_forecast.tolist(), "salary": salary_forecast.tolist() } } return jsonify({ "status": "success", "message": "预测数据获取成功", "data": response_data }) except Exception as e: return jsonify({ "status": "error", "message": f"预测数据获取失败: {str(e)}", "data": None }), 500 # SQL查询 @blus.route('/api/mysql', methods=['POST']) def mysql(): data = request.get_json() # 检查 SQL 参数是否存在 if not data['sql']: return APIUtils.error_response(message="没有sql参数") try: # 连接数据库 connection = pymysql.connect(**db_config) with connection.cursor(pymysql.cursors.DictCursor) as cursor: # 自定义 SQL 查询 cursor.execute(data['sql']) # 获取查询结果 results = cursor.fetchall() return results except pymysql.MySQLError as e: return APIUtils.error_response(message=f"数据库连接失败:{str(e)}") except Exception as e: return APIUtils.error_response(message=f"查询执行失败:{str(e)}") @blus.route('/api/word', methods=['GET']) def word(): try: # 构造停用词文件的动态路径 base_dir = os.path.dirname(os.path.abspath(__file__)) stopwords_file = os.path.join(base_dir, 'utils', 'stopwords.txt') # 读取停用词文件,存入集合,提高查找效率 stopwords = set() with open(stopwords_file, encoding='utf-8') as f: for line in f: word = line.strip() if word: stopwords.add(word) connection = pymysql.connect(**db_config) with connection.cursor(pymysql.cursors.DictCursor) as cursor: # 获取 job_description 字段 query = "SELECT job_description FROM job_positions LIMIT 1000" cursor.execute(query) # 词频统计字典 word_counts = {} # 处理每个 job_description for row in cursor.fetchall(): job_desc = row['job_description'] # 分词并统计 for word in jieba.cut(job_desc): word = word.strip() # 过滤空字符串和停用词 if word and word not in stopwords: word_counts[word] = word_counts.get(word, 0) + 1 # 转换为要求的格式 result = [ {"name": word, "value": count} for word, count in word_counts.items() ] return APIUtils.success_response(data=result) except pymysql.MySQLError as err: return APIUtils.error_response(message=str(err)) finally: connection.close() @blus.route('/api/caiji', methods=['GET']) def caiji(): try: # 构造停用词文件的动态路径 base_dir = os.path.dirname(os.path.abspath(__file__)) stopwords_file = os.path.join(base_dir, 'utils', 'stopwords.txt') # 读取停用词文件,存入集合,提高查找效率 stopwords = set() with open(stopwords_file, encoding='utf-8') as f: for line in f: word = line.strip() if word: stopwords.add(word) connection = pymysql.connect(**db_config) with connection.cursor(pymysql.cursors.DictCursor) as cursor: # 获取 job_description 字段 query = "SELECT job_description FROM job_positions LIMIT 1000" cursor.execute(query) # 词频统计字典 word_counts = {} # 处理每个 job_description for row in cursor.fetchall(): job_desc = row['job_description'] # 分词并统计 for word in jieba.cut(job_desc): word = word.strip() # 过滤空字符串和停用词 if word and word not in stopwords: word_counts[word] = word_counts.get(word, 0) + 1 # 转换为要求的格式 result = [ {"name": word, "value": count} for word, count in word_counts.items() ] return APIUtils.success_response(data=result) except pymysql.MySQLError as err: return APIUtils.error_response(message=str(err)) finally: connection.close()