mirror of
https://gitee.com/dromara/hutool.git
synced 2025-04-05 17:37:59 +08:00
commit
7c6978c990
@ -0,0 +1,136 @@
|
||||
package cn.hutool.core.lang;
|
||||
|
||||
import cn.hutool.core.util.RandomUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
|
||||
/**
|
||||
* 动态按权重随机的随机池,底层是list实现。
|
||||
*
|
||||
* @param <E> 元素类型
|
||||
* @author 王叶峰
|
||||
* @date 2024-07-29
|
||||
*/
|
||||
public class WeightListPool<E> {
|
||||
|
||||
/**
|
||||
* 随机元素池
|
||||
*/
|
||||
private final ArrayList<EWeight<E>> randomPool;
|
||||
|
||||
public WeightListPool() {
|
||||
randomPool = new ArrayList<>();
|
||||
}
|
||||
|
||||
public WeightListPool(int poolSize) {
|
||||
randomPool = new ArrayList<>(poolSize);
|
||||
}
|
||||
|
||||
public void add(E e, double weight) {
|
||||
Assert.isTrue(weight > 0, "权重必须大于0!");
|
||||
randomPool.add(new EWeight<>(e, sumWeight() + weight));
|
||||
}
|
||||
|
||||
public boolean remove(E e) {
|
||||
boolean removed = false;
|
||||
double weight = 0;
|
||||
int i = 0;
|
||||
Iterator<EWeight<E>> iterator = randomPool.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
EWeight<E> ew = iterator.next();
|
||||
if (!removed && ew.e.equals(e)) {
|
||||
iterator.remove();
|
||||
weight = ew.sumWeight - (i == 0 ? 0 : randomPool.get(i - 1).sumWeight);// 权重=当前权重-上一个权重
|
||||
removed = true;
|
||||
}
|
||||
if (removed) {
|
||||
// 重新计算后续权重
|
||||
ew.sumWeight -= weight;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
return removed;
|
||||
}
|
||||
|
||||
private double sumWeight() {
|
||||
if (randomPool.isEmpty()) {
|
||||
return 0;
|
||||
}
|
||||
return randomPool.get(randomPool.size() - 1).sumWeight;
|
||||
}
|
||||
|
||||
private void checkEmptyPool() {
|
||||
if (isEmpty()) {
|
||||
throw new IllegalArgumentException("随机池为空!");
|
||||
}
|
||||
}
|
||||
|
||||
public E random() {
|
||||
checkEmptyPool();
|
||||
|
||||
if (randomPool.size() == 1) {
|
||||
return randomPool.get(0).e;
|
||||
}
|
||||
ThreadLocalRandom random = RandomUtil.getRandom();
|
||||
double randVal = random.nextDouble() * sumWeight();
|
||||
return binarySearch(randVal);
|
||||
}
|
||||
|
||||
/**
|
||||
* 二分查找小于等于key的最大值的元素
|
||||
*
|
||||
* @param key 目标值
|
||||
* @return 随机池的一个元素或者null 当key大于所有元素的总权重时,返回null
|
||||
*/
|
||||
private E binarySearch(double key) {
|
||||
int low = 0;
|
||||
int high = randomPool.size() - 1;
|
||||
|
||||
while (low <= high) {
|
||||
int mid = (low + high) >>> 1;
|
||||
double midVal = randomPool.get(mid).sumWeight;
|
||||
|
||||
if (midVal < key) {
|
||||
low = mid + 1;
|
||||
} else if (midVal > key) {
|
||||
high = mid - 1;
|
||||
} else {
|
||||
return randomPool.get(mid).e;
|
||||
}
|
||||
}
|
||||
return randomPool.get(low).e;
|
||||
}
|
||||
|
||||
/**
|
||||
* 按照给定的总权重随机出一个元素
|
||||
*
|
||||
* @param weight 总权重
|
||||
* @return 随机池的一个元素或者null
|
||||
*/
|
||||
public E randomByWeight(double weight) {
|
||||
Assert.isTrue(weight >= sumWeight(), "权重必须大于当前总权重!");
|
||||
ThreadLocalRandom random = RandomUtil.getRandom();
|
||||
double randVal = random.nextDouble() * sumWeight();
|
||||
if (randVal > sumWeight()) {
|
||||
return null;
|
||||
}
|
||||
return binarySearch(randVal);
|
||||
}
|
||||
|
||||
public boolean isEmpty() {
|
||||
return randomPool.isEmpty();
|
||||
}
|
||||
|
||||
private static class EWeight<E> {
|
||||
final E e;
|
||||
double sumWeight;
|
||||
|
||||
public EWeight(E e, double sumWeight) {
|
||||
this.e = e;
|
||||
this.sumWeight = sumWeight;
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package cn.hutool.core.lang;
|
||||
|
||||
import cn.hutool.core.util.RandomUtil;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class WeightListPoolTest {
|
||||
|
||||
@Test
|
||||
public void weightRandomTest() {
|
||||
Map<Integer, Times> timesMap = new HashMap<>();
|
||||
int size = 100;
|
||||
double sumWeight = 0.0;
|
||||
WeightListPool<Integer> pool = new WeightListPool<>(size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
double weight = RandomUtil.randomDouble(100);
|
||||
pool.add(i, weight);
|
||||
sumWeight += weight;
|
||||
timesMap.put(i, new Times(weight));
|
||||
}
|
||||
|
||||
double d = 0.0001;// 随机误差
|
||||
int times = 100000000;// 随机次数
|
||||
for (int i = 0; i < times; i++) {
|
||||
timesMap.get(pool.random()).num++;
|
||||
}
|
||||
double finalSumWeight = sumWeight;
|
||||
timesMap.forEach((key, times1) -> {
|
||||
double expected = times1.weight / finalSumWeight;// 期望概率
|
||||
double actual = timesMap.get(key).num * 1.0 / times;// 真实随机概率
|
||||
assertTrue(Math.abs(actual - expected) < d);// 检验随机误差是否在误差范围内
|
||||
});
|
||||
}
|
||||
|
||||
private static class Times {
|
||||
|
||||
int num;
|
||||
|
||||
double weight;
|
||||
|
||||
public Times(double weight) {
|
||||
this.weight = weight;
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user