Skip to content

Commit

Permalink
Updated example
Browse files Browse the repository at this point in the history
  • Loading branch information
perara committed Jan 10, 2025
1 parent 5f078e8 commit 9cc34ac
Show file tree
Hide file tree
Showing 5 changed files with 646 additions and 201 deletions.
13 changes: 6 additions & 7 deletions per_jsp/examples/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Main entry point for the Job Shop Scheduling System.
Handles problem generation, solving, and result visualization.
"""

import logging
import argparse
import json
Expand Down Expand Up @@ -233,11 +232,11 @@ def create_scheduler(args) -> BaseScheduler:
if args.algorithm == "greedy":
return GreedyScheduler(use_longest=args.use_longest)
elif args.algorithm == "q-learning":
return QLearningScheduler(
learning_rate=args.learning_rate,
discount_factor=args.discount_factor,
epsilon=args.epsilon,
episodes=args.episodes
return QLearningScheduler(
learning_rate=0.1,
discount_factor=0.95,
exploration_rate=1.0,
episodes=1000
)
else:
raise ValueError(f"Unknown algorithm: {args.algorithm}")
Expand Down Expand Up @@ -267,7 +266,7 @@ def setup_argument_parser() -> argparse.ArgumentParser:
help="Discount factor for Q-learning")
parser.add_argument("--epsilon", type=float, default=0.1,
help="Exploration rate for Q-learning")
parser.add_argument("--episodes", type=int, default=10000,
parser.add_argument("--episodes", type=int, default=1000,
help="Number of episodes for Q-learning")

# Automatic generation parameters
Expand Down
251 changes: 141 additions & 110 deletions per_jsp/python/per_jsp/algorithms/q_learning.py
Original file line number Diff line number Diff line change
@@ -1,141 +1,172 @@
import numpy as np
import time
import random
from typing import List, Tuple, Callable
import logging
from typing import List, Tuple, Dict
from .base import BaseScheduler
from dataclasses import dataclass
import time

from per_jsp.algorithms.base import BaseScheduler
from per_jsp.environment.job_shop_environment import JobShopEnvironment, Action

logger = logging.getLogger(__name__)

class QLearningScheduler(BaseScheduler):
"""Q-learning algorithm for job shop scheduling."""
"""Direct Python implementation of the C++ Q-learning scheduler."""

def __init__(self,
learning_rate: float = 0.1,
discount_factor: float = 0.95,
epsilon: float = 0.1,
episodes: int = 100):
"""
Initialize Q-learning scheduler.
Args:
learning_rate: Learning rate for Q-value updates
discount_factor: Discount factor for future rewards
epsilon: Probability of random exploration
episodes: Number of training episodes
"""
exploration_rate: float = 1.0,
episodes: int = 1000):
self.learning_rate = learning_rate
self.discount_factor = discount_factor
self.epsilon = epsilon
self.exploration_rate = exploration_rate
self.episodes = episodes
self.q_table: Dict[int, Dict[Tuple[int, int, int], float]] = {}

def _get_state_key(self, env: JobShopEnvironment) -> int:
"""Create a unique key for the current state."""
state_tuple = (
tuple(env.current_state.machine_availability),
tuple(env.current_state.next_operation_for_job),
tuple(env.current_state.completed_jobs)

self.q_table = None
self.best_time = float('inf')
self.best_schedule = []
self.rng = np.random.RandomState()

def _initialize_q_table(self, env: JobShopEnvironment) -> None:
"""Initialize Q-table with proper dimensions."""
max_operations = max(len(job.operations) for job in env.jobs)
self.q_table = np.zeros((
len(env.jobs), # Number of jobs
env.num_machines, # Number of machines
max_operations # Max operations per job
))

def _calculate_priority(self, env: JobShopEnvironment, action: Action) -> float:
"""Calculate priority score for an action."""
# Calculate remaining processing time for the job
remaining_time = sum(
op.duration
for op in env.jobs[action.job].operations[action.operation:]
)
return hash(state_tuple)

def _get_q_value(self, state_key: int, action: Action) -> float:
"""Get Q-value for state-action pair."""
if state_key not in self.q_table:
self.q_table[state_key] = {}
action_key = (action.job, action.machine, action.operation)
return self.q_table[state_key].get(action_key, 0.0)

def _update_q_value(self, state_key: int, action: Action, value: float):
"""Update Q-value for state-action pair."""
if state_key not in self.q_table:
self.q_table[state_key] = {}
action_key = (action.job, action.machine, action.operation)
self.q_table[state_key][action_key] = value

def _calculate_reward(self, env: JobShopEnvironment, action: Action) -> float:
"""Calculate reward for taking an action."""
job = env.jobs[action.job]
operation = job.operations[action.operation]

# Combine multiple factors for reward
completion_bonus = 1.0 if env.current_state.completed_jobs[action.job] else 0.0
duration_penalty = -operation.duration / 100.0 # Penalize longer operations
machine_utilization = -env.current_state.machine_availability[action.machine] / 1000.0

return completion_bonus + duration_penalty + machine_utilization

# Calculate machine utilization
machine_time = env.current_state.machine_availability[action.machine]
total_time = max(1, env.total_time)
machine_utilization = machine_time / total_time

# Priority combines remaining time and machine availability
return remaining_time * (1 - machine_utilization)

def _select_action(self, env: JobShopEnvironment) -> Action:
"""Select action using epsilon-greedy strategy with priority-based exploration."""
possible_actions = env.get_possible_actions()

if not possible_actions:
return None

if self.rng.random() < self.exploration_rate:
# Use priorities for smarter exploration
priorities = [self._calculate_priority(env, action) for action in possible_actions]
total_priority = sum(priorities)
if total_priority > 0:
priorities = [p/total_priority for p in priorities]
return possible_actions[self.rng.choice(len(possible_actions), p=priorities)]
return self.rng.choice(possible_actions)

# Greedy selection based on Q-values
best_q = float('-inf')
best_action = possible_actions[0]

for action in possible_actions:
q_value = self.q_table[action.job, action.machine, action.operation]
if q_value > best_q:
best_q = q_value
best_action = action

return best_action

def _update_q_value(self, env: JobShopEnvironment, action: Action, prev_time: int) -> None:
"""Update Q-value for the given action."""
# Calculate time-based reward
time_reward = -(env.total_time - prev_time)

# Calculate utilization-based reward
utils = env.get_machine_utilization()
util_reward = np.mean(utils) * 100

# Combined reward
reward = time_reward + util_reward

# Get maximum future Q-value
possible_actions = env.get_possible_actions()
max_future_q = 0.0
if possible_actions:
max_future_q = max(
self.q_table[a.job, a.machine, a.operation]
for a in possible_actions
)

# Update Q-value
current_q = self.q_table[action.job, action.machine, action.operation]
new_q = (1 - self.learning_rate) * current_q + self.learning_rate * (
reward + self.discount_factor * max_future_q
)
self.q_table[action.job, action.machine, action.operation] = new_q

def _run_episode(self, env: JobShopEnvironment, max_steps: int = 1000) -> List[Action]:
"""Run a single episode."""
env.reset()
episode_actions = []

while not env.is_done() and len(episode_actions) < max_steps:
prev_time = env.total_time

action = self._select_action(env)
if action is None:
break

env.step(action)
episode_actions.append(action)

self._update_q_value(env, action, prev_time)

return episode_actions

def solve(self, env: JobShopEnvironment, max_steps: int = 1000) -> Tuple[List[Action], int]:
"""Solve using Q-learning algorithm."""
"""Solve using Q-learning."""
start_time = time.time()
best_actions = []
best_makespan = float('inf')

# Training phase
logger.debug(f"Starting Q-learning training for {self.episodes} episodes...")
# Initialize Q-table
if self.q_table is None:
self._initialize_q_table(env)

logger.info(f"Starting Q-learning training for {self.episodes} episodes...")

for episode in range(self.episodes):
# Run episode
episode_actions = self._run_episode(env, max_steps)

# Evaluate episode
env.reset()
episode_actions = []
step_count = 0

while not env.is_done() and step_count < max_steps:
state_key = self._get_state_key(env)
possible_actions = env.get_possible_actions()

if not possible_actions:
break

# Epsilon-greedy action selection
if np.random.random() < self.epsilon:
action = np.random.choice(possible_actions)
else:
action = max(
possible_actions,
key=lambda a: self._get_q_value(state_key, a)
)

# Take action and observe result
for action in episode_actions:
env.step(action)
episode_actions.append(action)
reward = self._calculate_reward(env, action)

# Update Q-value
new_state_key = self._get_state_key(env)
new_possible_actions = env.get_possible_actions()

if new_possible_actions:
max_future_q = max(
self._get_q_value(new_state_key, a)
for a in new_possible_actions
)
else:
max_future_q = 0

old_q = self._get_q_value(state_key, action)
new_q = (1 - self.learning_rate) * old_q + self.learning_rate * (
reward + self.discount_factor * max_future_q
)
self._update_q_value(state_key, action, new_q)

step_count += 1

# Check if this episode found a better solution
if env.total_time < best_makespan:
best_makespan = env.total_time
best_actions = episode_actions.copy()

# Track best solution
if env.total_time < self.best_time:
self.best_time = env.total_time
self.best_schedule = episode_actions.copy()
logger.info(f"New best makespan: {self.best_time}")

# Decay exploration rate
self.exploration_rate *= 0.9999

if (episode + 1) % 10 == 0:
logger.debug(f"Episode {episode + 1}/{self.episodes}, "
f"Best makespan: {best_makespan}")
logger.info(f"Episode {episode + 1}/{self.episodes}, "
f"Best makespan: {self.best_time}")

# Final run with best policy
# Final run with best actions
env.reset()
for action in best_actions:
for action in self.best_schedule:
env.step(action)

solve_time = time.time() - start_time
logger.debug(f"Q-learning solved in {solve_time:.2f} seconds")
logger.debug(f"Final makespan: {env.total_time}")
logger.info(f"Q-learning solved in {solve_time:.2f} seconds")
logger.info(f"Final makespan: {env.total_time}")

return best_actions, env.total_time
return self.best_schedule, env.total_time
Loading

0 comments on commit 9cc34ac

Please sign in to comment.