From 9b5f949a0b0de44c7d36ea7f1aa3cbd03746edb0 Mon Sep 17 00:00:00 2001 From: Angel Valentin Date: Sat, 19 Jul 2025 17:02:15 -0400 Subject: [PATCH] Have a successful pipe communication between AI and balatro with some intial iteration going --- .gitignore | 10 +- CLAUDE.md | 21 +++ README.md | 5 + RLBridge/ai.lua | 77 +++++++-- RLBridge/communication.lua | 175 ++++++++++---------- RLBridge/init.lua | 1 + RLBridge/output.lua | 22 ++- RLBridge/utils.lua | 20 ++- ai/agent/__init__.py | 0 ai/agent/balatro_agent.py | 95 ----------- ai/agent/policy.py | 83 ---------- ai/agent/training.py | 153 ------------------ ai/environment/balatro_env.py | 297 ++++++++++++++++++---------------- ai/environment/reward.py | 128 ++++++--------- ai/file_watcher.py | 111 ------------- ai/requirements.txt | 3 +- ai/setup.sh | 4 +- ai/test_env.py | 96 +++++++++++ ai/train_balatro.py | 273 +++++++++++++++++++++++++++++++ ai/utils/communication.py | 192 ++++++++++++++++++++++ ai/utils/debug.py | 39 +++++ ai/utils/logging.py | 74 --------- ai/utils/mappers.py | 245 ++++++++++++++++++++++++++++ ai/utils/serialization.py | 122 -------------- ai/utils/validation.py | 89 ++++++++++ 25 files changed, 1364 insertions(+), 971 deletions(-) delete mode 100644 ai/agent/__init__.py delete mode 100644 ai/agent/balatro_agent.py delete mode 100644 ai/agent/policy.py delete mode 100644 ai/agent/training.py delete mode 100755 ai/file_watcher.py create mode 100644 ai/test_env.py create mode 100755 ai/train_balatro.py create mode 100644 ai/utils/communication.py create mode 100644 ai/utils/debug.py delete mode 100644 ai/utils/logging.py create mode 100644 ai/utils/mappers.py delete mode 100644 ai/utils/serialization.py create mode 100644 ai/utils/validation.py diff --git a/.gitignore b/.gitignore index 3c9ff16..1c600d3 100644 --- a/.gitignore +++ b/.gitignore @@ -56,13 +56,17 @@ logs/ *.temp .cache/ -# Model checkpoints and training artifacts +# Training artifacts (generated during RL training sessions) +models/ +tensorboard_logs/ +training.log +training_monitor.csv + +# Other ML artifacts *.pth *.pt *.ckpt checkpoints/ -models/ -tensorboard_logs/ wandb/ # Large reference files (don't commit decompiled game source) diff --git a/CLAUDE.md b/CLAUDE.md index 51d18e1..24b9599 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -55,11 +55,32 @@ Balatro uses numbered states defined in `globals.lua`: - `G.GAME` - Main game data structure - `G.FUNCS` - Game function callbacks +## Communication Architecture + +### Dual Pipe Communication +The mod uses separate named pipes for request/response communication with the AI: +- **Request pipe path**: `/tmp/balatro_request` (Balatro writes, AI reads) +- **Response pipe path**: `/tmp/balatro_response` (AI writes, Balatro reads) +- **Protocol**: Persistent pipe handles with blocking reads +- **Flow**: + 1. Balatro opens persistent handles to both pipes + 2. Balatro writes JSON request to request pipe + 3. AI reads request from request pipe + 4. AI writes JSON response to response pipe + 5. Balatro performs blocking read from response pipe + +### Benefits of Dual Pipes +- Clear separation of request/response channels +- Persistent handles avoid repeated open/close overhead +- Blocking reads ensure proper synchronization +- Each pipe has single direction, eliminating read/write conflicts + ## Current Status The mod can successfully: - Start a run automatically - Detect blind selection state - Automatically select the next blind +- Communicate with AI via dual pipe system ## Communication with Claude - Please keep chats as efficient as possible and brief. No fluff talk, just get to the point diff --git a/README.md b/README.md index 94e2431..628279d 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,12 @@ Made a symlink -> ln -s ~/dev/balatro-rl/RLBridge /mnt/gamerlinuxssd/SteamLibrar - [ ] Training loop integration ### Game Features +- [ ] Always have restart_run as an action option assuming the game is ongoing - [ ] Blind selection choices (skip vs select) - [ ] Extended game state (money, discards, hands played) - [ ] Shop interactions - [ ] Joker management + +### RL Enhancements +- [ ] Add a "Replay System" to analyze successful actions. For example, save seed, have an action log for reproduction etc +Can probably do this by adding it before the game is reset like a check on what criteria I want to save for, and save diff --git a/RLBridge/ai.lua b/RLBridge/ai.lua index 5c16bf0..1247d83 100644 --- a/RLBridge/ai.lua +++ b/RLBridge/ai.lua @@ -13,6 +13,10 @@ local utils = require("utils") local last_state_hash = nil local last_actions_hash = nil local pending_action = nil +local retry_count = 0 +local need_retry_request = false +local rl_training_active = false +local last_key_pressed = nil --- Initialize AI system --- Sets up communication and prepares the AI for operation @@ -20,12 +24,50 @@ local pending_action = nil function AI.init() utils.log_ai("Initializing AI system...") communication.init() + + -- Hook into Love2D keyboard events + if love and love.keypressed then + local original_keypressed = love.keypressed + love.keypressed = function(key) + -- Store the key press for our AI + last_key_pressed = key + + -- Call original function + if original_keypressed then + original_keypressed(key) + end + end + else + end end --- Main AI update loop (called every frame) --- Monitors game state changes, handles communication, and executes AI actions --- @return nil function AI.update() + -- Check for key press to start/stop RL training + if last_key_pressed then + if last_key_pressed == "r" then + if not rl_training_active then + rl_training_active = true + utils.log_ai("๐Ÿš€ RL Training STARTED (R pressed)") + end + elseif last_key_pressed == "t" then + if rl_training_active then + rl_training_active = false + utils.log_ai("โน๏ธ RL Training STOPPED (T pressed)") + end + end + + -- Clear the key press + last_key_pressed = nil + end + + -- Don't process AI requests unless RL training is active + if not rl_training_active then + return + end + -- Get current game state local current_state = output.get_game_state() local available_actions = action.get_available_actions() @@ -40,29 +82,40 @@ function AI.update() return end - -- Create simple hash to detect state changes + -- Create hash to detect state changes local state_hash = AI.hash_state(current_state) local actions_hash = AI.hash_actions(available_actions) - -- Request action from AI if state or actions changed - if state_hash ~= last_state_hash or actions_hash ~= last_actions_hash then + if state_hash ~= last_state_hash or actions_hash ~= last_actions_hash or need_retry_request then + -- State has changed if state_hash ~= last_state_hash then utils.log_ai("State changed to: " .. - current_state.state .. " (" .. utils.state_name(current_state.state) .. ")") + current_state.state .. " (" .. utils.get_state_name(current_state.state) .. ")") action.reset_state() last_state_hash = state_hash end + -- Available actions have changed if actions_hash ~= last_actions_hash then utils.log_ai("Available actions changed: " .. table.concat(utils.get_action_names(available_actions), ", ")) last_actions_hash = actions_hash end - -- Request action from AI via file I/O - local ai_response = communication.request_action(current_state, available_actions) - if ai_response and ai_response.action ~= "no_action" then - pending_action = ai_response + -- Request action from AI + need_retry_request = false + local ai_response = communication.request_action(current_state, available_actions, retry_count) + + if ai_response then + -- Handling handshake + if ai_response.action == "ready" then + utils.log_ai("Handshake complete - AI ready") + -- Don't return - continue normal game loop + -- Force a state check on next frame to send real request + last_state_hash = nil + else + pending_action = ai_response + end end end @@ -71,8 +124,12 @@ function AI.update() local result = action.execute_action(pending_action.action, pending_action.params) if result.success then utils.log_ai("Action executed successfully: " .. pending_action.action) + retry_count = 0 else - utils.log_ai("Action failed: " .. (result.error or "Unknown error")) + -- Update retry_count to indicate state change + utils.log_ai("Action failed: " .. (result.error or "Unknown error") .. " RETRYING...") + retry_count = retry_count + 1 + need_retry_request = true end pending_action = nil end @@ -83,7 +140,7 @@ end --- @param game_state table Current game state data --- @return string Hash representing the current state function AI.hash_state(game_state) - return game_state.state .. "_" .. (game_state.round or 0) .. "_" .. (game_state.chips or 0) + return game_state.state .. "_" .. (game_state.chips or 0) -- TODO update hash to be more unique end --- Create simple hash of available actions diff --git a/RLBridge/communication.lua b/RLBridge/communication.lua index 539b26e..cbdd273 100644 --- a/RLBridge/communication.lua +++ b/RLBridge/communication.lua @@ -1,124 +1,117 @@ --- RLBridge communication module ---- Handles file-based communication between the game and external AI system +--- Handles dual pipe communication with persistent handles between the game and external AI system local COMM = {} local utils = require("utils") -local action = require("actions") local json = require("dkjson") --- File-based communication settings +-- Dual pipe communication settings with persistent handles local comm_enabled = false -local request_file = "/tmp/balatro_request.json" -local response_file = "/tmp/balatro_response.json" -local request_counter = 0 -local last_response_time = 0 +local request_pipe = "/tmp/balatro_request" +local response_pipe = "/tmp/balatro_response" +local request_handle = nil +local response_handle = nil ---- Check if response file has new content ---- @param expected_id number Expected request ID ---- @return table|nil Response data if new response available, nil otherwise -local function check_for_response(expected_id) - local response_file_handle = io.open(response_file, "r") - if not response_file_handle then - return nil - end - - local response_json = response_file_handle:read("*all") - response_file_handle:close() - - if not response_json or response_json == "" then - return nil - end - - local response_data = json.decode(response_json) - if not response_data then - return nil - end - - -- Check if this is a new response for our request - if response_data.id == expected_id and response_data.timestamp then - if response_data.timestamp > last_response_time then - last_response_time = response_data.timestamp - return response_data - end - end - - return nil -end - ---- Initialize file-based communication ---- Sets up file-based communication channel with external AI system +--- Initialize dual pipe communication with persistent handles +--- Sets up persistent pipe handles with external AI system --- @return nil function COMM.init() - utils.log_comm("Initializing file-based communication...") - comm_enabled = true - last_response_time = 0 -- Reset response tracking - - utils.log_comm("Ready for file-based requests") - utils.log_comm("Request file: " .. request_file) - utils.log_comm("Response file: " .. response_file) + utils.log_comm("Initializing dual pipe communication with persistent handles...") + utils.log_comm("Note: Pipes will be opened when first needed (lazy initialization)") + comm_enabled = true -- Enable communication, pipes will open on first use end ---- Send game turn request to AI and get action via files +--- Lazy initialization of pipe handles when first needed +--- @return boolean True if pipes are ready, false otherwise +function COMM.ensure_pipes_open() + if request_handle and response_handle then + return true -- Already open + end + + -- Try to open pipes (this will block until AI creates them) + utils.log_comm("Opening persistent pipe handles...") + + -- Open response pipe for reading (keep open) + response_handle = io.open(response_pipe, "r") + if not response_handle then + utils.log_comm("ERROR: Cannot open response pipe for reading: " .. response_pipe) + return false + end + + -- Open request pipe for writing (keep open) + request_handle = io.open(request_pipe, "w") + if not request_handle then + utils.log_comm("ERROR: Cannot open request pipe for writing: " .. request_pipe) + if response_handle then + response_handle:close() + response_handle = nil + end + return false + end + + utils.log_comm("Persistent pipe handles opened successfully") + utils.log_comm("Request pipe (write): " .. request_pipe) + utils.log_comm("Response pipe (read): " .. response_pipe) + return true +end + +--- Send game turn request to AI and get action via persistent pipe handles --- @param game_state table Current game state data --- @param available_actions table Available actions list --- @return table|nil Action response from AI, nil if error -function COMM.request_action(game_state, available_actions) +function COMM.request_action(game_state, available_actions, retry_count) if not comm_enabled then + utils.log_comm("ERROR: Communication not enabled") return nil end - request_counter = request_counter + 1 + -- Lazy initialization - open pipes when first needed + if not COMM.ensure_pipes_open() then + utils.log_comm("ERROR: Failed to open pipe handles") + return nil + end local request = { - id = request_counter, - timestamp = os.time(), game_state = game_state, - available_actions = available_actions or {} + available_actions = available_actions or {}, + retry_count = retry_count, } - utils.log_comm("Requesting action for state: " .. tostring(game_state.state)) + utils.log_comm(utils.get_timestamp() .. "Sending action request for state: " .. + tostring(game_state.state) .. " (" .. utils.get_state_name(game_state.state) .. ")") - -- Write request to file + -- Encode request as JSON local json_data = json.encode(request) if not json_data then utils.log_comm("ERROR: Failed to encode request JSON") return nil end - local request_file_handle = io.open(request_file, "w") - if not request_file_handle then - utils.log_comm("ERROR: Cannot write to request file: " .. request_file) + -- Write request to persistent handle + request_handle:write(json_data .. "\n") + request_handle:flush() -- Ensure data is sent immediately + utils.log_comm(utils.get_timestamp() .. "Request sent...") + + -- Read response from persistent handle + utils.log_comm(utils.get_timestamp() .. "about to read the respones") + local response_json = response_handle:read("*line") + + if not response_json or response_json == "" then + utils.log_comm("ERROR: No response received from AI") return nil end - request_file_handle:write(json_data) - request_file_handle:close() - - -- Wait for response file (with timeout) - local max_wait_time = 2 -- seconds - local wait_interval = 0.05 -- seconds - local total_waited = 0 - - while total_waited < max_wait_time do - local response_data = check_for_response(request_counter) - if response_data then - utils.log_comm("AI action: " .. tostring(response_data.action)) - return response_data - end - - -- Sleep for a short time using busy wait - local start_time = os.clock() - while (os.clock() - start_time) < wait_interval do - -- Busy wait for short duration - end - total_waited = total_waited + wait_interval + local response_data = json.decode(response_json) + if not response_data then + utils.log_comm("ERROR: Failed to decode response JSON") + return nil end - utils.log_comm("ERROR: Timeout waiting for AI response") - return nil + utils.log_comm("AI action: " .. tostring(response_data.action)) + return response_data end ---- Check if file communication is enabled +--- Check if pipe communication is enabled --- Returns the current communication status --- @return boolean True if enabled, false otherwise function COMM.is_connected() @@ -126,12 +119,22 @@ function COMM.is_connected() end --- Close communication ---- Terminates the communication channel with the AI system +--- Terminates the persistent pipe handles with the AI system --- @return nil function COMM.close() comm_enabled = false - last_response_time = 0 - utils.log_comm("File communication disabled") + + if request_handle then + request_handle:close() + request_handle = nil + end + + if response_handle then + response_handle:close() + response_handle = nil + end + + utils.log_comm("Persistent pipe communication closed") end return COMM diff --git a/RLBridge/init.lua b/RLBridge/init.lua index 0b88179..f295c7d 100644 --- a/RLBridge/init.lua +++ b/RLBridge/init.lua @@ -15,6 +15,7 @@ G.SETTINGS.reduced_motion = true function INIT.start_run() print("Starting balatro run") + -- Initialize AI system local ai = require("ai") ai.init() diff --git a/RLBridge/output.lua b/RLBridge/output.lua index 2cba9fb..66f9260 100644 --- a/RLBridge/output.lua +++ b/RLBridge/output.lua @@ -4,15 +4,20 @@ local O = {} -local actions = require("actions") - --- Get comprehensive game state for AI --- Collects all relevant game information into a structured format --- @return table Complete game state data for AI processing function O.get_game_state() + -- Calculate game_over + local game_over = 0 + if G.STATE == G.STATES.GAME_OVER then + game_over = 1 + end + return { -- Basic state info state = G.STATE, + game_over = game_over, -- Game progression -- round = G.GAME and G.GAME.round or 0, @@ -27,16 +32,9 @@ function O.get_game_state() -- Hand info (if applicable) hand = O.get_hand_info(), - -- -- Jokers info - -- jokers = O.get_jokers_info(), - - -- -- Chips/score info - -- chips = G.GAME and G.GAME.chips or 0, - -- chip_target = G.GAME and G.GAME.blind and G.GAME.blind.chips or 0, - - -- -- Action feedback for AI - -- last_action_success = true, -- Did last action work? - -- last_action_error = nil, -- What went wrong? + -- Chips/score info + chips = G.GAME and G.GAME.chips or 0, + chip_target = G.GAME and G.GAME.blind and G.GAME.blind.chips or 0, } end diff --git a/RLBridge/utils.lua b/RLBridge/utils.lua index 4385e15..72dc2f6 100644 --- a/RLBridge/utils.lua +++ b/RLBridge/utils.lua @@ -46,7 +46,7 @@ end --- Maps numeric state IDs to their string representations for debugging --- @param state_id number Numeric state identifier --- @return string Human-readable state name or "UNKNOWN_STATE" -function UTIL.state_name(state_id) +function UTIL.get_state_name(state_id) for key, value in pairs(G.STATES) do if value == state_id then return key @@ -133,4 +133,22 @@ function UTIL.get_action_names(available_actions) return names end +function UTIL.contains(tbl, val) + for x, _ in ipairs(tbl) do + if tbl[x] == val then + return true + end + end + return false +end + +--- Get current timestamp as formatted string +--- Returns current time in HH:MM:SS.mmm format for debugging +--- @return string Formatted timestamp +function UTIL.get_timestamp() + local time = os.time() + local ms = (os.clock() % 1) * 1000 + return os.date("%H:%M:%S", time) .. string.format(".%03d", ms) +end + return UTIL diff --git a/ai/agent/__init__.py b/ai/agent/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/ai/agent/balatro_agent.py b/ai/agent/balatro_agent.py deleted file mode 100644 index b18936f..0000000 --- a/ai/agent/balatro_agent.py +++ /dev/null @@ -1,95 +0,0 @@ -""" -Balatro RL Agent -Contains the AI logic for making decisions in Balatro -""" - -from typing import Dict, Any, List -import random -import logging - - -class BalatroAgent: - """Main AI agent for playing Balatro""" - - def __init__(self): - self.logger = logging.getLogger(__name__) - - # TODO: Initialize your RL model here - # self.model = load_model("path/to/model") - # self.training = True - - def get_action(self, request: Dict[str, Any]) -> Dict[str, Any]: - """ - Given game state, return the action to take - - Args: - game_state: Current game state from Balatro mod - - Returns: - Action dictionary to send back to mod - """ - try: - # Extract relevant info from game state - available_actions = request.get('available_actions', []) - game_state = request.get('game_state', {}) - - self.logger.info(f"Processing state {game_state.get('state', 0)} with {len(available_actions)} actions") - - # TODO: Replace with actual RL logic - action = self._get_random_action(available_actions) # TODO we are getting random actions for testing - - # TODO: If training, update model with reward - # if self.training: - # self._update_model(game_state, action, reward) - - return action - - except Exception as e: - self.logger.error(f"Error in get_action: {e}") - return {"action": "no_action"} - - def _get_random_action(self, available_actions: List) -> Dict[str, Any]: - """Temporary random action selection - replace with RL logic""" - - # available_actions comes as [1, 2] format - if not available_actions or not isinstance(available_actions, List): - return {"action": "no_action"} - - # Simple random action selection - selected_action_id = random.choice(available_actions) - self.logger.info(f"Selected action ID: {selected_action_id} from available: {available_actions}") - - # TODO this is for testing, the AI doesn't really care about what the action_names are - # but we do so we can test out actions and make sure they work - # Map action IDs to names for logic (but return the ID) - action_names = {1: "start_run", 2: "select_blind", 3: "select_hand"} - action_name = action_names.get(selected_action_id, "unknown") - - # Pick random cards for now - params = {} - if action_name == "select_hand": - params["card_indices"] = [1, 3, 5] - - # Return the action ID (number), not the name - return {"action": selected_action_id, "params": params} - - def train_step(self, game_state: Dict, action: Dict, reward: float, next_state: Dict): - """ - Perform one training step - - TODO: Implement RL training logic here - - Update Q-values - - Update neural network weights - - Store experience in replay buffer - """ - pass - - def save_model(self, path: str): - """Save the trained model""" - # TODO: Implement model saving - pass - - def load_model(self, path: str): - """Load a trained model""" - # TODO: Implement model loading - pass diff --git a/ai/agent/policy.py b/ai/agent/policy.py deleted file mode 100644 index c8baeba..0000000 --- a/ai/agent/policy.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -Neural network policy for Balatro RL agent -""" - -import torch -import torch.nn as nn -import numpy as np -from typing import Dict, Any - -class BalatroPolicy(nn.Module): - """ - Neural network policy for Balatro decision making - """ - - def __init__(self, state_dim: int = 128, action_dim: int = 64, hidden_dim: int = 256): - super().__init__() - - self.state_dim = state_dim - self.action_dim = action_dim - - # Feature extraction layers - self.feature_net = nn.Sequential( - nn.Linear(state_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, hidden_dim), - nn.ReLU() - ) - - # Action value head - self.value_head = nn.Linear(hidden_dim, 1) - - # Action probability head - self.action_head = nn.Linear(hidden_dim, action_dim) - - def forward(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Forward pass through the network - - Args: - state: Game state tensor - - Returns: - Tuple of (action_logits, state_value) - """ - features = self.feature_net(state) - - action_logits = self.action_head(features) - state_value = self.value_head(features) - - return action_logits, state_value - - def get_action(self, state: np.ndarray, available_actions: list) -> tuple[int, float]: - """ - Get action from policy given current state - - Args: - state: Current game state as numpy array - available_actions: List of available action indices - - Returns: - Tuple of (action_index, action_probability) - """ - with torch.no_grad(): - state_tensor = torch.FloatTensor(state).unsqueeze(0) - action_logits, _ = self.forward(state_tensor) - - # Mask unavailable actions - masked_logits = action_logits.clone() - if available_actions: - mask = torch.full_like(action_logits, float('-inf')) - mask[0, available_actions] = 0 - masked_logits = action_logits + mask - - # Get action probabilities - action_probs = torch.softmax(masked_logits, dim=-1) - - # Sample action - action_dist = torch.distributions.Categorical(action_probs) - action = action_dist.sample() - - return action.item(), action_probs[0, action].item() \ No newline at end of file diff --git a/ai/agent/training.py b/ai/agent/training.py deleted file mode 100644 index 46e071c..0000000 --- a/ai/agent/training.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -Training logic for Balatro RL agent -""" - -import torch -import torch.optim as optim -import numpy as np -from typing import Dict, List, Any -from .policy import BalatroPolicy -import logging - -class BalatroTrainer: - """ - Handles training of the Balatro RL agent - """ - - def __init__(self, policy: BalatroPolicy, learning_rate: float = 3e-4): - self.policy = policy - self.optimizer = optim.Adam(policy.parameters(), lr=learning_rate) - self.logger = logging.getLogger(__name__) - - # Training buffers - self.states = [] - self.actions = [] - self.rewards = [] - self.values = [] - self.log_probs = [] - - def collect_experience(self, state: np.ndarray, action: int, reward: float, - value: float, log_prob: float): - """ - Collect experience for training - - Args: - state: Game state - action: Action taken - reward: Reward received - value: State value estimate - log_prob: Log probability of action - """ - self.states.append(state) - self.actions.append(action) - self.rewards.append(reward) - self.values.append(value) - self.log_probs.append(log_prob) - - def compute_returns(self, next_value: float = 0.0, gamma: float = 0.99) -> List[float]: - """ - Compute discounted returns - - Args: - next_value: Value of next state (for bootstrapping) - gamma: Discount factor - - Returns: - List of discounted returns - """ - returns = [] - R = next_value - - for reward in reversed(self.rewards): - R = reward + gamma * R - returns.insert(0, R) - - return returns - - def update_policy(self, gamma: float = 0.99, entropy_coef: float = 0.01, - value_coef: float = 0.5) -> Dict[str, float]: - """ - Update policy using collected experience - - Args: - gamma: Discount factor - entropy_coef: Entropy regularization coefficient - value_coef: Value loss coefficient - - Returns: - Dictionary of training metrics - """ - if not self.states: - return {} - - # Convert to tensors - states = torch.FloatTensor(np.array(self.states)) - actions = torch.LongTensor(self.actions) - old_log_probs = torch.FloatTensor(self.log_probs) - values = torch.FloatTensor(self.values) - - # Compute returns - returns = self.compute_returns(gamma=gamma) - returns = torch.FloatTensor(returns) - - # Compute advantages - advantages = returns - values - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) - - # Forward pass - action_logits, new_values = self.policy(states) - - # Compute losses - action_dist = torch.distributions.Categorical(logits=action_logits) - new_log_probs = action_dist.log_prob(actions) - entropy = action_dist.entropy().mean() - - # Policy loss (PPO-style, but simplified) - ratio = torch.exp(new_log_probs - old_log_probs) - policy_loss = -(ratio * advantages).mean() - - # Value loss - value_loss = torch.nn.functional.mse_loss(new_values.squeeze(), returns) - - # Total loss - total_loss = policy_loss + value_coef * value_loss - entropy_coef * entropy - - # Update - self.optimizer.zero_grad() - total_loss.backward() - torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5) - self.optimizer.step() - - # Clear buffers - self.clear_buffers() - - # Return metrics - return { - 'policy_loss': policy_loss.item(), - 'value_loss': value_loss.item(), - 'entropy': entropy.item(), - 'total_loss': total_loss.item() - } - - def clear_buffers(self): - """Clear experience buffers""" - self.states.clear() - self.actions.clear() - self.rewards.clear() - self.values.clear() - self.log_probs.clear() - - def save_model(self, path: str): - """Save model checkpoint""" - torch.save({ - 'policy_state_dict': self.policy.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict() - }, path) - self.logger.info(f"Model saved to {path}") - - def load_model(self, path: str): - """Load model checkpoint""" - checkpoint = torch.load(path) - self.policy.load_state_dict(checkpoint['policy_state_dict']) - self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - self.logger.info(f"Model loaded from {path}") \ No newline at end of file diff --git a/ai/environment/balatro_env.py b/ai/environment/balatro_env.py index ad770d3..0c56a22 100644 --- a/ai/environment/balatro_env.py +++ b/ai/environment/balatro_env.py @@ -1,177 +1,190 @@ """ Balatro RL Environment -Wraps the HTTP communication in a gym-like interface for RL training +Wraps the pipe-based communication with Balatro mod in a standard RL interface. +This acts as a translator between Balatro's JSON pipe communication and +RL libraries that expect gym-style step()/reset() methods. """ import numpy as np from typing import Dict, Any, Tuple, List, Optional import logging +import gymnasium as gym +from gymnasium import spaces +from ..utils.debug import tprint -from ..utils.serialization import GameStateValidator +from ..utils.communication import BalatroPipeIO +from .reward import BalatroRewardCalculator +from ..utils.mappers import BalatroStateMapper, BalatroActionMapper -class BalatroEnvironment: +class BalatroEnv(gym.Env): """ - Gym-like environment interface for Balatro RL training + Standard RL Environment wrapper for Balatro - This class will eventually wrap the HTTP communication - and provide a standard RL interface with step(), reset(), etc. + Translates between: + - Balatro mod's JSON pipe communication (/tmp/balatro_request, /tmp/balatro_response) + - Standard RL interface (step, reset, observation spaces) + + This allows RL libraries like Stable-Baselines3 to train on Balatro + without knowing about the underlying pipe communication system. """ def __init__(self): + super().__init__() self.logger = logging.getLogger(__name__) self.current_state = None + self.prev_state = None self.game_over = False - # TODO: Define action and observation spaces - # self.action_space = ... - # self.observation_space = ... - - def reset(self) -> Dict[str, Any]: - """Reset the environment for a new episode""" - self.current_state = None - self.game_over = False + # Initialize communication and reward systems + self.pipe_io = BalatroPipeIO() + self.reward_calculator = BalatroRewardCalculator() + + # Define Gymnasium spaces + # Action Spaces; This should describe the type and shape of the action + # Constants + self.MAX_ACTIONS = 10 + action_selection = np.array([self.MAX_ACTIONS]) + card_indices = np.array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) # Handles up to 10 cards in a hand + self.action_space = spaces.MultiDiscrete(np.concatenate([ + action_selection, + card_indices + ])) + ACTION_SLICE_LAYOUT = [ + ("action_selection", 1), + ("card_indices", 10) + ] + slices = self._build_action_slices(ACTION_SLICE_LAYOUT) - # TODO: Send reset signal to Balatro mod - # This might involve starting a new run - - return self._get_initial_state() + # Observation space: This should describe the type and shape of the observation + # Constants + self.OBSERVATION_SIZE = 70 # you can get this value by running test_env.py. + self.observation_space = spaces.Box( + low=-np.inf, # lowest bound of observation data + high=np.inf, # highest bound of observation data + shape=(self.OBSERVATION_SIZE,), # Adjust based on actual state size which This is a 1D array + dtype=np.float32 # Data type of the numbers + ) + + # Initialize mappers + self.state_mapper = BalatroStateMapper(observation_size=self.OBSERVATION_SIZE, max_actions=self.MAX_ACTIONS) + self.action_mapper = BalatroActionMapper(action_slices=slices) - def step(self, action: Dict[str, Any]) -> Tuple[Dict, float, bool, Dict]: + def reset(self, seed=None, options=None): #TODO add types """ - Take an action in the environment + Reset the environment for a new episode + + In Balatro context, this means starting a new run. + Communicates with Balatro mod via pipes to initiate reset. + + Returns: + Initial observation/game state + """ + self.current_state = None + self.prev_state = None + self.game_over = False + + # Reset reward tracking + self.reward_calculator.reset() + + # Wait for initial request from Balatro (game start) + initial_request = self.pipe_io.wait_for_request() + if not initial_request: + raise RuntimeError("Failed to receive initial request from Balatro") + + # Send dummy response to complete hand shake + success = self.pipe_io.send_response({"action": "ready"}) + if not success: + raise RuntimeError("Failed to complete handshake") + + # Process initial state for SB3 + self.current_state = initial_request + initial_observation = self.state_mapper.process_game_state(self.current_state) + + return initial_observation, {} + + def step(self, action): #TODO add types + """ + Take an action in the Balatro environment + Sends action to Balatro mod via JSON pipe, waits for response, + calculates reward, and returns standard RL step format. Args: - action: Action to take + action: Action dictionary (e.g., {"action": 1, "params": {...}}) Returns: - Tuple of (observation, reward, done, info) + Tuple of (observation, reward, done, info) where: + - observation: Processed game state for neural network + - reward: Calculated reward for this step + - done: Whether episode is finished (game over) + - info: Additional debug information """ - # TODO: Send action to Balatro via HTTP - # TODO: Receive new game state - # TODO: Calculate reward - # TODO: Check if episode is done + # Store previous state for reward calculation + self.prev_state = self.current_state + + # Send action response to Balatro mod + response_data = self.action_mapper.process_action(rl_action=action) - observation = self._process_game_state(self.current_state) - reward = self._calculate_reward() - done = self.game_over - info = {} + success = self.pipe_io.send_response(response_data) + if not success: + raise RuntimeError("Failed to send response to Balatro") - return observation, reward, done, info - - def _get_initial_state(self) -> Dict[str, Any]: - """Get initial game state""" - # TODO: Implement - return {} - - def _process_game_state(self, raw_state: Dict[str, Any]) -> Dict[str, Any]: - """ - Process raw game state into format suitable for RL agent + # Wait for next request with new game state + next_request = self.pipe_io.wait_for_request() + if not next_request: + # If no response, assume game ended + self.game_over = True + observation = self.state_mapper.process_game_state(self.current_state) + reward = 0.0 + return observation, reward, True, {"timeout": True} #TODO this is a bug we need to return 5 values - This might involve: - - Extracting relevant features - - Normalizing values - - Converting to numpy arrays - """ - if not raw_state: - return {} + # Update current state + self.current_state = next_request - # Validate state structure - try: - GameStateValidator.validate_game_state(raw_state) - except ValueError as e: - self.logger.error(f"Invalid game state: {e}") - return {} + # Process new state for SB3 + observation = self.state_mapper.process_game_state(self.current_state) - # TODO: Extract and process features - processed = { - 'hand_features': self._extract_hand_features(raw_state.get('hand', {})), - 'game_features': self._extract_game_features(raw_state), - 'available_actions': raw_state.get('available_actions', []) + # Calculate reward using expert reward calculator + reward = self.reward_calculator.calculate_reward( + current_state=self.current_state, + prev_state=self.prev_state if self.prev_state else {} + ) + + # Check if episode is done + done = bool(self.current_state.get('game_over', 0)) # TODO send a game_over in our state + terminated = done + truncated = False # Not using time limits for now + + info = { + 'balatro_state': self.current_state.get('state', 0), + 'available_actions': next_request.get('available_actions', []) } + # TODO have a feeling observation is not being return when we run out of actions which causes problems + return observation, reward, terminated, truncated, info + + def cleanup(self): + """ + Clean up environment resources - return processed - - def _extract_hand_features(self, hand: Dict[str, Any]) -> np.ndarray: - """Extract numerical features from hand""" - # TODO: Convert hand cards to numerical representation - # This might include: - # - Card values (2-14 for 2-A) - # - Suits (0-3 for different suits) - # - Special abilities/bonuses + Call this when shutting down to clean up pipe communication. + """ + self.pipe_io.cleanup() + + @staticmethod + def _build_action_slices(layout: List[Tuple[str, int]]) -> Dict[str, slice]: + """ + Create slices for our actions so that we can precisely extract the + right params to send over to balatro - cards = hand.get('cards', []) - features = [] - - for card in cards: - # Extract card features - suit_encoding = self._encode_suit(card.get('suit', '')) - value_encoding = self._encode_value(card.get('base', {}).get('value', '')) - nominal = card.get('base', {}).get('nominal', 0) - - # Add ability features - ability = card.get('ability', {}) - ability_features = [ - ability.get('t_chips', 0), - ability.get('t_mult', 0), - ability.get('x_mult', 1), - ability.get('mult', 0) - ] - - card_features = [suit_encoding, value_encoding, nominal] + ability_features - features.extend(card_features) - - # Pad or truncate to fixed size (e.g., 8 cards max * 7 features each) - max_cards = 8 - features_per_card = 7 - - if len(features) < max_cards * features_per_card: - features.extend([0] * (max_cards * features_per_card - len(features))) - else: - features = features[:max_cards * features_per_card] - - return np.array(features, dtype=np.float32) - - def _extract_game_features(self, state: Dict[str, Any]) -> np.ndarray: - """Extract game-level features""" - # TODO: Extract features like: - # - Current chips - # - Target chips - # - Money - # - Round/ante - # - Discards remaining - # - Hands remaining - - features = [ - state.get('state', 0), # Game state - len(state.get('available_actions', [])), # Number of available actions - ] - - return np.array(features, dtype=np.float32) - - def _encode_suit(self, suit: str) -> int: - """Encode suit as integer""" - suit_map = {'Hearts': 0, 'Diamonds': 1, 'Clubs': 2, 'Spades': 3} - return suit_map.get(suit, 0) - - def _encode_value(self, value: str) -> int: - """Encode card value as integer""" - if value.isdigit(): - return int(value) - - value_map = { - 'Jack': 11, 'Queen': 12, 'King': 13, 'Ace': 14, - 'A': 14, 'K': 13, 'Q': 12, 'J': 11 - } - return value_map.get(value, 0) - - def _calculate_reward(self) -> float: - """Calculate reward for current state""" - # TODO: Implement reward calculation - # This might be based on: - # - Chips scored - # - Blind completion - # - Round progression - # - Final score - - return 0.0 \ No newline at end of file + Args: + layout: Our ACTION_SLICE_LAYOUT that contains action name and size + Return: + A dictionary containing a key being our action space slice name, and + the slice + """ + slices = {} + start = 0 + for action_name, size in layout: + slices[action_name] = slice(start, start + size) + start += size + return slices diff --git a/ai/environment/reward.py b/ai/environment/reward.py index 1f65d78..3c53240 100644 --- a/ai/environment/reward.py +++ b/ai/environment/reward.py @@ -1,5 +1,11 @@ """ -Reward calculation for Balatro RL environment +Balatro Reward System - The Expert on What's Good/Bad + +This module is the single source of truth for reward calculation in Balatro RL. +All reward logic is centralized here to make experimentation and tuning easier. + +The BalatroRewardCalculator analyzes game state changes and assigns rewards +that teach the AI what constitutes good vs bad Balatro gameplay. """ from typing import Dict, Any @@ -7,114 +13,86 @@ import numpy as np class BalatroRewardCalculator: """ - Calculates rewards for Balatro RL training + Expert reward calculator for Balatro RL training + + This is the single authority on what constitutes good/bad play in Balatro. + Centralizes all reward logic for easy experimentation and tuning. + + Reward philosophy: + - Positive rewards for progress (chips, rounds, money) + - Bonus rewards for efficiency (fewer hands/discards used) + - Large rewards for major milestones (completing antes) + - Negative rewards for game over (based on progress made) """ def __init__(self): - self.prev_score = 0 - self.prev_money = 0 - self.prev_round = 0 - self.prev_ante = 0 + self.chips = 0 def calculate_reward(self, current_state: Dict[str, Any], prev_state: Dict[str, Any] = None) -> float: """ - Calculate reward based on game state changes + Main reward calculation method - analyzes state changes and assigns rewards + + This is the core method that determines what the AI should optimize for. + Examines differences between previous and current game state to calculate + appropriate rewards for the action that caused the transition. Args: - current_state: Current game state - prev_state: Previous game state (optional) + current_state: Current Balatro game state + prev_state: Previous Balatro game state (None for first step) Returns: - Calculated reward + Float reward value (positive = good, negative = bad, zero = neutral) """ reward = 0.0 # Extract relevant metrics - current_score = current_state.get('score', 0) - current_money = current_state.get('money', 0) - current_round = current_state.get('round', 0) - current_ante = current_state.get('ante', 0) - game_over = current_state.get('game_over', False) + current_chips = current_state.get('chips', 0) + game_over = current_state.get('game_over', False) #TODO fix # Score-based rewards - score_diff = current_score - self.prev_score - if score_diff > 0: - reward += score_diff * 0.001 # Small reward for score increases - - # Money-based rewards - money_diff = current_money - self.prev_money - if money_diff > 0: - reward += money_diff * 0.01 # Reward for gaining money - - # Round progression rewards - if current_round > self.prev_round: - reward += 10.0 # Big reward for completing rounds - - # Ante progression rewards - if current_ante > self.prev_ante: - reward += 50.0 # Very big reward for completing antes + chip_diff = current_chips - self.chips + if chip_diff > 0: + reward += chip_diff * 0.001 # Small reward for score increases # Game over penalty if game_over: # Penalty based on how early the game ended - max_ante = 8 # Typical max ante in Balatro - completion_ratio = current_ante / max_ante - reward += completion_ratio * 100.0 # Reward based on progress + # max_ante = 8 # Typical max ante in Balatro + # completion_ratio = current_ante / max_ante + # reward += completion_ratio * 100.0 # Reward based on progress + reward -= 10 #TODO keeping this simple for now - # Efficiency rewards (optional) - # reward += self._calculate_efficiency_reward(current_state) - # Update previous state - self.prev_score = current_score - self.prev_money = current_money - self.prev_round = current_round - self.prev_ante = current_ante - - return reward - - def _calculate_efficiency_reward(self, state: Dict[str, Any]) -> float: - """ - Calculate rewards for efficient play - - Args: - state: Current game state - - Returns: - Efficiency reward - """ - reward = 0.0 - - # Reward for using fewer discards - discards_remaining = state.get('discards', 0) - if discards_remaining > 0: - reward += discards_remaining * 0.1 - - # Reward for using fewer hands - hands_remaining = state.get('hands', 0) - if hands_remaining > 0: - reward += hands_remaining * 0.5 + self.chips = current_chips return reward def reset(self): - """Reset reward calculator for new episode""" - self.prev_score = 0 - self.prev_money = 0 - self.prev_round = 0 - self.prev_ante = 0 + """ + Reset reward calculator state for new episode + + Called at the start of each new Balatro run to clear + previous state tracking variables. + """ + self.chips = 0 def get_shaped_reward(self, state: Dict[str, Any], action: str) -> float: """ - Get shaped reward based on specific actions + Calculate action-specific reward shaping + + Provides small rewards/penalties for specific actions to guide + AI behavior beyond just game state changes. Use sparingly to + avoid overriding the main reward signal. Args: - state: Current game state - action: Action taken + state: Current game state when action was taken + action: Action that was taken (for action-specific rewards) Returns: - Shaped reward + Small shaped reward value (usually -1 to +1) """ + # TODO look at this this could help but currently not in a working state reward = 0.0 # Action-specific rewards @@ -134,4 +112,4 @@ class BalatroRewardCalculator: # Neutral - depends on if it's a good purchase reward += 0.0 - return reward \ No newline at end of file + return reward diff --git a/ai/file_watcher.py b/ai/file_watcher.py deleted file mode 100755 index 0f10509..0000000 --- a/ai/file_watcher.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -""" -File-based communication handler for Balatro RL -Watches for request files and responds with AI actions -""" - -import json -import time -import logging -from pathlib import Path -from watchdog.observers import Observer -from watchdog.events import FileSystemEventHandler - -from agent.balatro_agent import BalatroAgent - -# File paths -REQUEST_FILE = "/tmp/balatro_request.json" -RESPONSE_FILE = "/tmp/balatro_response.json" - -class RequestHandler(FileSystemEventHandler): - """Handles incoming request files from Balatro""" - - def __init__(self): - self.agent = BalatroAgent() - self.logger = logging.getLogger(__name__) - - def on_created(self, event): - """Handle new request file creation""" - if event.src_path == REQUEST_FILE and not event.is_directory: - self.process_request() - - def on_modified(self, event): - """Handle request file modification""" - if event.src_path == REQUEST_FILE and not event.is_directory: - self.process_request() - - def process_request(self): - """Process a request file and generate response""" - try: - # Small delay to ensure file write is complete - time.sleep(0.01) - - # Read request - with open(REQUEST_FILE, 'r') as f: - request_data = json.load(f) - - self.logger.info(f"Processing request {request_data.get('id')}: state={request_data.get('state')}") - - # Get action from AI agent - action_response = self.agent.get_action(request_data) - - # Write response - response = { - "id": request_data.get("id"), - "action": action_response.get("action", "no_action"), - "params": action_response.get("params"), - "timestamp": time.time() - } - - with open(RESPONSE_FILE, 'w') as f: - json.dump(response, f) - - self.logger.info(f"Response written: {response['action']}") - - - except Exception as e: - self.logger.error(f"Error processing request: {e}") - - # Write error response - error_response = { - "id": request_data.get("id", 0) if 'request_data' in locals() else 0, - "action": "no_action", - "params": {"error": str(e)}, - "timestamp": time.time() - } - - try: - with open(RESPONSE_FILE, 'w') as f: - json.dump(error_response, f) - except: - pass - -def main(): - """Main file watcher loop""" - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' - ) - - logger = logging.getLogger(__name__) - logger.info("Starting Balatro RL file watcher...") - - # Set up file watcher - event_handler = RequestHandler() - observer = Observer() - observer.schedule(event_handler, path="/tmp", recursive=False) - - observer.start() - logger.info(f"Watching for requests at: {REQUEST_FILE}") - - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - logger.info("Stopping file watcher...") - observer.stop() - - observer.join() - -if __name__ == "__main__": - main() diff --git a/ai/requirements.txt b/ai/requirements.txt index 189933c..a7d3925 100644 --- a/ai/requirements.txt +++ b/ai/requirements.txt @@ -1,6 +1,5 @@ numpy>=1.21.0 torch>=1.9.0 gymnasium>=0.26.0 # For RL environment interface -stable-baselines3>=1.6.0 # For RL algorithms +stable-baselines3[extra]>=1.6.0 # For RL algorithms tensorboard>=2.8.0 # For training visualization -watchdog>=2.1.0 # For file system monitoring \ No newline at end of file diff --git a/ai/setup.sh b/ai/setup.sh index fdcebd9..3932e8c 100755 --- a/ai/setup.sh +++ b/ai/setup.sh @@ -22,5 +22,5 @@ echo "Setup complete!" echo "" echo "Usage:" echo " To activate environment: source venv/bin/activate" -echo " To run file watcher: python file_watcher.py" -echo " To train agent: python -m agent.training" \ No newline at end of file +echo " To test environment is working: python -m ai.test_env" +echo " To train agent: python -m ai.train_balatro" diff --git a/ai/test_env.py b/ai/test_env.py new file mode 100644 index 0000000..f768af5 --- /dev/null +++ b/ai/test_env.py @@ -0,0 +1,96 @@ +""" +Environment Testing Script for Balatro RL + +Tests the BalatroEnv to ensure it follows the Gym interface and +communicates properly with the Balatro mod via file I/O. + +Usage: + python test_env.py +""" + +import time +import logging +# from stable_baselines3.common.env_checker import check_env # Not compatible with pipes +from .environment.balatro_env import BalatroEnv + + +def test_manual_actions(): + """Manual testing with random actions""" + print("\n=== Manual Action Testing ===") + + env = BalatroEnv() + print("BalatroEnv created successfully") + + try: + # Reset and get initial observation (ONLY ONCE!) + print("Calling env.reset() - this should wait for Balatro...") + obs, info = env.reset() + print("env.reset() completed!") + print(f"Initial observation: {obs}") + print(f"Observation space: {env.observation_space}") + print(f"Action space: {env.action_space}") + print(f"Sample action: {env.action_space.sample()}") + + # Run test episodes + max_steps = 3 + max_episodes = 1 + + for episode in range(max_episodes): + print(f"\n--- Episode {episode + 1} ---") + # Don't call reset() again! Use the obs from above + + for step in range(max_steps): + # Get random action (like the old balatro_agent) + action = env.action_space.sample() + + print(f"Step {step + 1}: Taking action {action}") + + # Take action + obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + + print(f" obs={obs}, reward={reward}, done={done}") + + # Optional: Add delay for readability + time.sleep(0.1) + + if done: + print(f" Episode finished after {step + 1} steps") + break + + if not done: + print(f" Episode reached max steps ({max_steps})") + + except Exception as e: + print(f"โœ— Manual testing failed: {e}") + return False + finally: + env.close() + + print("โœ“ Manual testing completed!") + return True + + +def main(): + """Run all tests""" + print("Starting Balatro Environment Testing...") + + # Setup logging to see communication details + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + success = True + success &= test_manual_actions() + + if success: + print("\n๐ŸŽ‰ All tests passed! Environment is ready for training.") + else: + print("\nโŒ Some tests failed. Check the environment implementation.") + + return success + + +if __name__ == "__main__": + main() diff --git a/ai/train_balatro.py b/ai/train_balatro.py new file mode 100755 index 0000000..ed396e5 --- /dev/null +++ b/ai/train_balatro.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +""" +Balatro RL Training Script + +Main training script for teaching AI to play Balatro using Stable-Baselines3. +This script creates the Balatro environment, sets up the RL model, and runs training. + +Usage: + python train_balatro.py + +Requirements: + - Balatro game running with RLBridge mod + - file_watcher.py should NOT be running (this replaces it) +""" + +import logging +import time +from pathlib import Path + +# SB3 imports +from stable_baselines3 import DQN, A2C, PPO +from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback +from stable_baselines3.common.monitor import Monitor + +# Our custom environment +from .environment.balatro_env import BalatroEnv + + +def setup_logging(): + """Setup logging for training""" + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler('training.log'), + logging.StreamHandler() + ] + ) + + +def create_environment(): + """Create and wrap the Balatro environment""" + # Create base environment + env = BalatroEnv() + + # Wrap with Monitor for logging episode stats + env = Monitor(env, filename="training_monitor.csv") + + return env + + +def create_model(env, algorithm="DQN", model_path=None): + """ + Create SB3 model for training + + Args: + env: Balatro environment + algorithm: RL algorithm to use ("DQN", "A2C", "PPO") + model_path: Path to load existing model (optional) + + Returns: + SB3 model ready for training + """ + if algorithm == "DQN": + model = DQN( + "MlpPolicy", + env, + verbose=1, + learning_rate=1e-4, + buffer_size=10000, + learning_starts=1000, + target_update_interval=500, + train_freq=4, + gradient_steps=1, + exploration_fraction=0.1, + exploration_initial_eps=1.0, + exploration_final_eps=0.05, + tensorboard_log="./tensorboard_logs/" + ) + elif algorithm == "A2C": + model = A2C( + "MlpPolicy", + env, + verbose=1, + learning_rate=7e-4, + n_steps=5, + gamma=0.99, + tensorboard_log="./tensorboard_logs/" + ) + elif algorithm == "PPO": + model = PPO( + "MlpPolicy", + env, + verbose=1, + learning_rate=3e-4, + n_steps=2048, + batch_size=64, + n_epochs=10, + gamma=0.99, + tensorboard_log="./tensorboard_logs/" + ) + else: + raise ValueError(f"Unknown algorithm: {algorithm}") + + # Load existing model if path provided + if model_path and Path(model_path).exists(): + model.load(model_path) + print(f"Loaded existing model from {model_path}") + + return model + + +def create_callbacks(save_freq=1000): + """Create training callbacks for saving and evaluation""" + callbacks = [] + + # Checkpoint callback - save model periodically + checkpoint_callback = CheckpointCallback( + save_freq=save_freq, + save_path="./models/", + name_prefix="balatro_model" + ) + callbacks.append(checkpoint_callback) + + return callbacks + + +def train_agent(total_timesteps=50000, algorithm="DQN", save_path="./models/balatro_final"): + """ + Main training function + + Args: + total_timesteps: Number of training steps + algorithm: RL algorithm to use + save_path: Where to save final model + """ + logger = logging.getLogger(__name__) + logger.info(f"Starting Balatro RL training with {algorithm}") + logger.info(f"Training for {total_timesteps} timesteps") + + try: + # Create environment and model + env = create_environment() + model = create_model(env, algorithm) + + # Create callbacks + callbacks = create_callbacks(save_freq=max(1000, total_timesteps // 20)) + + # Train the model + logger.info("Starting training...") + start_time = time.time() + + model.learn( + total_timesteps=total_timesteps, + callback=callbacks, + progress_bar=True + ) + + training_time = time.time() - start_time + logger.info(f"Training completed in {training_time:.2f} seconds") + + # Save final model + model.save(save_path) + logger.info(f"Model saved to {save_path}") + + # Clean up environment + if hasattr(env, 'cleanup'): + env.cleanup() + elif hasattr(env.env, 'cleanup'): # Monitor wrapper + env.env.cleanup() + + return model + + except KeyboardInterrupt: + logger.info("Training interrupted by user") + if hasattr(env, 'cleanup'): + env.cleanup() + elif hasattr(env.env, 'cleanup'): + env.env.cleanup() + return None + except Exception as e: + logger.error(f"Training failed: {e}") + if hasattr(env, 'cleanup'): + env.cleanup() + elif hasattr(env.env, 'cleanup'): + env.env.cleanup() + raise + + +def test_trained_model(model_path, num_episodes=5): + """ + Test a trained model + + Args: + model_path: Path to trained model + num_episodes: Number of episodes to test + """ + logger = logging.getLogger(__name__) + logger.info(f"Testing model from {model_path}") + + # Create environment and load model + env = create_environment() + model = DQN.load(model_path) + + episode_rewards = [] + + for episode in range(num_episodes): + obs = env.reset() + total_reward = 0 + steps = 0 + + while True: + action, _states = model.predict(obs, deterministic=True) + obs, reward, done, info = env.step(action) + total_reward += reward + steps += 1 + + if done: + break + + episode_rewards.append(total_reward) + logger.info(f"Episode {episode + 1}: {steps} steps, reward: {total_reward:.2f}") + + avg_reward = sum(episode_rewards) / len(episode_rewards) + logger.info(f"Average reward over {num_episodes} episodes: {avg_reward:.2f}") + + env.cleanup() + return episode_rewards + + +if __name__ == "__main__": + # Setup + setup_logging() + + # Create necessary directories + Path(".models").mkdir(exist_ok=True) + Path(".tensorboard_logs").mkdir(exist_ok=True) + + # We'll let BalatroEnv create the pipes when needed + print("๐Ÿ”ง Pipe communication will be initialized with environment...") + + # Train the agent + print("\n๐ŸŽฎ Starting Balatro RL Training!") + print("Setup steps:") + print("1. โœ… Balatro is running with RLBridge mod") + print("2. โœ… Balatro is in menu state") + print("3. Press Enter below to start training") + print("4. When prompted, press 'R' in Balatro within 3 seconds") + print("\n๐Ÿ“ก Training will create pipes and connect automatically!") + print() + + input("Press Enter to start training (then you'll have 3 seconds to press 'R' in Balatro)...") + + try: + model = train_agent( + total_timesteps=10000, # Start small for testing + algorithm="PPO", # PPO supports MultiDiscrete actions + save_path="./models/balatro_trained" + ) + + if model: + print("\n๐ŸŽ‰ Training completed successfully!") + print("Testing the trained model...") + + # Test the trained model + test_trained_model("./models/balatro_trained", num_episodes=3) + + except Exception as e: + print(f"\nโŒ Training failed: {e}") + print("Check the logs for more details.") + finally: + # Pipes will be cleaned up by the environment + print("๐Ÿงน Training session ended") diff --git a/ai/utils/communication.py b/ai/utils/communication.py new file mode 100644 index 0000000..65b21b5 --- /dev/null +++ b/ai/utils/communication.py @@ -0,0 +1,192 @@ +""" +Balatro Dual Pipe Communication Abstraction + +Pure dual pipe I/O layer with persistent handles for communicating with Balatro mod. +Handles low-level pipe operations without any game logic or AI logic. + +This abstraction can be used by: +- balatro_env.py (for SB3 training) +- file_watcher.py (for testing) +- Any other component that needs to talk to Balatro +""" + +import json +import logging +import os +from typing import Dict, Any, Optional + + +class BalatroPipeIO: + """ + Clean abstraction for dual pipe communication with Balatro mod using persistent handles + + Responsibilities: + - Read JSON requests from Balatro mod via request pipe + - Write JSON responses back to Balatro mod via response pipe + - Keep pipe handles open persistently to avoid deadlocks + - Provide simple, clean interface for higher-level code + """ + + def __init__(self, request_pipe: str = "/tmp/balatro_request", response_pipe: str = "/tmp/balatro_response"): + self.request_pipe = request_pipe + self.response_pipe = response_pipe + self.logger = logging.getLogger(__name__) + + # Persistent pipe handles + self.request_handle = None + self.response_handle = None + + # Create pipes and open persistent handles + self.create_pipes() + self.open_persistent_handles() + + def create_pipes(self) -> None: + """ + Create dual named pipes for communication + + Creates both request and response pipes if they don't exist. + Safe to call multiple times. + """ + for pipe_path in [self.request_pipe, self.response_pipe]: + try: + # Remove existing pipe if it exists + if os.path.exists(pipe_path): + os.unlink(pipe_path) + self.logger.debug(f"Removed existing pipe: {pipe_path}") + + # Create named pipe + os.mkfifo(pipe_path) + self.logger.info(f"Created pipe: {pipe_path}") + + except Exception as e: + self.logger.error(f"Failed to create pipe {pipe_path}: {e}") + raise RuntimeError(f"Could not create communication pipe: {pipe_path}") + + def open_persistent_handles(self) -> None: + """ + Open persistent handles for reading and writing + + Opens request pipe for reading and response pipe for writing. + Keeps handles open to avoid deadlocks. + """ + import time + + try: + self.logger.info(f"๐Ÿ”ง Waiting for Balatro to connect...") + self.logger.info(f" Press 'R' in Balatro now to activate RL training!") + + # Open request pipe for reading (Balatro writes to this) + self.request_handle = open(self.request_pipe, 'r') + + # Open response pipe for writing (AI writes to this) + self.response_handle = open(self.response_pipe, 'w') + + except Exception as e: + self.logger.error(f"Failed to open persistent handles: {e}") + self.cleanup_handles() + raise RuntimeError(f"Could not open persistent pipe handles: {e}") + + def wait_for_request(self) -> Optional[Dict[str, Any]]: + """ + Wait for new request from Balatro mod using persistent handle + + Blocks until Balatro writes a request to the pipe. + No timeout needed - pipes block until data arrives. + + Returns: + Parsed JSON request data, or None if error + """ + if not self.request_handle: + self.logger.error("Request handle not open") + return None + + try: + # Read from persistent request handle + request_line = self.request_handle.readline().strip() + if not request_line: + return None + + request_data = json.loads(request_line) + self.logger.info(f"๐Ÿ“ฅ RECEIVED REQUEST: {request_line}...") # Show first 100 chars + return request_data + + except json.JSONDecodeError as e: + self.logger.error(f"Invalid JSON in request: {e}") + return None + except Exception as e: + self.logger.error(f"Error reading from request pipe: {e}") + return None + + def send_response(self, response_data: Dict[str, Any]) -> bool: + """ + Send response back to Balatro mod using persistent handle + + Writes response data to response pipe for Balatro to read. + + Args: + response_data: Response dictionary to send + + Returns: + True if successful, False if error + """ + if not self.response_handle: + self.logger.error("Response handle not open") + return False + + try: + # Write to persistent response handle + json.dump(response_data, self.response_handle) + self.response_handle.write('\n') # Important: newline for pipe communication + self.response_handle.flush() # Force write to pipe immediately + + self.logger.info(f"๐Ÿ“ค SENT RESPONSE: {json.dumps(response_data)}") + return True + + except Exception as e: + self.logger.error(f"โŒ ERROR sending response: {e}") + import traceback + self.logger.error(f"โŒ TRACEBACK: {traceback.format_exc()}") + return False + + def cleanup_handles(self): + """ + Close persistent pipe handles + + Closes the open file handles. + """ + if self.request_handle: + try: + self.request_handle.close() + self.logger.debug("Closed request handle") + except Exception as e: + self.logger.warning(f"Failed to close request handle: {e}") + self.request_handle = None + + if self.response_handle: + try: + self.response_handle.close() + self.logger.debug("Closed response handle") + except Exception as e: + self.logger.warning(f"Failed to close response handle: {e}") + self.response_handle = None + + def cleanup(self): + """ + Clean up communication pipes and handles + + Closes handles and removes pipe files from the filesystem. + """ + # Close handles first + self.cleanup_handles() + + # Remove pipe files + for pipe_path in [self.request_pipe, self.response_pipe]: + try: + if os.path.exists(pipe_path): + os.unlink(pipe_path) + self.logger.debug(f"Removed pipe: {pipe_path}") + except Exception as e: + self.logger.warning(f"Failed to remove pipe {pipe_path}: {e}") + + self.logger.debug("Dual pipe communication cleanup complete") + diff --git a/ai/utils/debug.py b/ai/utils/debug.py new file mode 100644 index 0000000..429fc4a --- /dev/null +++ b/ai/utils/debug.py @@ -0,0 +1,39 @@ +""" +Debug utilities for timestamped print statements and debugging helpers +""" + +import datetime + + +def timestamp() -> str: + """ + Get current timestamp formatted as HH:MM:SS.mmm + + Returns: + str: Formatted timestamp string + """ + now = datetime.datetime.now() + return now.strftime("%H:%M:%S.%f")[:-3] # Remove last 3 digits to get milliseconds + + +def tprint(*args, **kwargs): + """ + Print with timestamp prefix + + Args: + *args: Arguments to print + **kwargs: Keyword arguments for print() + """ + print(f"[{timestamp()}]", *args, **kwargs) + + +def dprint(prefix: str, *args, **kwargs): + """ + Debug print with custom prefix and timestamp + + Args: + prefix: Custom prefix string + *args: Arguments to print + **kwargs: Keyword arguments for print() + """ + print(f"[{timestamp()}] [{prefix}]", *args, **kwargs) \ No newline at end of file diff --git a/ai/utils/logging.py b/ai/utils/logging.py deleted file mode 100644 index cead171..0000000 --- a/ai/utils/logging.py +++ /dev/null @@ -1,74 +0,0 @@ -""" -Logging setup for Balatro RL system -""" - -import logging -import sys -from pathlib import Path -from typing import Optional - -def setup_logging( - level: str = "INFO", - log_file: Optional[str] = None, - format_string: Optional[str] = None -) -> None: - """ - Setup logging configuration - - Args: - level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) - log_file: Optional log file path - format_string: Optional custom format string - """ - if format_string is None: - format_string = ( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - - # Convert string level to logging constant - numeric_level = getattr(logging, level.upper(), None) - if not isinstance(numeric_level, int): - raise ValueError(f'Invalid log level: {level}') - - # Create formatter - formatter = logging.Formatter(format_string) - - # Setup root logger - root_logger = logging.getLogger() - root_logger.setLevel(numeric_level) - - # Remove existing handlers - for handler in root_logger.handlers[:]: - root_logger.removeHandler(handler) - - # Console handler - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(numeric_level) - console_handler.setFormatter(formatter) - root_logger.addHandler(console_handler) - - # File handler (optional) - if log_file: - log_path = Path(log_file) - log_path.parent.mkdir(parents=True, exist_ok=True) - - file_handler = logging.FileHandler(log_file) - file_handler.setLevel(numeric_level) - file_handler.setFormatter(formatter) - root_logger.addHandler(file_handler) - - # Set specific logger levels - logging.getLogger("uvicorn").setLevel(logging.INFO) - logging.getLogger("fastapi").setLevel(logging.INFO) - -def get_logger(name: str) -> logging.Logger: - """ - Get a logger with the specified name - - Args: - name: Logger name - - Returns: - Logger instance - """ - return logging.getLogger(name) \ No newline at end of file diff --git a/ai/utils/mappers.py b/ai/utils/mappers.py new file mode 100644 index 0000000..e48ae69 --- /dev/null +++ b/ai/utils/mappers.py @@ -0,0 +1,245 @@ +""" +Mappers for converting between Balatro game data and RL-compatible formats. + +This module handles the two-way data transformation: +1. BalatroStateMapper: Converts incoming Balatro JSON to normalized RL observations +2. BalatroActionMapper: Converts RL actions to Balatro command JSON +""" + +import numpy as np +import time +from typing import Dict, List, Any +from ..utils.validation import GameStateValidator, ResponseValidator +import logging + + + +class BalatroStateMapper: + """ + Converts raw Balatro game state JSON to normalized RL observations. + + Handles: + - Card data normalization + - Game state parsing + - Observation space formatting + """ + def __init__(self, observation_size: int, max_actions: int): + self.observation_size = observation_size + self.max_actions = max_actions + + # Logger + self.logger = logging.getLogger(__name__) + + # GameStateValidator + self.game_state_validator = GameStateValidator() + + #TODO review Might be something wrong here + def process_game_state(self, raw_state: Dict[str, Any] | None) -> np.ndarray: + """ + Convert Balatro's raw JSON game state into neural network input format + converting into standardized numerical arrays that neural networks can + process. + + Args: + raw_state: Raw game state from Balatro mod JSON + + Returns: + Processed state suitable for RL training + """ + if not raw_state: + return np.zeros(self.observation_size, dtype=np.float32) # TODO maybe raise error? + + # Validate game state request + try: + self.game_state_validator.validate_game_state(raw_state) + except ValueError as e: + self.logger.error(f"Invalid game state: {e}") + + hand_features = self._extract_hand_features(raw_state.get('hand', {})) + game_features = self._extract_game_features(raw_state) + available_actions = self._extract_available_actions(raw_state.get('available_actions', [])) + chips = self._extract_chip_features(raw_state) + # TODO extract state + + return np.concatenate([hand_features, game_features, available_actions, chips]) + + def _extract_available_actions(self, available_actions: List[int]) -> np.ndarray: + """ + Convert available actions into a np array + Args: + available_actions: Available actions list from Balatro game state + Returns: + Fixed-size numpy array of hand features + """ + mask = np.zeros(self.max_actions, dtype=np.float32) + for action_id in available_actions: + mask[action_id] = 1.0 + return mask + + def _extract_hand_features(self, hand: Dict[str, Any]) -> np.ndarray: + """ + Convert Balatro hand data into numerical features for neural network + + Transforms card dictionaries into fixed-size numerical arrays: + - Card values: 2-14 (2 through Ace) + - Suits: 0-3 (Hearts, Diamonds, Clubs, Spades) + - Card abilities: chips, mult, special effects + - Pads/truncates to fixed hand size for consistent input + + Args: + hand: Hand dictionary from Balatro game state + + Returns: + Fixed-size numpy array of hand features + """ + + cards = hand.get('cards', []) + features = [] + + for card in cards: + # Extract card features + suit_encoding = self._encode_suit(card.get('suit', '')) + value_encoding = self._encode_value(card.get('base', {}).get('value', '')) + nominal = card.get('base', {}).get('nominal', 0) + + # Add ability features + ability = card.get('ability', {}) + ability_features = [ + ability.get('t_chips', 0), + ability.get('t_mult', 0), + ability.get('x_mult', 1), + ability.get('mult', 0) + ] + + card_features = [suit_encoding, value_encoding, nominal] + ability_features + features.extend(card_features) + + # Pad or truncate to fixed size (e.g., 8 cards max * 7 features each) + # TODO hand.get('size') + # TODO hand.get('highlighted_count') + max_cards = 8 + features_per_card = 7 + + if len(features) < max_cards * features_per_card: + features.extend([0] * (max_cards * features_per_card - len(features))) + else: + features = features[:max_cards * features_per_card] + + return np.array(features, dtype=np.float32) + + def _extract_game_features(self, state: Dict[str, Any]) -> np.ndarray: + """ + Extract numerical game-level features from Balatro state + + Converts game metadata into neural network inputs: + - Current game state (menu, selecting hand, etc.) + - Available actions count + - Money, chips, round progression + - Remaining hands/discards + + Args: + state: Full Balatro game state dictionary + + Returns: + Numpy array of normalized game features + """ + + features = [ + state.get('state', 0), # Game state + len(state.get('available_actions', [])), # Number of available actions + ] + + return np.array(features, dtype=np.float32) + + def _extract_chip_features(self, state: Dict[str, Any]) -> np.ndarray: + """ + Extract information relating to chips + + Args: + state: Full Balatro game state dictionary + Returns: + Numpy array of normalized chip features + """ + features = [ + state.get('chips', 0), + state.get('chip_target', 0) + ] + + return np.array(features, dtype=np.float32) + + def _encode_suit(self, suit: str) -> int: + """Encode suit as integer""" + suit_map = {'Hearts': 0, 'Diamonds': 1, 'Clubs': 2, 'Spades': 3} + return suit_map.get(suit, 0) + + def _encode_value(self, value: str) -> int: + """Encode card value as integer""" + if value.isdigit(): + return int(value) + + value_map = { + 'Jack': 11, 'Queen': 12, 'King': 13, 'Ace': 14, + 'A': 14, 'K': 13, 'Q': 12, 'J': 11 + } + return value_map.get(value, 0) + + + +class BalatroActionMapper: + """ + Converts RL actions to Balatro command JSON. + + Handles: + - Binary action conversion to card indices + - Action validation + - JSON response formatting + """ + + def __init__(self, action_slices: Dict[str, slice]): + self.slices = action_slices + + # Validator + self.response_validator = ResponseValidator() + + # Logger + self.logger = logging.getLogger(__name__) + + def process_action(self, rl_action: np.ndarray) -> Dict[str, Any]: + """ + Convert RL action to Balatro JSON. + + Args: + rl_action: Binary action array from RL agent + game_state: Current game state for validation + + Returns: + JSON response formatted for Balatro mod + """ + ai_action = rl_action[self.slices["action_selection"]].tolist()[0] + response_data = { + "action": ai_action, + "params": self._extract_select_hand_params(rl_action), + } + self.response_validator.validate_response(response_data) + + # Validate action structure + try: + self.response_validator.validate_response(response_data) + except ValueError as e: + self.logger.error(f"Invalid game state: {e}") + + return response_data + + def _extract_select_hand_params(self, raw_action: np.ndarray) -> List[int]: + """ + Converts the raw action to a list of Lua card indices + + Args: + raw_action: The whole action from the RL agent + Returns: + List of 1-based card indices for Lua + """ + card_indices = raw_action[self.slices["card_indices"]] + return [i + 1 for i, val in enumerate(card_indices) if val == 1] + + diff --git a/ai/utils/serialization.py b/ai/utils/serialization.py deleted file mode 100644 index dafab0a..0000000 --- a/ai/utils/serialization.py +++ /dev/null @@ -1,122 +0,0 @@ -""" -JSON Serialization utilities for Balatro RL -Handles conversion between game state formats -""" - -import json -from typing import Dict, Any, List, Optional - - -class GameStateSerializer: - """Handles serialization/deserialization of game states""" - - @staticmethod - def serialize_game_state(game_state: Dict[str, Any]) -> str: - """Convert game state dict to JSON string""" - try: - return json.dumps(game_state, indent=2, ensure_ascii=False) - except (TypeError, ValueError) as e: - raise ValueError(f"Failed to serialize game state: {e}") - - @staticmethod - def deserialize_game_state(json_string: str) -> Dict[str, Any]: - """Convert JSON string to game state dict""" - try: - return json.loads(json_string) - except json.JSONDecodeError as e: - raise ValueError(f"Failed to deserialize game state: {e}") - - @staticmethod - def serialize_action(action: Dict[str, Any]) -> str: - """Convert action dict to JSON string""" - try: - return json.dumps(action, ensure_ascii=False) - except (TypeError, ValueError) as e: - raise ValueError(f"Failed to serialize action: {e}") - - @staticmethod - def deserialize_action(json_string: str) -> Dict[str, Any]: - """Convert JSON string to action dict""" - try: - return json.loads(json_string) - except json.JSONDecodeError as e: - raise ValueError(f"Failed to deserialize action: {e}") - - -class GameStateValidator: - """Validates game state structure and content""" - - @staticmethod - def validate_game_state(game_state: Dict[str, Any]) -> bool: - """Validate that game state has required fields""" - required_fields = ['state', 'available_actions'] - - for field in required_fields: - if field not in game_state: - raise ValueError(f"Missing required field: {field}") - - # Validate hand structure if present - if 'hand' in game_state: - GameStateValidator._validate_hand(game_state['hand']) - - return True - - @staticmethod - def _validate_hand(hand: Dict[str, Any]) -> bool: - """Validate hand structure""" - if not isinstance(hand, dict): - raise ValueError("Hand must be a dictionary") - - if 'cards' in hand and not isinstance(hand['cards'], list): - raise ValueError("Hand cards must be a list") - - # Validate each card - for i, card in enumerate(hand.get('cards', [])): - GameStateValidator._validate_card(card, i) - - return True - - @staticmethod - def _validate_card(card: Dict[str, Any], index: int) -> bool: - """Validate individual card structure""" - required_fields = ['suit', 'base'] - - for field in required_fields: - if field not in card: - raise ValueError(f"Card {index} missing required field: {field}") - - # Validate base structure - if 'base' in card: - base = card['base'] - if 'value' not in base or 'nominal' not in base: - raise ValueError(f"Card {index} base missing value or nominal") - - return True - - @staticmethod - def validate_action(action: Dict[str, Any]) -> bool: - """Validate action structure""" - if 'action' not in action: - raise ValueError("Action must have 'action' field") - - action_type = action['action'] - - # Validate specific action types - if action_type in ['play_hand', 'discard']: - if 'card_indices' not in action: - raise ValueError(f"Action {action_type} requires card_indices") - - if not isinstance(action['card_indices'], list): - raise ValueError("card_indices must be a list") - - return True - - -# Convenience functions -def to_json(obj: Dict[str, Any]) -> str: - """Quick serialize to JSON""" - return GameStateSerializer.serialize_game_state(obj) - -def from_json(json_str: str) -> Dict[str, Any]: - """Quick deserialize from JSON""" - return GameStateSerializer.deserialize_game_state(json_str) \ No newline at end of file diff --git a/ai/utils/validation.py b/ai/utils/validation.py new file mode 100644 index 0000000..52a3175 --- /dev/null +++ b/ai/utils/validation.py @@ -0,0 +1,89 @@ +""" +Game State Validation utilities for Balatro RL +Validates game state structure and content +""" + +from typing import Dict, Any, List + + + +class GameStateValidator: + """ + Validates game state json structure + + Request Contract (Macro level not everything) + { + game_state: Dict, # The whole game state table + available_actions: List, # The available actions all as integers + retry_count: int, # TODO update observation space. this is to handle retries if AI fails + } + """ + + @staticmethod + def validate_game_state(game_state: Dict[str, Any]) -> bool: + """Validate that game state has required fields""" + required_fields = ['game_state', 'available_actions', 'retry_count'] + # TODO we can add more validations as needed like ensuring available_actions is a list and stuff + # TODO game_over might be a validation + # TODO validate chips for rewards + + for field in required_fields: + if field not in game_state: + raise ValueError(f"Missing required field: {field}") + + # Validate hand structure if present + if 'hand' in game_state: + GameStateValidator._validate_hand(game_state['hand']) + + return True + + @staticmethod + def _validate_hand(hand: Dict[str, Any]) -> bool: + """Validate hand structure""" + if not isinstance(hand, dict): + raise ValueError("Hand must be a dictionary") + + if 'cards' in hand and not isinstance(hand['cards'], list): + raise ValueError("Hand cards must be a list") + + # Validate each card + for i, card in enumerate(hand.get('cards', [])): + GameStateValidator._validate_card(card, i) + + return True + + @staticmethod + def _validate_card(card: Dict[str, Any], index: int) -> bool: + """Validate individual card structure""" + required_fields = ['suit', 'base'] + + for field in required_fields: + if field not in card: + raise ValueError(f"Card {index} missing required field: {field}") + + # Validate base structure + if 'base' in card: + base = card['base'] + if 'value' not in base or 'nominal' not in base: + raise ValueError(f"Card {index} base missing value or nominal") + + return True + +class ResponseValidator: + """ + Validates response json structure + + Response contract + { + action, # The action to take + params, # The params in the event the action takes in params + } + """ + + @staticmethod + def validate_response(response: Dict[str, Any]): + required_fields = ["action"] + for field in required_fields: + if field not in response: + raise ValueError(f"Missing required field: {field}") +