1
0
mirror of https://gitee.com/dromara/hutool.git synced 2025-04-05 17:37:59 +08:00
This commit is contained in:
Looly 2025-03-26 10:16:20 +08:00
parent 5b1476c8c9
commit e0566c17ee
11 changed files with 181 additions and 54 deletions

View File

@ -1,35 +1,71 @@
package org.dromara.hutool.ai;
import org.dromara.hutool.core.exception.ExceptionUtil;
import org.dromara.hutool.core.text.StrUtil;
import org.dromara.hutool.core.exception.HutoolException;
/**
* 异常处理类
*/
public class AIException extends RuntimeException {
public class AIException extends HutoolException {
private static final long serialVersionUID = 1L;
public AIException(Throwable e) {
super(ExceptionUtil.getMessage(e), e);
/**
* 构造
*
* @param e 异常
*/
public AIException(final Throwable e) {
super(e);
}
public AIException(String message) {
/**
* 构造
*
* @param message 消息
*/
public AIException(final String message) {
super(message);
}
public AIException(String messageTemplate, Object... params) {
super(StrUtil.format(messageTemplate, params));
/**
* 构造
*
* @param messageTemplate 消息模板
* @param params 参数
*/
public AIException(final String messageTemplate, final Object... params) {
super(messageTemplate, params);
}
public AIException(String message, Throwable throwable) {
super(message, throwable);
/**
* 构造
*
* @param message 消息
* @param cause 被包装的子异常
*/
public AIException(final String message, final Throwable cause) {
super(message, cause);
}
public AIException(String message, Throwable throwable, boolean enableSuppression, boolean writableStackTrace) {
super(message, throwable, enableSuppression, writableStackTrace);
/**
* 构造
*
* @param message 消息
* @param cause 被包装的子异常
* @param enableSuppression 是否启用抑制
* @param writableStackTrace 堆栈跟踪是否应该是可写的
*/
public AIException(final String message, final Throwable cause, final boolean enableSuppression, final boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}
public AIException(Throwable throwable, String messageTemplate, Object... params) {
super(StrUtil.format(messageTemplate, params), throwable);
/**
* 构造
*
* @param cause 被包装的子异常
* @param messageTemplate 消息模板
* @param params 参数
*/
public AIException(final Throwable cause, final String messageTemplate, final Object... params) {
super(cause, messageTemplate, params);
}
}

View File

@ -22,8 +22,8 @@ public class AIServiceFactory {
// 加载所有 AIModelProvider 实现类
static {
ServiceLoader<AIServiceProvider> loader = ServiceLoader.load(AIServiceProvider.class);
for (AIServiceProvider provider : loader) {
final ServiceLoader<AIServiceProvider> loader = ServiceLoader.load(AIServiceProvider.class);
for (final AIServiceProvider provider : loader) {
providers.put(provider.getServiceName().toLowerCase(), provider);
}
}
@ -35,7 +35,7 @@ public class AIServiceFactory {
* @return AI服务实例
* @since 6.0.0
*/
public static AIService getAIService(AIConfig config) {
public static AIService getAIService(final AIConfig config) {
return getAIService(config, AIService.class);
}
@ -46,21 +46,23 @@ public class AIServiceFactory {
* @param clazz AI服务类
* @return clazz对应的AI服务类实例
* @since 6.0.0
* @param <T> AI服务类
*/
public static <T extends AIService> T getAIService(AIConfig config, Class<T> clazz) {
@SuppressWarnings("unchecked")
public static <T extends AIService> T getAIService(final AIConfig config, final Class<T> clazz) {
//异步执行
CompletableFuture.runAsync(() -> {
try {
HttpUtil.get("https://static.hutool.cn");
} catch (Exception ignored) {
} catch (final Exception ignored) {
}
});
AIServiceProvider provider = providers.get(config.getModelName().toLowerCase());
final AIServiceProvider provider = providers.get(config.getModelName().toLowerCase());
if (provider == null) {
throw new IllegalArgumentException("Unsupported model: " + config.getModelName());
}
AIService service = provider.create(config);
final AIService service = provider.create(config);
if (!clazz.isInstance(service)) {
throw new AIException("Model service is not of type: " + clazz.getSimpleName());
}

View File

@ -25,8 +25,9 @@ public class AIUtil {
* @param clazz AI模型服务类
* @return AIModelService的实现类实例
* @since 6.0.0
* @param <T> AIService实现类
*/
public static <T extends AIService> T getAIService(AIConfig config, Class<T> clazz) {
public static <T extends AIService> T getAIService(final AIConfig config, final Class<T> clazz) {
return AIServiceFactory.getAIService(config, clazz);
}
@ -37,7 +38,7 @@ public class AIUtil {
* @return AIModelService 其中只有公共方法
* @since 6.0.0
*/
public static AIService getAIService(AIConfig config) {
public static AIService getAIService(final AIConfig config) {
return getAIService(config, AIService.class);
}
@ -48,7 +49,7 @@ public class AIUtil {
* @return DeepSeekService
* @since 6.0.0
*/
public static DeepSeekService getDeepSeekService(AIConfig config) {
public static DeepSeekService getDeepSeekService(final AIConfig config) {
return getAIService(config, DeepSeekService.class);
}
@ -59,7 +60,7 @@ public class AIUtil {
* @return DoubaoService
* @since 6.0.0
*/
public static DoubaoService getDoubaoService(AIConfig config) {
public static DoubaoService getDoubaoService(final AIConfig config) {
return getAIService(config, DoubaoService.class);
}
@ -70,7 +71,7 @@ public class AIUtil {
* @return GrokService
* @since 6.0.0
*/
public static GrokService getGrokService(AIConfig config) {
public static GrokService getGrokService(final AIConfig config) {
return getAIService(config, GrokService.class);
}
@ -81,7 +82,7 @@ public class AIUtil {
* @return OpenAIService
* @since 6.0.0
*/
public static OpenaiService getOpenAIService(AIConfig config) {
public static OpenaiService getOpenAIService(final AIConfig config) {
return getAIService(config, OpenaiService.class);
}
@ -93,7 +94,7 @@ public class AIUtil {
* @return AI模型返回的Response响应字符串
* @since 6.0.0
*/
public static String chat(AIConfig config, String prompt) {
public static String chat(final AIConfig config, final String prompt) {
return getAIService(config).chat(prompt);
}
@ -105,7 +106,7 @@ public class AIUtil {
* @return AI模型返回的Response响应字符串
* @since 6.0.0
*/
public static String chat(AIConfig config, List<Message> messages) {
public static String chat(final AIConfig config, final List<Message> messages) {
return getAIService(config).chat(messages);
}

View File

@ -7,17 +7,34 @@ package org.dromara.hutool.ai;
* @since 6.0.0
*/
public enum ModelName {
/**
* deepSeek
*/
DEEPSEEK("deepSeek"),
/**
* openai
*/
OPENAI("openai"),
/**
* doubao
*/
DOUBAO("doubao"),
/**
* grok
*/
GROK("grok");
private final String value;
ModelName(String value) {
ModelName(final String value) {
this.value = value;
}
/**
* 获取值
*
* @return
*/
public String getValue() {
return value;
}

View File

@ -17,18 +17,18 @@ public class AIConfigBuilder {
*
* @param modelName 模型厂商的名称注意不是指具体的模型
*/
public AIConfigBuilder(String modelName) {
public AIConfigBuilder(final String modelName) {
try {
// 获取配置类
Class<? extends AIConfig> configClass = AIConfigRegistry.getConfigClass(modelName);
final Class<? extends AIConfig> configClass = AIConfigRegistry.getConfigClass(modelName);
if (configClass == null) {
throw new IllegalArgumentException("Unsupported model: " + modelName);
}
// 使用反射创建实例
Constructor<? extends AIConfig> constructor = configClass.getDeclaredConstructor();
final Constructor<? extends AIConfig> constructor = configClass.getDeclaredConstructor();
config = constructor.newInstance();
} catch (Exception e) {
} catch (final Exception e) {
throw new RuntimeException("Failed to create AIConfig instance", e);
}
}
@ -40,7 +40,7 @@ public class AIConfigBuilder {
* @return config
* @since 6.0.0
*/
public synchronized AIConfigBuilder setApiKey(String apiKey) {
public synchronized AIConfigBuilder setApiKey(final String apiKey) {
if (apiKey != null) {
config.setApiKey(apiKey);
}
@ -54,7 +54,7 @@ public class AIConfigBuilder {
* @return config
* @since 6.0.0
*/
public synchronized AIConfigBuilder setApiUrl(String apiUrl) {
public synchronized AIConfigBuilder setApiUrl(final String apiUrl) {
if (apiUrl != null) {
config.setApiUrl(apiUrl);
}
@ -68,7 +68,7 @@ public class AIConfigBuilder {
* @return config
* @since 6.0.0
*/
public synchronized AIConfigBuilder setModel(String model) {
public synchronized AIConfigBuilder setModel(final String model) {
if (model != null) {
config.setModel(model);
}
@ -83,7 +83,7 @@ public class AIConfigBuilder {
* @return config
* @since 6.0.0
*/
public AIConfigBuilder putAdditionalConfig(String key, Object value) {
public AIConfigBuilder putAdditionalConfig(final String key, final Object value) {
if (value != null) {
config.putAdditionalConfigByKey(key, value);
}

View File

@ -17,13 +17,19 @@ public class AIConfigRegistry {
// 加载所有 AIConfig 实现类
static {
ServiceLoader<AIConfig> loader = ServiceLoader.load(AIConfig.class);
for (AIConfig config : loader) {
final ServiceLoader<AIConfig> loader = ServiceLoader.load(AIConfig.class);
for (final AIConfig config : loader) {
configClasses.put(config.getModelName().toLowerCase(), config.getClass());
}
}
public static Class<? extends AIConfig> getConfigClass(String modelName) {
/**
* 根据模型名称获取AIConfig实现类
*
* @param modelName 模型名称
* @return AIConfig实现类
*/
public static Class<? extends AIConfig> getConfigClass(final String modelName) {
return configClasses.get(modelName.toLowerCase());
}
}

View File

@ -19,11 +19,16 @@ public class BaseAIService {
protected final AIConfig config;
public BaseAIService(AIConfig config) {
/**
* 构造方法
*
* @param config AI配置
*/
public BaseAIService(final AIConfig config) {
this.config = config;
}
protected Response sendGet(String endpoint) {
protected Response sendGet(final String endpoint) {
//链式构建请求
try {
//设置超时3分钟
@ -32,12 +37,12 @@ public class BaseAIService {
.header(HeaderName.ACCEPT, "application/json")
.header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey())
.send();
} catch (AIException e) {
} catch (final AIException e) {
throw new AIException("Failed to send GET request: " + e.getMessage(), e);
}
}
protected Response sendPost(String endpoint, String paramJson) {
protected Response sendPost(final String endpoint, final String paramJson) {
//链式构建请求
try {
//设置超时3分钟
@ -48,13 +53,13 @@ public class BaseAIService {
.header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey())
.body(paramJson)
.send();
} catch (AIException e) {
} catch (final AIException e) {
throw new AIException("Failed to send POST request" + e.getMessage(), e);
}
}
protected Response sendFormData(String endpoint, Map<String, Object> paramMap) {
protected Response sendFormData(final String endpoint, final Map<String, Object> paramMap) {
//链式构建请求
try {
//设置超时3分钟
@ -65,7 +70,7 @@ public class BaseAIService {
.header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey())
.form(paramMap)
.send();
} catch (AIException e) {
} catch (final AIException e) {
throw new AIException("Failed to send POST request" + e.getMessage(), e);
}
}

View File

@ -23,7 +23,7 @@ public class BaseConfig implements AIConfig {
protected Map<String, Object> additionalConfig = new SafeConcurrentHashMap<>();
@Override
public void setApiKey(String apiKey) {
public void setApiKey(final String apiKey) {
this.apiKey = apiKey;
}
@ -33,7 +33,7 @@ public class BaseConfig implements AIConfig {
}
@Override
public void setApiUrl(String apiUrl) {
public void setApiUrl(final String apiUrl) {
this.apiUrl = apiUrl;
}
@ -43,7 +43,7 @@ public class BaseConfig implements AIConfig {
}
@Override
public void setModel(String model) {
public void setModel(final String model) {
this.model = model;
}
@ -53,12 +53,12 @@ public class BaseConfig implements AIConfig {
}
@Override
public void putAdditionalConfigByKey(String key, Object value) {
public void putAdditionalConfigByKey(final String key, final Object value) {
this.additionalConfig.put(key, value);
}
@Override
public Object getAdditionalConfigByKey(String key) {
public Object getAdditionalConfigByKey(final String key) {
return additionalConfig.get(key);
}

View File

@ -12,15 +12,31 @@ public class Message {
//内容
private final Object content;
public Message(String role, Object content) {
/**
* 构造
*
* @param role 角色
* @param content 内容
*/
public Message(final String role, final Object content) {
this.role = role;
this.content = content;
}
/**
* 获取角色
*
* @return 角色
*/
public String getRole() {
return role;
}
/**
* 获取内容
*
* @return 内容
*/
public Object getContent() {
return content;
}

View File

@ -0,0 +1,22 @@
/*
* Copyright (c) 2025 Hutool Team and hutool.cn
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* AI相关基础类
*
* @author elichow
*/
package org.dromara.hutool.ai.core;

View File

@ -0,0 +1,22 @@
/*
* Copyright (c) 2025 Hutool Team and hutool.cn
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* AI模块
*
* @author elichow
*/
package org.dromara.hutool.ai;