This commit is contained in:
Looly 2025-03-26 10:16:20 +08:00
parent 5b1476c8c9
commit e0566c17ee
11 changed files with 181 additions and 54 deletions

View File

@ -1,35 +1,71 @@
package org.dromara.hutool.ai; package org.dromara.hutool.ai;
import org.dromara.hutool.core.exception.ExceptionUtil; import org.dromara.hutool.core.exception.HutoolException;
import org.dromara.hutool.core.text.StrUtil;
/** /**
* 异常处理类 * 异常处理类
*/ */
public class AIException extends RuntimeException { public class AIException extends HutoolException {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
public AIException(Throwable e) { /**
super(ExceptionUtil.getMessage(e), e); * 构造
*
* @param e 异常
*/
public AIException(final Throwable e) {
super(e);
} }
public AIException(String message) { /**
* 构造
*
* @param message 消息
*/
public AIException(final String message) {
super(message); super(message);
} }
public AIException(String messageTemplate, Object... params) { /**
super(StrUtil.format(messageTemplate, params)); * 构造
*
* @param messageTemplate 消息模板
* @param params 参数
*/
public AIException(final String messageTemplate, final Object... params) {
super(messageTemplate, params);
} }
public AIException(String message, Throwable throwable) { /**
super(message, throwable); * 构造
*
* @param message 消息
* @param cause 被包装的子异常
*/
public AIException(final String message, final Throwable cause) {
super(message, cause);
} }
public AIException(String message, Throwable throwable, boolean enableSuppression, boolean writableStackTrace) { /**
super(message, throwable, enableSuppression, writableStackTrace); * 构造
*
* @param message 消息
* @param cause 被包装的子异常
* @param enableSuppression 是否启用抑制
* @param writableStackTrace 堆栈跟踪是否应该是可写的
*/
public AIException(final String message, final Throwable cause, final boolean enableSuppression, final boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
} }
public AIException(Throwable throwable, String messageTemplate, Object... params) { /**
super(StrUtil.format(messageTemplate, params), throwable); * 构造
*
* @param cause 被包装的子异常
* @param messageTemplate 消息模板
* @param params 参数
*/
public AIException(final Throwable cause, final String messageTemplate, final Object... params) {
super(cause, messageTemplate, params);
} }
} }

View File

@ -22,8 +22,8 @@ public class AIServiceFactory {
// 加载所有 AIModelProvider 实现类 // 加载所有 AIModelProvider 实现类
static { static {
ServiceLoader<AIServiceProvider> loader = ServiceLoader.load(AIServiceProvider.class); final ServiceLoader<AIServiceProvider> loader = ServiceLoader.load(AIServiceProvider.class);
for (AIServiceProvider provider : loader) { for (final AIServiceProvider provider : loader) {
providers.put(provider.getServiceName().toLowerCase(), provider); providers.put(provider.getServiceName().toLowerCase(), provider);
} }
} }
@ -35,7 +35,7 @@ public class AIServiceFactory {
* @return AI服务实例 * @return AI服务实例
* @since 6.0.0 * @since 6.0.0
*/ */
public static AIService getAIService(AIConfig config) { public static AIService getAIService(final AIConfig config) {
return getAIService(config, AIService.class); return getAIService(config, AIService.class);
} }
@ -46,21 +46,23 @@ public class AIServiceFactory {
* @param clazz AI服务类 * @param clazz AI服务类
* @return clazz对应的AI服务类实例 * @return clazz对应的AI服务类实例
* @since 6.0.0 * @since 6.0.0
* @param <T> AI服务类
*/ */
public static <T extends AIService> T getAIService(AIConfig config, Class<T> clazz) { @SuppressWarnings("unchecked")
public static <T extends AIService> T getAIService(final AIConfig config, final Class<T> clazz) {
//异步执行 //异步执行
CompletableFuture.runAsync(() -> { CompletableFuture.runAsync(() -> {
try { try {
HttpUtil.get("https://static.hutool.cn"); HttpUtil.get("https://static.hutool.cn");
} catch (Exception ignored) { } catch (final Exception ignored) {
} }
}); });
AIServiceProvider provider = providers.get(config.getModelName().toLowerCase()); final AIServiceProvider provider = providers.get(config.getModelName().toLowerCase());
if (provider == null) { if (provider == null) {
throw new IllegalArgumentException("Unsupported model: " + config.getModelName()); throw new IllegalArgumentException("Unsupported model: " + config.getModelName());
} }
AIService service = provider.create(config); final AIService service = provider.create(config);
if (!clazz.isInstance(service)) { if (!clazz.isInstance(service)) {
throw new AIException("Model service is not of type: " + clazz.getSimpleName()); throw new AIException("Model service is not of type: " + clazz.getSimpleName());
} }

View File

@ -25,8 +25,9 @@ public class AIUtil {
* @param clazz AI模型服务类 * @param clazz AI模型服务类
* @return AIModelService的实现类实例 * @return AIModelService的实现类实例
* @since 6.0.0 * @since 6.0.0
* @param <T> AIService实现类
*/ */
public static <T extends AIService> T getAIService(AIConfig config, Class<T> clazz) { public static <T extends AIService> T getAIService(final AIConfig config, final Class<T> clazz) {
return AIServiceFactory.getAIService(config, clazz); return AIServiceFactory.getAIService(config, clazz);
} }
@ -37,7 +38,7 @@ public class AIUtil {
* @return AIModelService 其中只有公共方法 * @return AIModelService 其中只有公共方法
* @since 6.0.0 * @since 6.0.0
*/ */
public static AIService getAIService(AIConfig config) { public static AIService getAIService(final AIConfig config) {
return getAIService(config, AIService.class); return getAIService(config, AIService.class);
} }
@ -48,7 +49,7 @@ public class AIUtil {
* @return DeepSeekService * @return DeepSeekService
* @since 6.0.0 * @since 6.0.0
*/ */
public static DeepSeekService getDeepSeekService(AIConfig config) { public static DeepSeekService getDeepSeekService(final AIConfig config) {
return getAIService(config, DeepSeekService.class); return getAIService(config, DeepSeekService.class);
} }
@ -59,7 +60,7 @@ public class AIUtil {
* @return DoubaoService * @return DoubaoService
* @since 6.0.0 * @since 6.0.0
*/ */
public static DoubaoService getDoubaoService(AIConfig config) { public static DoubaoService getDoubaoService(final AIConfig config) {
return getAIService(config, DoubaoService.class); return getAIService(config, DoubaoService.class);
} }
@ -70,7 +71,7 @@ public class AIUtil {
* @return GrokService * @return GrokService
* @since 6.0.0 * @since 6.0.0
*/ */
public static GrokService getGrokService(AIConfig config) { public static GrokService getGrokService(final AIConfig config) {
return getAIService(config, GrokService.class); return getAIService(config, GrokService.class);
} }
@ -81,7 +82,7 @@ public class AIUtil {
* @return OpenAIService * @return OpenAIService
* @since 6.0.0 * @since 6.0.0
*/ */
public static OpenaiService getOpenAIService(AIConfig config) { public static OpenaiService getOpenAIService(final AIConfig config) {
return getAIService(config, OpenaiService.class); return getAIService(config, OpenaiService.class);
} }
@ -93,7 +94,7 @@ public class AIUtil {
* @return AI模型返回的Response响应字符串 * @return AI模型返回的Response响应字符串
* @since 6.0.0 * @since 6.0.0
*/ */
public static String chat(AIConfig config, String prompt) { public static String chat(final AIConfig config, final String prompt) {
return getAIService(config).chat(prompt); return getAIService(config).chat(prompt);
} }
@ -105,7 +106,7 @@ public class AIUtil {
* @return AI模型返回的Response响应字符串 * @return AI模型返回的Response响应字符串
* @since 6.0.0 * @since 6.0.0
*/ */
public static String chat(AIConfig config, List<Message> messages) { public static String chat(final AIConfig config, final List<Message> messages) {
return getAIService(config).chat(messages); return getAIService(config).chat(messages);
} }

View File

@ -7,17 +7,34 @@ package org.dromara.hutool.ai;
* @since 6.0.0 * @since 6.0.0
*/ */
public enum ModelName { public enum ModelName {
/**
* deepSeek
*/
DEEPSEEK("deepSeek"), DEEPSEEK("deepSeek"),
/**
* openai
*/
OPENAI("openai"), OPENAI("openai"),
/**
* doubao
*/
DOUBAO("doubao"), DOUBAO("doubao"),
/**
* grok
*/
GROK("grok"); GROK("grok");
private final String value; private final String value;
ModelName(String value) { ModelName(final String value) {
this.value = value; this.value = value;
} }
/**
* 获取值
*
* @return
*/
public String getValue() { public String getValue() {
return value; return value;
} }

View File

@ -17,18 +17,18 @@ public class AIConfigBuilder {
* *
* @param modelName 模型厂商的名称注意不是指具体的模型 * @param modelName 模型厂商的名称注意不是指具体的模型
*/ */
public AIConfigBuilder(String modelName) { public AIConfigBuilder(final String modelName) {
try { try {
// 获取配置类 // 获取配置类
Class<? extends AIConfig> configClass = AIConfigRegistry.getConfigClass(modelName); final Class<? extends AIConfig> configClass = AIConfigRegistry.getConfigClass(modelName);
if (configClass == null) { if (configClass == null) {
throw new IllegalArgumentException("Unsupported model: " + modelName); throw new IllegalArgumentException("Unsupported model: " + modelName);
} }
// 使用反射创建实例 // 使用反射创建实例
Constructor<? extends AIConfig> constructor = configClass.getDeclaredConstructor(); final Constructor<? extends AIConfig> constructor = configClass.getDeclaredConstructor();
config = constructor.newInstance(); config = constructor.newInstance();
} catch (Exception e) { } catch (final Exception e) {
throw new RuntimeException("Failed to create AIConfig instance", e); throw new RuntimeException("Failed to create AIConfig instance", e);
} }
} }
@ -40,7 +40,7 @@ public class AIConfigBuilder {
* @return config * @return config
* @since 6.0.0 * @since 6.0.0
*/ */
public synchronized AIConfigBuilder setApiKey(String apiKey) { public synchronized AIConfigBuilder setApiKey(final String apiKey) {
if (apiKey != null) { if (apiKey != null) {
config.setApiKey(apiKey); config.setApiKey(apiKey);
} }
@ -54,7 +54,7 @@ public class AIConfigBuilder {
* @return config * @return config
* @since 6.0.0 * @since 6.0.0
*/ */
public synchronized AIConfigBuilder setApiUrl(String apiUrl) { public synchronized AIConfigBuilder setApiUrl(final String apiUrl) {
if (apiUrl != null) { if (apiUrl != null) {
config.setApiUrl(apiUrl); config.setApiUrl(apiUrl);
} }
@ -68,7 +68,7 @@ public class AIConfigBuilder {
* @return config * @return config
* @since 6.0.0 * @since 6.0.0
*/ */
public synchronized AIConfigBuilder setModel(String model) { public synchronized AIConfigBuilder setModel(final String model) {
if (model != null) { if (model != null) {
config.setModel(model); config.setModel(model);
} }
@ -83,7 +83,7 @@ public class AIConfigBuilder {
* @return config * @return config
* @since 6.0.0 * @since 6.0.0
*/ */
public AIConfigBuilder putAdditionalConfig(String key, Object value) { public AIConfigBuilder putAdditionalConfig(final String key, final Object value) {
if (value != null) { if (value != null) {
config.putAdditionalConfigByKey(key, value); config.putAdditionalConfigByKey(key, value);
} }

View File

@ -17,13 +17,19 @@ public class AIConfigRegistry {
// 加载所有 AIConfig 实现类 // 加载所有 AIConfig 实现类
static { static {
ServiceLoader<AIConfig> loader = ServiceLoader.load(AIConfig.class); final ServiceLoader<AIConfig> loader = ServiceLoader.load(AIConfig.class);
for (AIConfig config : loader) { for (final AIConfig config : loader) {
configClasses.put(config.getModelName().toLowerCase(), config.getClass()); configClasses.put(config.getModelName().toLowerCase(), config.getClass());
} }
} }
public static Class<? extends AIConfig> getConfigClass(String modelName) { /**
* 根据模型名称获取AIConfig实现类
*
* @param modelName 模型名称
* @return AIConfig实现类
*/
public static Class<? extends AIConfig> getConfigClass(final String modelName) {
return configClasses.get(modelName.toLowerCase()); return configClasses.get(modelName.toLowerCase());
} }
} }

View File

@ -19,11 +19,16 @@ public class BaseAIService {
protected final AIConfig config; protected final AIConfig config;
public BaseAIService(AIConfig config) { /**
* 构造方法
*
* @param config AI配置
*/
public BaseAIService(final AIConfig config) {
this.config = config; this.config = config;
} }
protected Response sendGet(String endpoint) { protected Response sendGet(final String endpoint) {
//链式构建请求 //链式构建请求
try { try {
//设置超时3分钟 //设置超时3分钟
@ -32,12 +37,12 @@ public class BaseAIService {
.header(HeaderName.ACCEPT, "application/json") .header(HeaderName.ACCEPT, "application/json")
.header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey()) .header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey())
.send(); .send();
} catch (AIException e) { } catch (final AIException e) {
throw new AIException("Failed to send GET request: " + e.getMessage(), e); throw new AIException("Failed to send GET request: " + e.getMessage(), e);
} }
} }
protected Response sendPost(String endpoint, String paramJson) { protected Response sendPost(final String endpoint, final String paramJson) {
//链式构建请求 //链式构建请求
try { try {
//设置超时3分钟 //设置超时3分钟
@ -48,13 +53,13 @@ public class BaseAIService {
.header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey()) .header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey())
.body(paramJson) .body(paramJson)
.send(); .send();
} catch (AIException e) { } catch (final AIException e) {
throw new AIException("Failed to send POST request" + e.getMessage(), e); throw new AIException("Failed to send POST request" + e.getMessage(), e);
} }
} }
protected Response sendFormData(String endpoint, Map<String, Object> paramMap) { protected Response sendFormData(final String endpoint, final Map<String, Object> paramMap) {
//链式构建请求 //链式构建请求
try { try {
//设置超时3分钟 //设置超时3分钟
@ -65,7 +70,7 @@ public class BaseAIService {
.header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey()) .header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey())
.form(paramMap) .form(paramMap)
.send(); .send();
} catch (AIException e) { } catch (final AIException e) {
throw new AIException("Failed to send POST request" + e.getMessage(), e); throw new AIException("Failed to send POST request" + e.getMessage(), e);
} }
} }

View File

@ -23,7 +23,7 @@ public class BaseConfig implements AIConfig {
protected Map<String, Object> additionalConfig = new SafeConcurrentHashMap<>(); protected Map<String, Object> additionalConfig = new SafeConcurrentHashMap<>();
@Override @Override
public void setApiKey(String apiKey) { public void setApiKey(final String apiKey) {
this.apiKey = apiKey; this.apiKey = apiKey;
} }
@ -33,7 +33,7 @@ public class BaseConfig implements AIConfig {
} }
@Override @Override
public void setApiUrl(String apiUrl) { public void setApiUrl(final String apiUrl) {
this.apiUrl = apiUrl; this.apiUrl = apiUrl;
} }
@ -43,7 +43,7 @@ public class BaseConfig implements AIConfig {
} }
@Override @Override
public void setModel(String model) { public void setModel(final String model) {
this.model = model; this.model = model;
} }
@ -53,12 +53,12 @@ public class BaseConfig implements AIConfig {
} }
@Override @Override
public void putAdditionalConfigByKey(String key, Object value) { public void putAdditionalConfigByKey(final String key, final Object value) {
this.additionalConfig.put(key, value); this.additionalConfig.put(key, value);
} }
@Override @Override
public Object getAdditionalConfigByKey(String key) { public Object getAdditionalConfigByKey(final String key) {
return additionalConfig.get(key); return additionalConfig.get(key);
} }

View File

@ -12,15 +12,31 @@ public class Message {
//内容 //内容
private final Object content; private final Object content;
public Message(String role, Object content) { /**
* 构造
*
* @param role 角色
* @param content 内容
*/
public Message(final String role, final Object content) {
this.role = role; this.role = role;
this.content = content; this.content = content;
} }
/**
* 获取角色
*
* @return 角色
*/
public String getRole() { public String getRole() {
return role; return role;
} }
/**
* 获取内容
*
* @return 内容
*/
public Object getContent() { public Object getContent() {
return content; return content;
} }

View File

@ -0,0 +1,22 @@
/*
* Copyright (c) 2025 Hutool Team and hutool.cn
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* AI相关基础类
*
* @author elichow
*/
package org.dromara.hutool.ai.core;

View File

@ -0,0 +1,22 @@
/*
* Copyright (c) 2025 Hutool Team and hutool.cn
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* AI模块
*
* @author elichow
*/
package org.dromara.hutool.ai;