Source code for slidingpuzzle.nn.paths
# Copyright 2023 Stephen Dunn
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for dealing with paths required during model training and inference.
"""
import shutil
from pathlib import Path
CHECKPOINT_DIR = "checkpoints"
DATASET_DIR = "datasets"
MODELS_DIR = "models"
TENSORBOARD_DIR = "tensorboard"
[docs]
def clear_training(h: int, w: int) -> None:
r"""
Removes checkpoints and tensorboard logs.
"""
shutil.rmtree(get_checkpoint_dir(h, w), ignore_errors=True)
shutil.rmtree(get_log_dir(h, w), ignore_errors=True)
[docs]
def get_board_size_str(h: int, w: int) -> str:
"""
Helper to get a string encoding height and width.
"""
return f"{h}x{w}"
[docs]
def get_path(dirname: str | Path, filename: str | Path) -> Path:
"""
Creates intermediate directorys for dirname and returns the full path from dirname
to filename.
Args:
dirname: A directory or path to a directory.
filename: The filename the path will point to.
Returns:
The path object.
"""
dirpath = Path(dirname)
dirpath.mkdir(exist_ok=True, parents=True)
return dirpath / filename
[docs]
def get_checkpoint_dir(h: int, w: int) -> Path:
"""
Get the path to store checkpoints for this board size.
Args:
h: Board height
w: Board width
Returns:
The checkpoint dir path
"""
return Path(CHECKPOINT_DIR) / get_board_size_str(h, w)
[docs]
def get_checkpoint_path(h: int, w: int, tag: str) -> Path:
"""
Get the path to a checkpoint, given a board size and optional tag.
Args:
h: Board height
w: Board width
tag: Checkpoint tag to load
Returns:
The path to the checkpoint file
"""
return get_path(get_checkpoint_dir(h, w), f"checkpoint_{tag}")
[docs]
def get_examples_path(h: int, w: int) -> Path:
board_size_str = get_board_size_str(h, w)
return get_path(DATASET_DIR, f"examples_{board_size_str}.json")
[docs]
def get_model_path(h: int, w: int, version: str) -> Path:
board_size_str = get_board_size_str(h, w)
return get_path(MODELS_DIR, f"{version}_{board_size_str}.pt")
[docs]
def get_log_dir(h: int, w: int) -> Path:
board_size_str = get_board_size_str(h, w)
return get_path(TENSORBOARD_DIR, f"slidingpuzzle_{board_size_str}")