Source code for slidingpuzzle.nn.heuristics

# Copyright 2023 Stephen Dunn

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Defines neural network-guided heuristics.
"""
import torch

import slidingpuzzle.nn.models as models
import slidingpuzzle.nn.paths as paths
from slidingpuzzle.board import Board
from slidingpuzzle.heuristics import Heuristic


model_heuristics = {}


[docs] def get_heuristic_key(h: int, w: int, version: str): board_size_str = paths.get_board_size_str(h, w) return f"{board_size_str}_{version}"
[docs] def make_heuristic(model: torch.nn.Module | torch.ScriptModule): device = models.DEVICE dtype = models.DTYPE def heuristic(board: tuple[list[int], ...]) -> float: tensor = torch.tensor(board, dtype=dtype).unsqueeze(0).to(device) return model(tensor).item() return heuristic
[docs] def set_heuristic(model: torch.nn.Module | torch.ScriptModule): key = get_heuristic_key(model.w, model.h, model.version) heuristic = make_heuristic(model) model_heuristics[key] = heuristic return heuristic
[docs] def get_heuristic(h: int, w: int, version: str) -> Heuristic: key = get_heuristic_key(h, w, version) heuristic = model_heuristics.get(key, None) if heuristic is None: model = models.load_model(h, w, version) key = get_heuristic_key(h, w, version) heuristic = make_heuristic(model) model_heuristics[key] = heuristic return heuristic
################################################################## # methods beyond here correspond to predefined model classes # you can add your own model heuristics below ##################################################################
[docs] def v1_distance(board: Board) -> float: """ A neural network that estimates distance to goal Args: board: The board Returns: An estimated number of moves to reach the goal """ heuristic = get_heuristic(*board.shape, models.VERSION_1) return heuristic(board)
[docs] def v2_distance(board: Board) -> float: """ A neural network that estimates distance to goal Args: board: The board Returns: An estimated number of moves to reach the goal """ heuristic = get_heuristic(*board.shape, models.VERSION_2) return heuristic(board)