禁忌搜索是一种可以用于解决组合优化问题的启发式算法,通过引入记忆机制跳出局部最优,避免重复搜索。该算法从一个初始解开始,通过邻域搜索策略来寻找当前解的邻域解,并在邻域解中选择一个最优解作为下一次迭代的当前解,为了避免算法陷入局部最优,引入禁忌表来记录已经访问过的操作,禁止算法在一定迭代次数内再次选择这些被禁忌的操作,另外算法可以设置一些特赦条件,使得被禁忌的操作可以解除禁忌,从而探索更优的解空间。

算法流程
在这里插入图片描述

旅行商问题
假设有 4 个城市A、B、C、D,旅行商需要从一个城市出发,遍历所有城市且每个城市只经过一次,最后回到起始城市,要求找到最短的旅行路线,城市距离矩阵如下,最短的旅行路线为 A → B → D → C → A
在这里插入图片描述

python版本

from collections import deque

DISTANCE_MATRIX = [
    [0, 2, 9, 10],
    [2, 0, 6, 4],
    [9, 6, 0, 8],
    [10, 4, 8, 0]
]

NUM_CITIES = 4        # 城市数量
TABU_TENURE = 2       # 禁忌表长度
MAX_ITERATIONS = 100  # 最大迭代次数

def main():
    best_solution = tabu_search()
    print(f"最优路径: {format_path(best_solution)}")
    print(f"最短距离: {calculate_distance(best_solution)}")

def format_path(path):
    cities = ["A", "B", "C", "D"]
    return " → ".join(cities[idx] for idx in path) + " → " + cities[0]

# 禁忌搜索核心算法
def tabu_search():
    # 初始化解
    current_solution = generate_initial_solution()
    best_solution = current_solution.copy()
    best_distance = calculate_distance(best_solution)

    # 禁忌表
    tabu_list = deque(maxlen=TABU_TENURE)

    # 迭代搜索
    for _ in range(MAX_ITERATIONS):
        best_candidate = None
        best_candidate_dist = float('inf')
        move = None

        # 生成邻域解
        for i in range(1, NUM_CITIES):
            for j in range(i+1, NUM_CITIES):
                # 避免重复交换
                swap_key = f"{i}-{j}"

                # 生成候选解
                candidate = current_solution.copy()
                swap(candidate, i, j)
                candidate_dist = calculate_distance(candidate)

                # 检查是否满足特赦的条件
                is_aspiration = candidate_dist < best_distance

                # 选择最优候选解或者满足特赦条件的候选解
                if swap_key not in tabu_list or is_aspiration:
                    if candidate_dist < best_candidate_dist:
                        best_candidate = candidate.copy()
                        best_candidate_dist = candidate_dist
                        move = swap_key

        # 更新当前解
        if best_candidate is not None:
            current_solution = best_candidate.copy()

            # 更新禁忌表
            tabu_list.append(move)

            # 更新全局最优解
            if best_candidate_dist < best_distance:
                best_solution = best_candidate.copy()
                best_distance = best_candidate_dist

    return best_solution

def generate_initial_solution():
    return list(range(NUM_CITIES))

def swap(array, i, j):
    array[i], array[j] = array[j], array[i]

# 计算路径总距离
def calculate_distance(path):
    distance = 0
    for i in range(NUM_CITIES):
        from_city = path[i]
        to_city = path[(i+1) % NUM_CITIES]
        distance += DISTANCE_MATRIX[from_city][to_city]
    return distance

if __name__ == "__main__":
    main()

java版本

public class TabuSearchTSP {

    // 城市距离矩阵
    private static final int[][] DISTANCE_MATRIX = {
            {0, 2, 9, 10},
            {2, 0, 6, 4},
            {9, 6, 0, 8},
            {10, 4, 8, 0}
    };

    private static final int NUM_CITIES = 4;      // 城市数量
    private static final int TABU_TENURE = 2;     // 禁忌表长度
    private static final int MAX_ITERATIONS = 100; // 最大迭代次数

    public static void main(String[] args) {
        int[] bestSolution = tabuSearch();
        System.out.println("最优路径: " + formatPath(bestSolution));
        System.out.println("最短距离: " + calculateDistance(bestSolution));
    }
    private static String formatPath(int[] path) {
        String[] cities = {"A", "B", "C", "D"};
        StringBuilder sb = new StringBuilder();
        for (int idx : path) {
            sb.append(cities[idx]).append(" → ");
        }
        sb.append(cities[0]);
        return sb.toString();
    }
    // 禁忌搜索核心算法
    private static int[] tabuSearch() {
        // 初始化解
        int[] currentSolution = generateInitialSolution();
        int[] bestSolution = currentSolution.clone();
        int bestDistance = calculateDistance(bestSolution);

        // 禁忌表
        Queue<String> tabuList = new LinkedList<>();

        // 迭代搜索
        for (int iter = 0; iter < MAX_ITERATIONS; iter++) {
            int[] bestCandidate = null;
            int bestCandidateDist = Integer.MAX_VALUE;
            String move = null;

            // 生成邻域解
            for (int i = 1; i < NUM_CITIES; i++) {
                for (int j = i+1; j < NUM_CITIES; j++) {
                    // 避免重复交换
                    String swapKey = i + "-" + j;

                    // 生成候选解
                    int[] candidate = currentSolution.clone();
                    swap(candidate, i, j);
                    int candidateDist = calculateDistance(candidate);

                    // 检查是否满足特赦的条件
                    boolean isAspiration = candidateDist < bestDistance;

                    // 选择最优候选解或者满足特赦条件的候选解
                    if (!tabuList.contains(swapKey) || isAspiration) {
                        if (candidateDist < bestCandidateDist) {
                            bestCandidate = candidate.clone();
                            bestCandidateDist = candidateDist;
                            move = swapKey;
                        }
                    }
                }
            }

            // 更新当前解
            if (bestCandidate != null) {
                currentSolution = bestCandidate.clone();

                // 更新禁忌表
                tabuList.add(move);
                if (tabuList.size() > TABU_TENURE) {
                    tabuList.poll();
                }

                // 更新全局最优解
                if (bestCandidateDist < bestDistance) {
                    bestSolution = bestCandidate.clone();
                    bestDistance = bestCandidateDist;
                }
            }
        }
        return bestSolution;
    }

    private static int[] generateInitialSolution() {
        int[] solution = new int[NUM_CITIES];
        for (int i = 0; i < NUM_CITIES; i++) {
            solution[i] = i;
        }
        return solution;
    }

    private static void swap(int[] array, int i, int j) {
        int temp = array[i];
        array[i] = array[j];
        array[j] = temp;
    }

    // 计算路径总距离
    private static int calculateDistance(int[] path) {
        int distance = 0;
        for (int i = 0; i < NUM_CITIES; i++) {
            int from = path[i];
            int to = path[(i+1)%NUM_CITIES];
            distance += DISTANCE_MATRIX[from][to];
        }
        return distance;
    }
}

在这里插入图片描述

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐