推荐算法push

This commit is contained in:
闵宪瑞 2025-04-07 10:50:40 +08:00
parent 9041756e80
commit 5d94fa03dc
21 changed files with 576 additions and 22 deletions

54
.gitignore vendored
View File

@ -1,26 +1,38 @@
# ---> Java
# Compiled class file
*.class
target/
!.mvn/wrapper/maven-wrapper.jar
!**/src/main/**/target/
!**/src/test/**/target/
# Log file
*.log
### IntelliJ IDEA ###
.idea/modules.xml
.idea/jarRepositories.xml
.idea/compiler.xml
.idea/libraries/
*.iws
*.iml
*.ipr
# BlueJ files
*.ctxt
### Eclipse ###
.apt_generated
.classpath
.factorypath
.project
.settings
.springBeans
.sts4-cache
# Mobile Tools for Java (J2ME)
.mtj.tmp/
### NetBeans ###
/nbproject/private/
/nbbuild/
/dist/
/nbdist/
/.nb-gradle/
build/
!**/src/main/**/build/
!**/src/test/**/build/
# Package Files #
*.jar
*.war
*.nar
*.ear
*.zip
*.tar.gz
*.rar
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
hs_err_pid*
replay_pid*
### VS Code ###
.vscode/
### Mac OS ###
.DS_Store

8
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

7
.idea/codeStyles/Project.xml generated Normal file
View File

@ -0,0 +1,7 @@
<component name="ProjectCodeStyleConfiguration">
<code_scheme name="Project" version="173">
<ScalaCodeStyleSettings>
<option name="MULTILINE_STRING_CLOSING_QUOTES_ON_NEW_LINE" value="true" />
</ScalaCodeStyleSettings>
</code_scheme>
</component>

5
.idea/codeStyles/codeStyleConfig.xml generated Normal file
View File

@ -0,0 +1,5 @@
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
</state>
</component>

6
.idea/encodings.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Encoding">
<file url="file://$PROJECT_DIR$/src/main/java" charset="UTF-8" />
</component>
</project>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
</profile>
</component>

12
.idea/misc.xml generated Normal file
View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ExternalStorageConfigurationManager" enabled="true" />
<component name="MavenProjectsManager">
<option name="originalFiles">
<list>
<option value="$PROJECT_DIR$/pom.xml" />
</list>
</option>
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="corretto-21" project-jdk-type="JavaSDK" />
</project>

6
.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

@ -1,3 +1,3 @@
# spark-rem
推荐算法模板
推荐算法

77
pom.xml Normal file
View File

@ -0,0 +1,77 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.6.2</version>
<relativePath/>
</parent>
<groupId>org.example</groupId>
<artifactId>spark-rem</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<!-- spark引入-->
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>3.1.2</version>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>3.1.2</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.12</artifactId>
<version>3.1.2</version>
</dependency>
<dependency>
<groupId>org.codehaus.janino</groupId>
<artifactId>janino</artifactId>
<version>3.0.15</version> <!-- 使用适当版本 -->
</dependency>
<!-- MySQL JDBC Driver -->
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.28</version> <!-- 使用适合的 MySQL 驱动版本 -->
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,112 @@
//package com;
//
//import com.model.ProductRating;
//import org.apache.spark.SparkConf;
//import org.apache.spark.sql.Dataset;
//import org.apache.spark.sql.Encoders;
//import org.apache.spark.sql.SparkSession;
//import org.apache.spark.sql.Row;
//import org.apache.spark.api.java.JavaRDD;
//import org.apache.spark.api.java.JavaPairRDD;
//import org.apache.spark.mllib.recommendation.ALS;
//import org.apache.spark.mllib.recommendation.Rating;
//import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
//import java.util.Properties;
//import java.util.List;
//
//public class OfflineRecommender {
//
// // 读取 MySQL 数据的工具方法
// public static Dataset<Row> readFromMySQL(SparkSession spark, String url, String table, Properties properties) {
// // MySQL 读取数据并返回 Dataset
// return spark.read().jdbc(url, table, properties);
// }
//
// public static void main(String[] args) {
//
// // 创建 SparkConf 对象并设置 Spark 配置
// SparkConf sparkConf = new SparkConf()
// .setMaster("local[*]")
// .setAppName("OfflineRecommender");
//
// // 创建 SparkSession 对象
// SparkSession spark = SparkSession.builder().config(sparkConf).getOrCreate();
//
//
// //配置 MySQL 连接信息
// String url = "jdbc:mysql://localhost:3306/test_rem"; // 修改为你的数据库连接 URL
// String table = "rating"; // 替换为你要读取的表名
//
// // 设置 MySQL 连接的属性
// Properties properties = new Properties();
// properties.put("user", "root"); // 修改为你的用户名
// properties.put("password", "123456"); // 修改为你的密码
// properties.put("driver", "com.mysql.cj.jdbc.Driver");
//
// // 读取 MySQL 数据
// Dataset<Row> df = readFromMySQL(spark, url, table, properties);
//
// // 使用 alias 重命名字段确保字段名匹配
// Dataset<Row> renamedDF = df
// .select(
// df.col("user_id").alias("userId"),
// df.col("product_id").alias("productId"),
// df.col("score"),
// df.col("create_time").alias("createTime")
// )
// .na().drop(); // 删除包含空值的行
//
// // 使用 Spark Encoders Dataset<Row> 转换为 ProductRating 类型
// Dataset<ProductRating> productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class));
//
// // 显示映射后的实体类数据
// productRatingDataset.show();
//
// // Dataset 转换为 Java RDD
// JavaRDD<ProductRating> ratingRDD = productRatingDataset.javaRDD();
//
// // 创建用户和产品的 RDD
// JavaRDD<Integer> userRDD = ratingRDD.map(ProductRating::getUserId).distinct();
// JavaRDD<Integer> productRDD = ratingRDD.map(ProductRating::getProductId).distinct();
//
// // 创建训练数据集
// JavaRDD<Rating> trainData =
// ratingRDD.map(rating -> new Rating(rating.getUserId(), rating.getProductId(), rating.getScore())
// );
//
// // 打印训练数据集
// trainData.collect().forEach(train -> {
// System.out.println("UserId: " + train.user() + ", ProductId: " + train.product() + ", Rating: " + train.rating());
// });
//
// // 调用 ALS 算法训练隐语义模型
// int rank = 50; // 隐语义因子的个数
// int iterations = 5; // 迭代次数
// double lambda = 0.01; // 正则化参数
//
// // 训练 ALS 模型
// MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda);
//
// // 计算用户推荐矩阵
// JavaPairRDD<Integer, Integer> userProducts = userRDD.cartesian(productRDD);
//
// // model已训练好把id传进去就可以得到预测评分列表RDD[Rating] (userId,productId,rating)
// JavaRDD<Rating> preRatings = model.predict(userProducts);
//
// preRatings.foreach(userId -> {
// System.out.println(userId);
// });
// // 过滤出评分大于0的推荐
// JavaRDD<Rating> filteredRatings = preRatings.filter(rating -> rating.rating() > 0);
// // 按照评分降序排序并取前10个推荐
// List<Rating> top10Ratings = filteredRatings
// .sortBy((Rating rating) -> rating.rating(), false, 1) // 按评分降序排序false表示降序
// .take(10); // 取前10个评分最高的推荐
// // 打印前10个推荐
// top10Ratings.forEach(rating -> {
// System.out.println("UserId: " + rating.user() + ", ProductId: " + rating.product() + ", Rating: " + rating.rating());
// });
//
// }
//
//}

View File

@ -0,0 +1,12 @@
package com;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class SpringBootRecommendApplication {
public static void main(String[] args) {
SpringApplication.run(SpringBootRecommendApplication.class, args);
}
}

View File

@ -0,0 +1,23 @@
package com.config;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.SparkSession;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class SparkConfig {
@Bean
public SparkSession sparkSession() {
// 创建 SparkConf 对象并设置 Spark 配置
SparkConf sparkConf = new SparkConf()
.setMaster("local[*]") // 本地模式使用所有可用核心
.setAppName("OfflineRecommender"); // Spark 应用名称
// 创建并返回 SparkSession 对象
return SparkSession.builder()
.config(sparkConf)
.getOrCreate(); // 初始化 SparkSession
}
}

View File

@ -0,0 +1,34 @@
package com.controller;
import com.model.ProductRemRating;
import com.service.RecommendService;
import org.apache.spark.mllib.recommendation.Rating;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import java.util.List;
/**
* 推荐
*/
@RestController
@RequestMapping("/api/recommend")
@CrossOrigin
public class RecommendController {
@Autowired private RecommendService recommendService;
/**
* 推荐全部商品
* @param total
* @return 返回 用户Id商品编号 评估推荐分数
*/
@GetMapping("all")
public List<ProductRemRating> allUserRecommend(@RequestParam Integer total){
List<ProductRemRating> ratings = recommendService.AllRecommend(total);
return ratings;
}
}

View File

@ -0,0 +1,68 @@
package com.model;
/**
* 评分表
*/
public class ProductRating {
private Integer userId;
private Integer productId;
private Double score;
private String createTime;
public ProductRating() {
// 无参构造函数
}
public ProductRating(Integer userId, Integer productId, Double score, String createTime) {
this.userId = userId;
this.productId = productId;
this.score = score;
this.createTime = createTime;
}
public Integer getProductId() {
return productId;
}
public Double getScore() {
return score;
}
public void setProductId(Integer productId) {
this.productId = productId;
}
public void setScore(Double score) {
this.score = score;
}
public Integer getUserId() {
return userId;
}
public String getCreateTime() {
return createTime;
}
public void setUserId(Integer userId) {
this.userId = userId;
}
public void setCreateTime(String createTime) {
this.createTime = createTime;
}
@Override
public String toString() {
return "ProductRating{" +
"userId=" + userId +
", productId=" + productId +
", score=" + score +
", createTime='" + createTime + '\'' +
'}';
}
}

View File

@ -0,0 +1,17 @@
package com.model;
import lombok.Data;
/**
* 推荐返回表
*/
@Data
public class ProductRemRating {
private Integer userId;
private Integer productId;
private Double rating;
}

View File

@ -0,0 +1,16 @@
package com.service;
import com.model.ProductRemRating;
import org.apache.spark.mllib.recommendation.Rating;
import java.util.List;
public interface RecommendService {
/**
* 推荐全部
* @return
*/
List<ProductRemRating> AllRecommend(Integer total);
}

View File

@ -0,0 +1,117 @@
package com.service.impl;
import com.model.ProductRating;
import com.model.ProductRemRating;
import com.service.RecommendService;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.apache.spark.mllib.recommendation.Rating;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.ALS;
@Service
public class RecommendServiceImpl implements RecommendService {
@Autowired
private SparkSession spark; // 自动注入 SparkSession
@Value("${spark.mysql.url}")
private String url;
@Value("${spark.mysql.table}")
private String table;
@Value("${spark.mysql.user}")
private String user;
@Value("${spark.mysql.password}")
private String password;
@Override
public List<ProductRemRating> AllRecommend(Integer total) {
// 设置 MySQL 连接的属性
Properties properties = new Properties();
properties.put("user", user); // 修改为你的用户名
properties.put("password",password); // 修改为你的密码
properties.put("driver", "com.mysql.cj.jdbc.Driver");
// 读取 MySQL 数据
Dataset<Row> df = readFromMySQL(spark, url, table, properties);
// 使用 alias 重命名字段确保字段名匹配
Dataset<Row> renamedDF = df
.select(
df.col("user_id").alias("userId"),
df.col("item_id").alias("productId"),
df.col("star").alias("score"),
df.col("create_time").alias("createTime")
)
.na().drop(); // 删除包含空值的行
// 使用 Spark Encoders Dataset<Row> 转换为 ProductRating 类型
Dataset<ProductRating> productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class));
JavaRDD<ProductRating> ratingRDD = productRatingDataset.javaRDD();
// 创建用户和产品的 RDD
JavaRDD<Integer> userRDD = ratingRDD.map(ProductRating::getUserId).distinct();
JavaRDD<Integer> productRDD = ratingRDD.map(ProductRating::getProductId).distinct();
//用户编号
// 创建训练数据集
JavaRDD<Rating> trainData =
ratingRDD.map(rating -> new Rating(rating.getUserId(), rating.getProductId(), rating.getScore())
);
// 打印训练数据集
trainData.collect().forEach(train -> {
System.out.println("UserId: " + train.user() + ", ProductId: " + train.product() + ", Rating: " + train.rating());
});
// 调用 ALS 算法训练隐语义模型
int rank = 50; // 隐语义因子的个数
int iterations = 5; // 迭代次数
double lambda = 0.01; // 正则化参数
// 训练 ALS 模型
MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda);
// 计算用户推荐矩阵
JavaPairRDD<Integer, Integer> userProducts = userRDD.cartesian(productRDD);
// model已训练好把id传进去就可以得到预测评分列表RDD[Rating] (userId,productId,rating)
JavaRDD<Rating> preRatings = model.predict(userProducts);
preRatings.foreach(userId -> {
System.out.println(userId);
});
// 过滤出评分大于0的推荐
JavaRDD<Rating> filteredRatings = preRatings.filter(rating -> rating.rating() > 0);
// 按照评分降序排序并取前10个推荐
List<Rating> topRatings = filteredRatings
.sortBy((Rating rating) -> rating.rating(), false, 1) // 按评分降序排序false表示降序
.take(total); // 取前10个评分最高的推荐
List<ProductRemRating> list = new ArrayList<>();
// 打印前10个推荐
topRatings.forEach(rating -> {
ProductRemRating productRemRating = new ProductRemRating();
productRemRating.setUserId(rating.user());
productRemRating.setProductId(rating.product());
productRemRating.setRating(rating.rating());
list.add(productRemRating);
System.out.println("UserId: " + rating.user() + ", ProductId: " + rating.product() + ", Rating: " + rating.rating());
});
return list;
}
public static Dataset<Row> readFromMySQL(SparkSession spark, String url, String table, Properties properties) {
// MySQL 读取数据并返回 Dataset
return spark.read().jdbc(url, table, properties);
}
}

View File

@ -0,0 +1,4 @@
package com.util;
public class MySQLReader {
}

View File

@ -0,0 +1,8 @@
server:
port: 9090
spark:
mysql:
url: jdbc:mysql://localhost:3306/racket
user: root
password: 123456
table: tb_order

View File

@ -0,0 +1,4 @@
log4j.rootLogger=info, stdout
log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %5p --- [%50t] %-80c(line:%5L) : %m%n