Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ schematics_data
output
datasets
htmlcov
output_2
113 changes: 45 additions & 68 deletions minecraft_copilot_ml/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,26 @@
from typing import List, Optional, Set, Tuple

import boto3
import lightning as pl
import numpy as np
import pytorch_lightning as pl
import torch
from improved_diffusion.unet import UNetModel # type: ignore[import-untyped]
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from loguru import logger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from sklearn.model_selection import train_test_split # type: ignore
from torch.utils.data import DataLoader
from tqdm import tqdm

from minecraft_copilot_ml.data_loader import (
MinecraftSchematicsDataset,
MinecraftSchematicsDatasetItemType,
get_working_files_and_unique_blocks_and_counts,
get_working_files_and_unique_blocks,
list_schematic_files_in_folder,
)
from minecraft_copilot_ml.model import UNet3d


def export_to_onnx(model: UNet3d, path_to_output: str) -> None:
torch.onnx.export(
model,
torch.randn(1, 1, 16, 16, 16).to("cuda" if torch.cuda.is_available() else "cpu"),
path_to_output,
input_names=["input"],
output_names=["output"],
# https://onnxruntime.ai/docs/reference/compatibility.html
opset_version=17,
)
from minecraft_copilot_ml.model import MinecraftCopilotTrainer

device_name = torch.cuda.get_device_name()
if device_name is not None and device_name == "GeForce RTX 3090":
torch.set_float32_matmul_precision("medium")


def main(argparser: argparse.ArgumentParser) -> None:
Expand All @@ -58,74 +49,60 @@ def main(argparser: argparse.ArgumentParser) -> None:
schematics_list_files = schematics_list_files[start:end]
# Set the dictionary size to the number of unique blocks in the dataset.
# And also select the right files to load.
unique_blocks_dict, unique_counts_coefficients, loaded_schematic_files = (
get_working_files_and_unique_blocks_and_counts(schematics_list_files)
)
unique_blocks_dict, loaded_schematic_files = get_working_files_and_unique_blocks(schematics_list_files)

logger.info(f"Unique blocks: {unique_blocks_dict}")
logger.info(f"Number of unique blocks: {len(unique_blocks_dict)}")
logger.info(f"Number of loaded schematics files: {len(loaded_schematic_files)}")
logger.info(f"Unique counts coefficients: {unique_counts_coefficients}")

train_schematics_list_files, test_schematics_list_files = train_test_split(
loaded_schematic_files, test_size=0.2, random_state=42
)
train_schematics_dataset = MinecraftSchematicsDataset(train_schematics_list_files)
val_schematics_dataset = MinecraftSchematicsDataset(test_schematics_list_files)
schematics_dataset = MinecraftSchematicsDataset(loaded_schematic_files)

def collate_fn(batch: List[MinecraftSchematicsDatasetItemType]) -> MinecraftSchematicsDatasetItemType:
block_map, noisy_block_map, mask, loss_mask = zip(*batch)
return np.stack(block_map), np.stack(noisy_block_map), np.stack(mask), np.stack(loss_mask)

train_schematics_dataloader = DataLoader(
train_schematics_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
block_map, block_map_mask = zip(*batch)
return np.stack(block_map), np.stack(block_map_mask)

num_workers = os.cpu_count()
if num_workers is None:
num_workers = 0

schematics_dataloader = DataLoader(
schematics_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=num_workers,
)
val_schematics_dataloader = DataLoader(val_schematics_dataset, batch_size=batch_size, collate_fn=collate_fn)

model = UNet3d(unique_blocks_dict, unique_counts_coefficients=unique_counts_coefficients)
unet_model = UNetModel(
in_channels=len(unique_blocks_dict),
model_channels=32,
out_channels=len(unique_blocks_dict),
num_res_blocks=2,
num_heads=2,
attention_resolutions=[1],
dropout=0.1,
channel_mult=(1, 2, 4, 8),
dims=3,
)
model = MinecraftCopilotTrainer(unet_model, unique_blocks_dict, save_dir=path_to_output)
csv_logger = CSVLogger(save_dir=path_to_output)
model_checkpoint = ModelCheckpoint(path_to_output, monitor="val_loss", save_top_k=1, save_last=True, mode="min")
trainer = pl.Trainer(logger=csv_logger, callbacks=model_checkpoint, max_epochs=epochs, log_every_n_steps=1)
trainer.fit(model, train_schematics_dataloader, val_schematics_dataloader)
model_checkpoint = ModelCheckpoint(path_to_output, save_last=True, mode="min")
trainer = pl.Trainer(
logger=csv_logger, callbacks=model_checkpoint, max_epochs=epochs, log_every_n_steps=1, accelerator="gpu"
)
trainer.fit(model, schematics_dataloader)

# Save the best and last model locally
logger.info(f"Best val_loss is: {model_checkpoint.best_model_score}")
best_model = UNet3d.load_from_checkpoint(
model_checkpoint.best_model_path,
unique_blocks_dict=unique_blocks_dict,
unique_counts_coefficients=unique_counts_coefficients,
)
torch.save(best_model, os.path.join(path_to_output, "best_model.pth"))
last_model = UNet3d.load_from_checkpoint(
last_model = MinecraftCopilotTrainer.load_from_checkpoint(
model_checkpoint.last_model_path,
unet_model=unet_model,
unique_blocks_dict=unique_blocks_dict,
unique_counts_coefficients=unique_counts_coefficients,
save_dir=path_to_output,
)
torch.save(last_model, os.path.join(path_to_output, "last_model.pth"))
export_to_onnx(best_model, os.path.join(path_to_output, "best_model.onnx"))
export_to_onnx(last_model, os.path.join(path_to_output, "last_model.onnx"))
with open(os.path.join(path_to_output, "unique_blocks_dict.json"), "w") as f:
json.dump(unique_blocks_dict, f)

# Save the best and last model to S3
s3_client = boto3.client(
"s3",
region_name="eu-west-3",
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
)
s3_client.upload_file(os.path.join(path_to_output, "best_model.pth"), "minecraft-copilot-models", "best_model.pth")
s3_client.upload_file(os.path.join(path_to_output, "last_model.pth"), "minecraft-copilot-models", "last_model.pth")
s3_client.upload_file(
os.path.join(path_to_output, "best_model.onnx"), "minecraft-copilot-models", "best_model.onnx"
)
s3_client.upload_file(
os.path.join(path_to_output, "last_model.onnx"), "minecraft-copilot-models", "last_model.onnx"
)
s3_client.upload_file(
os.path.join(path_to_output, "unique_blocks_dict.json"), "minecraft-copilot-models", "unique_blocks_dict.json"
)


if __name__ == "__main__":
argparser = argparse.ArgumentParser()
Expand Down
45 changes: 16 additions & 29 deletions minecraft_copilot_ml/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# flake8: noqa: E203
import gc
import os
import re
from pathlib import Path
Expand Down Expand Up @@ -43,6 +42,10 @@
"4766.schematic",
"10380.schematic",
"12695.schematic",
"8675.schematic",
"10220.schematic",
"5096.schematic",
"14191.schematic"
]


Expand Down Expand Up @@ -73,6 +76,9 @@ def litematic_to_numpy_minecraft_map(
for z, k in zip(reg.zrange(), range(len(reg.zrange()))):
b = reg.getblock(x, y, z)
numpy_map[i, j, k] = b.blockid
numpy_map[numpy_map == "None"] = "minecraft:air"
numpy_map[numpy_map == None] = "minecraft:air"
del nbt_loaded
return numpy_map


Expand All @@ -94,13 +100,15 @@ def schematic_to_numpy_minecraft_map(
raise Exception(f"Could not find Blocks or BlockData in {nbt_file}. Known keys: {res.keys()}")
block_map = np.asarray(block_data).reshape(res["Height"], res["Length"], res["Width"])
block_map = np.vectorize(palette.get)(block_map)
block_map[block_map == "None"] = "minecraft:air"
block_map[block_map == None] = "minecraft:air"
del res
return block_map


def nbt_to_numpy_minecraft_map(
nbt_file: str,
) -> np.ndarray:
gc.collect()
if any([Path(nbt_file).parts[-1] == x for x in list_of_forbidden_files]):
raise Exception(
f"File {nbt_file} is forbidden. Skipping. If this file is here it is because it generates a SIGKILL."
Expand Down Expand Up @@ -171,7 +179,7 @@ def get_random_block_map_and_mask_coordinates(
)


MinecraftSchematicsDatasetItemType = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
MinecraftSchematicsDatasetItemType = Tuple[np.ndarray, np.ndarray]


class MinecraftSchematicsDataset(Dataset):
Expand All @@ -195,27 +203,13 @@ def __getitem__(self, idx: int) -> MinecraftSchematicsDatasetItemType:
minimum_height,
minimum_depth,
) = get_random_block_map_and_mask_coordinates(numpy_minecraft_map, 16, 16, 16)
focused_block_map = block_map[
random_roll_x_value : random_roll_x_value + minimum_width,
random_y_height_value : random_y_height_value + minimum_height,
random_roll_z_value : random_roll_z_value + minimum_depth,
]
focused_noisy_block_map, unraveled_indices_of_noise = create_noisy_block_map(focused_block_map)
noisy_block_map = block_map.copy()
noisy_block_map[
random_roll_x_value : random_roll_x_value + minimum_width,
random_y_height_value : random_y_height_value + minimum_height,
random_roll_z_value : random_roll_z_value + minimum_depth,
] = focused_noisy_block_map
block_map_mask = np.zeros((16, 16, 16), dtype=bool)
block_map_mask[
random_roll_x_value : random_roll_x_value + minimum_width,
random_y_height_value : random_y_height_value + minimum_height,
random_roll_z_value : random_roll_z_value + minimum_depth,
] = True
loss_mask = np.zeros((16, 16, 16), dtype=bool)
loss_mask[unraveled_indices_of_noise] = True
return block_map, noisy_block_map, block_map_mask, loss_mask
return block_map, block_map_mask


def list_schematic_files_in_folder(path_to_schematics: str) -> list[str]:
Expand All @@ -229,22 +223,17 @@ def list_schematic_files_in_folder(path_to_schematics: str) -> list[str]:
return schematics_list_files


def get_working_files_and_unique_blocks_and_counts(
def get_working_files_and_unique_blocks(
schematics_list_files: list[str],
) -> Tuple[Dict[str, int], np.ndarray, list[str]]:
) -> Tuple[Dict[str, int], list[str]]:
unique_blocks: Set[str] = set()
unique_counts: Dict[str, int] = {}
loaded_schematic_files: List[str] = []
tqdm_list_files = tqdm(schematics_list_files, smoothing=0)
for nbt_file in tqdm_list_files:
tqdm_list_files.set_description(f"Processing {nbt_file}")
try:
numpy_minecraft_map = nbt_to_numpy_minecraft_map(nbt_file)
unique_blocks_in_map, unique_counts_in_map = np.unique(numpy_minecraft_map, return_counts=True)
for block, count in zip(unique_blocks_in_map, unique_counts_in_map):
if block not in unique_counts:
unique_counts[block] = 0
unique_counts[block] += count
unique_blocks_in_map = set(numpy_minecraft_map.flatten())
for block in unique_blocks_in_map:
if block not in unique_blocks:
logger.info(f"Found new block: {block}")
Expand All @@ -255,6 +244,4 @@ def get_working_files_and_unique_blocks_and_counts(
logger.exception(e)
continue
unique_blocks_dict = {block: idx for idx, block in enumerate(unique_blocks)}
unique_counts_coefficients = np.array([unique_counts[block] for block in unique_blocks_dict])
unique_counts_coefficients = unique_counts_coefficients.max() / unique_counts_coefficients
return unique_blocks_dict, unique_counts_coefficients, loaded_schematic_files
return unique_blocks_dict, loaded_schematic_files
2 changes: 1 addition & 1 deletion minecraft_copilot_ml/metrics_graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
Loading