Merge pull request #3720 from wangyefeng/v5-weight

用ArrayList重新实现权重随机类
This commit is contained in:
Golden Looly 2024-09-04 20:08:13 +08:00 committed by GitHub
commit 7c6978c990
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 185 additions and 0 deletions

View File

@ -0,0 +1,136 @@
package cn.hutool.core.lang;
import cn.hutool.core.util.RandomUtil;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
/**
* 动态按权重随机的随机池底层是list实现
*
* @param <E> 元素类型
* @author 王叶峰
* @date 2024-07-29
*/
public class WeightListPool<E> {
/**
* 随机元素池
*/
private final ArrayList<EWeight<E>> randomPool;
public WeightListPool() {
randomPool = new ArrayList<>();
}
public WeightListPool(int poolSize) {
randomPool = new ArrayList<>(poolSize);
}
public void add(E e, double weight) {
Assert.isTrue(weight > 0, "权重必须大于0");
randomPool.add(new EWeight<>(e, sumWeight() + weight));
}
public boolean remove(E e) {
boolean removed = false;
double weight = 0;
int i = 0;
Iterator<EWeight<E>> iterator = randomPool.iterator();
while (iterator.hasNext()) {
EWeight<E> ew = iterator.next();
if (!removed && ew.e.equals(e)) {
iterator.remove();
weight = ew.sumWeight - (i == 0 ? 0 : randomPool.get(i - 1).sumWeight);// 权重=当前权重-上一个权重
removed = true;
}
if (removed) {
// 重新计算后续权重
ew.sumWeight -= weight;
}
i++;
}
return removed;
}
private double sumWeight() {
if (randomPool.isEmpty()) {
return 0;
}
return randomPool.get(randomPool.size() - 1).sumWeight;
}
private void checkEmptyPool() {
if (isEmpty()) {
throw new IllegalArgumentException("随机池为空!");
}
}
public E random() {
checkEmptyPool();
if (randomPool.size() == 1) {
return randomPool.get(0).e;
}
ThreadLocalRandom random = RandomUtil.getRandom();
double randVal = random.nextDouble() * sumWeight();
return binarySearch(randVal);
}
/**
* 二分查找小于等于key的最大值的元素
*
* @param key 目标值
* @return 随机池的一个元素或者null 当key大于所有元素的总权重时返回null
*/
private E binarySearch(double key) {
int low = 0;
int high = randomPool.size() - 1;
while (low <= high) {
int mid = (low + high) >>> 1;
double midVal = randomPool.get(mid).sumWeight;
if (midVal < key) {
low = mid + 1;
} else if (midVal > key) {
high = mid - 1;
} else {
return randomPool.get(mid).e;
}
}
return randomPool.get(low).e;
}
/**
* 按照给定的总权重随机出一个元素
*
* @param weight 总权重
* @return 随机池的一个元素或者null
*/
public E randomByWeight(double weight) {
Assert.isTrue(weight >= sumWeight(), "权重必须大于当前总权重!");
ThreadLocalRandom random = RandomUtil.getRandom();
double randVal = random.nextDouble() * sumWeight();
if (randVal > sumWeight()) {
return null;
}
return binarySearch(randVal);
}
public boolean isEmpty() {
return randomPool.isEmpty();
}
private static class EWeight<E> {
final E e;
double sumWeight;
public EWeight(E e, double sumWeight) {
this.e = e;
this.sumWeight = sumWeight;
}
}
}

View File

@ -0,0 +1,49 @@
package cn.hutool.core.lang;
import cn.hutool.core.util.RandomUtil;
import org.junit.jupiter.api.Test;
import java.util.HashMap;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class WeightListPoolTest {
@Test
public void weightRandomTest() {
Map<Integer, Times> timesMap = new HashMap<>();
int size = 100;
double sumWeight = 0.0;
WeightListPool<Integer> pool = new WeightListPool<>(size);
for (int i = 0; i < size; i++) {
double weight = RandomUtil.randomDouble(100);
pool.add(i, weight);
sumWeight += weight;
timesMap.put(i, new Times(weight));
}
double d = 0.0001;// 随机误差
int times = 100000000;// 随机次数
for (int i = 0; i < times; i++) {
timesMap.get(pool.random()).num++;
}
double finalSumWeight = sumWeight;
timesMap.forEach((key, times1) -> {
double expected = times1.weight / finalSumWeight;// 期望概率
double actual = timesMap.get(key).num * 1.0 / times;// 真实随机概率
assertTrue(Math.abs(actual - expected) < d);// 检验随机误差是否在误差范围内
});
}
private static class Times {
int num;
double weight;
public Times(double weight) {
this.weight = weight;
}
}
}