Source code for slidingpuzzle.nn.models

# Copyright 2023 Stephen Dunn

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Defines a PyTorch model to evaluate sliding puzzle boards.
"""

from typing import Optional

import torch
import torch.nn as nn

import slidingpuzzle.nn.paths as paths
import slidingpuzzle.board as board_


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32
VERSION_1 = "v1"
VERSION_2 = "v2"


[docs] class Model_v1(nn.Module): """ A stack of linear layers that accepts a board as input and outputs the estimated distance to the goal. Trained with:: SGD(lr = 0.001, momentum = 0.9) MSELoss Total parameters: 1322496 """ def __init__(self, h: int, w: int) -> None: super().__init__() self.version = VERSION_1 # required self.h = h # required self.w = w # required hidden_size = 512 self.input = nn.Sequential( nn.Flatten(), nn.Linear(h * w, hidden_size, dtype=DTYPE), nn.ReLU(), nn.BatchNorm1d(hidden_size, dtype=DTYPE), ) n_layers = 5 self.linears = nn.Sequential( *[ nn.Sequential( nn.Linear(hidden_size, hidden_size, bias=False, dtype=DTYPE), nn.ReLU(), nn.BatchNorm1d(hidden_size, dtype=DTYPE), ) for _ in range(n_layers) ] ) self.output = nn.Linear(hidden_size, 1, bias=False, dtype=DTYPE)
[docs] def forward(self, x): x = self.input(x) x = self.linears(x) x = self.output(x) return x
[docs] class Model_v2(nn.Module): """ A multi-headed attention-based network with about the same number of parameters as v1. (WIP) Trained with:: SGD(lr = 0.0001, momentum = 0.9) MSELoss Total parameters: 1170767 """ def __init__(self, h: int, w: int) -> None: super().__init__() self.version = VERSION_2 # required self.h = h # required self.w = w # required size = h * w hidden_size = size * 32 # should be chosen so that is divisible by h*w self.flatten = nn.Flatten() self.embed = nn.ModuleList( nn.Linear(1, hidden_size, dtype=DTYPE) for _ in range(size) ) self.first_attention = nn.MultiheadAttention( hidden_size, size, batch_first=True, dtype=DTYPE ) self.first_attention_linear = nn.Linear(hidden_size, hidden_size, dtype=DTYPE) self.first_attention_norm = nn.BatchNorm1d(size) n_layers = 1 self.attentions = nn.ModuleList( nn.MultiheadAttention(hidden_size, size, batch_first=True, dtype=DTYPE) for _ in range(n_layers) ) self.linears = nn.ModuleList( nn.Linear(hidden_size, hidden_size, dtype=DTYPE) for _ in range(n_layers) ) self.norms = nn.ModuleList( nn.BatchNorm1d(size, dtype=DTYPE) for _ in range(n_layers) ) self.last_attention = nn.MultiheadAttention( hidden_size, size, batch_first=True, dtype=DTYPE ) self.last_attention_linear = nn.Linear(hidden_size, 1, dtype=DTYPE) self.output = nn.Linear(size, 1, dtype=DTYPE)
[docs] def forward(self, x): # first embed and create initial q, k, v x = self.flatten(x) x = torch.split(x, 1, dim=-1) # break up board into individual tile tensors x = torch.stack([torch.relu(e(t)) for e, t in zip(self.embed, x)], dim=1) # attention layers x, _ = self.first_attention(x, x, x) x = torch.relu(self.first_attention_linear(x)) x = self.first_attention_norm(x) for atn, lin, nrm in zip(self.attentions, self.linears, self.norms): x, _ = atn(x, x, x) x = torch.relu(lin(x)) x = nrm(x) x, _ = self.last_attention(x, x, x) x = torch.relu(self.last_attention_linear(x)) # output x = self.flatten(x) x = self.output(x) return x
[docs] def save_model(model: nn.Module, device: Optional[str] = None) -> None: """ Save a frozen version of the model into the "models" dir. Args: model: The trained model. device: The device the model is currently loaded on. If not provied, it will be guessed. """ if device is None: device = DEVICE model.eval() model.to(device) path = paths.get_model_path(model.h, model.w, model.version) board = board_.new_board(model.h, model.w) example_inputs = torch.tensor(board, dtype=torch.float32).unsqueeze(0).to(device) traced_model = torch.jit.trace(model, example_inputs) frozen_model = torch.jit.freeze(traced_model) frozen_model.save(str(path))
[docs] def load_model( h: int, w: int, version: str, device: Optional[str] = None ) -> torch.ScriptModule: """ Reload a pre-trained frozen model. """ if device is None: device = DEVICE model_path = paths.get_model_path(h, w, version) model = torch.jit.load(str(model_path), map_location=device) model.eval() return model
[docs] def get_num_params(model: nn.Module) -> int: """ Compute the total number of parameters in a model. Args: model: A model Returns: The total number of trainable parameters """ return sum(p.numel() for p in model.parameters() if p.requires_grad)