Have a successful pipe communication between AI and balatro with some intial iteration going
This commit is contained in:
parent
bbeab8c635
commit
9b5f949a0b
|
|
@ -56,13 +56,17 @@ logs/
|
|||
*.temp
|
||||
.cache/
|
||||
|
||||
# Model checkpoints and training artifacts
|
||||
# Training artifacts (generated during RL training sessions)
|
||||
models/
|
||||
tensorboard_logs/
|
||||
training.log
|
||||
training_monitor.csv
|
||||
|
||||
# Other ML artifacts
|
||||
*.pth
|
||||
*.pt
|
||||
*.ckpt
|
||||
checkpoints/
|
||||
models/
|
||||
tensorboard_logs/
|
||||
wandb/
|
||||
|
||||
# Large reference files (don't commit decompiled game source)
|
||||
|
|
|
|||
21
CLAUDE.md
21
CLAUDE.md
|
|
@ -55,11 +55,32 @@ Balatro uses numbered states defined in `globals.lua`:
|
|||
- `G.GAME` - Main game data structure
|
||||
- `G.FUNCS` - Game function callbacks
|
||||
|
||||
## Communication Architecture
|
||||
|
||||
### Dual Pipe Communication
|
||||
The mod uses separate named pipes for request/response communication with the AI:
|
||||
- **Request pipe path**: `/tmp/balatro_request` (Balatro writes, AI reads)
|
||||
- **Response pipe path**: `/tmp/balatro_response` (AI writes, Balatro reads)
|
||||
- **Protocol**: Persistent pipe handles with blocking reads
|
||||
- **Flow**:
|
||||
1. Balatro opens persistent handles to both pipes
|
||||
2. Balatro writes JSON request to request pipe
|
||||
3. AI reads request from request pipe
|
||||
4. AI writes JSON response to response pipe
|
||||
5. Balatro performs blocking read from response pipe
|
||||
|
||||
### Benefits of Dual Pipes
|
||||
- Clear separation of request/response channels
|
||||
- Persistent handles avoid repeated open/close overhead
|
||||
- Blocking reads ensure proper synchronization
|
||||
- Each pipe has single direction, eliminating read/write conflicts
|
||||
|
||||
## Current Status
|
||||
The mod can successfully:
|
||||
- Start a run automatically
|
||||
- Detect blind selection state
|
||||
- Automatically select the next blind
|
||||
- Communicate with AI via dual pipe system
|
||||
|
||||
## Communication with Claude
|
||||
- Please keep chats as efficient as possible and brief. No fluff talk, just get to the point
|
||||
|
|
|
|||
|
|
@ -28,7 +28,12 @@ Made a symlink -> ln -s ~/dev/balatro-rl/RLBridge /mnt/gamerlinuxssd/SteamLibrar
|
|||
- [ ] Training loop integration
|
||||
|
||||
### Game Features
|
||||
- [ ] Always have restart_run as an action option assuming the game is ongoing
|
||||
- [ ] Blind selection choices (skip vs select)
|
||||
- [ ] Extended game state (money, discards, hands played)
|
||||
- [ ] Shop interactions
|
||||
- [ ] Joker management
|
||||
|
||||
### RL Enhancements
|
||||
- [ ] Add a "Replay System" to analyze successful actions. For example, save seed, have an action log for reproduction etc
|
||||
Can probably do this by adding it before the game is reset like a check on what criteria I want to save for, and save
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@ local utils = require("utils")
|
|||
local last_state_hash = nil
|
||||
local last_actions_hash = nil
|
||||
local pending_action = nil
|
||||
local retry_count = 0
|
||||
local need_retry_request = false
|
||||
local rl_training_active = false
|
||||
local last_key_pressed = nil
|
||||
|
||||
--- Initialize AI system
|
||||
--- Sets up communication and prepares the AI for operation
|
||||
|
|
@ -20,12 +24,50 @@ local pending_action = nil
|
|||
function AI.init()
|
||||
utils.log_ai("Initializing AI system...")
|
||||
communication.init()
|
||||
|
||||
-- Hook into Love2D keyboard events
|
||||
if love and love.keypressed then
|
||||
local original_keypressed = love.keypressed
|
||||
love.keypressed = function(key)
|
||||
-- Store the key press for our AI
|
||||
last_key_pressed = key
|
||||
|
||||
-- Call original function
|
||||
if original_keypressed then
|
||||
original_keypressed(key)
|
||||
end
|
||||
end
|
||||
else
|
||||
end
|
||||
end
|
||||
|
||||
--- Main AI update loop (called every frame)
|
||||
--- Monitors game state changes, handles communication, and executes AI actions
|
||||
--- @return nil
|
||||
function AI.update()
|
||||
-- Check for key press to start/stop RL training
|
||||
if last_key_pressed then
|
||||
if last_key_pressed == "r" then
|
||||
if not rl_training_active then
|
||||
rl_training_active = true
|
||||
utils.log_ai("🚀 RL Training STARTED (R pressed)")
|
||||
end
|
||||
elseif last_key_pressed == "t" then
|
||||
if rl_training_active then
|
||||
rl_training_active = false
|
||||
utils.log_ai("⏹️ RL Training STOPPED (T pressed)")
|
||||
end
|
||||
end
|
||||
|
||||
-- Clear the key press
|
||||
last_key_pressed = nil
|
||||
end
|
||||
|
||||
-- Don't process AI requests unless RL training is active
|
||||
if not rl_training_active then
|
||||
return
|
||||
end
|
||||
|
||||
-- Get current game state
|
||||
local current_state = output.get_game_state()
|
||||
local available_actions = action.get_available_actions()
|
||||
|
|
@ -40,29 +82,40 @@ function AI.update()
|
|||
return
|
||||
end
|
||||
|
||||
-- Create simple hash to detect state changes
|
||||
-- Create hash to detect state changes
|
||||
local state_hash = AI.hash_state(current_state)
|
||||
local actions_hash = AI.hash_actions(available_actions)
|
||||
|
||||
-- Request action from AI if state or actions changed
|
||||
if state_hash ~= last_state_hash or actions_hash ~= last_actions_hash then
|
||||
if state_hash ~= last_state_hash or actions_hash ~= last_actions_hash or need_retry_request then
|
||||
-- State has changed
|
||||
if state_hash ~= last_state_hash then
|
||||
utils.log_ai("State changed to: " ..
|
||||
current_state.state .. " (" .. utils.state_name(current_state.state) .. ")")
|
||||
current_state.state .. " (" .. utils.get_state_name(current_state.state) .. ")")
|
||||
action.reset_state()
|
||||
last_state_hash = state_hash
|
||||
end
|
||||
|
||||
-- Available actions have changed
|
||||
if actions_hash ~= last_actions_hash then
|
||||
utils.log_ai("Available actions changed: " ..
|
||||
table.concat(utils.get_action_names(available_actions), ", "))
|
||||
last_actions_hash = actions_hash
|
||||
end
|
||||
|
||||
-- Request action from AI via file I/O
|
||||
local ai_response = communication.request_action(current_state, available_actions)
|
||||
if ai_response and ai_response.action ~= "no_action" then
|
||||
pending_action = ai_response
|
||||
-- Request action from AI
|
||||
need_retry_request = false
|
||||
local ai_response = communication.request_action(current_state, available_actions, retry_count)
|
||||
|
||||
if ai_response then
|
||||
-- Handling handshake
|
||||
if ai_response.action == "ready" then
|
||||
utils.log_ai("Handshake complete - AI ready")
|
||||
-- Don't return - continue normal game loop
|
||||
-- Force a state check on next frame to send real request
|
||||
last_state_hash = nil
|
||||
else
|
||||
pending_action = ai_response
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
|
@ -71,8 +124,12 @@ function AI.update()
|
|||
local result = action.execute_action(pending_action.action, pending_action.params)
|
||||
if result.success then
|
||||
utils.log_ai("Action executed successfully: " .. pending_action.action)
|
||||
retry_count = 0
|
||||
else
|
||||
utils.log_ai("Action failed: " .. (result.error or "Unknown error"))
|
||||
-- Update retry_count to indicate state change
|
||||
utils.log_ai("Action failed: " .. (result.error or "Unknown error") .. " RETRYING...")
|
||||
retry_count = retry_count + 1
|
||||
need_retry_request = true
|
||||
end
|
||||
pending_action = nil
|
||||
end
|
||||
|
|
@ -83,7 +140,7 @@ end
|
|||
--- @param game_state table Current game state data
|
||||
--- @return string Hash representing the current state
|
||||
function AI.hash_state(game_state)
|
||||
return game_state.state .. "_" .. (game_state.round or 0) .. "_" .. (game_state.chips or 0)
|
||||
return game_state.state .. "_" .. (game_state.chips or 0) -- TODO update hash to be more unique
|
||||
end
|
||||
|
||||
--- Create simple hash of available actions
|
||||
|
|
|
|||
|
|
@ -1,124 +1,117 @@
|
|||
--- RLBridge communication module
|
||||
--- Handles file-based communication between the game and external AI system
|
||||
--- Handles dual pipe communication with persistent handles between the game and external AI system
|
||||
|
||||
local COMM = {}
|
||||
local utils = require("utils")
|
||||
local action = require("actions")
|
||||
local json = require("dkjson")
|
||||
|
||||
-- File-based communication settings
|
||||
-- Dual pipe communication settings with persistent handles
|
||||
local comm_enabled = false
|
||||
local request_file = "/tmp/balatro_request.json"
|
||||
local response_file = "/tmp/balatro_response.json"
|
||||
local request_counter = 0
|
||||
local last_response_time = 0
|
||||
local request_pipe = "/tmp/balatro_request"
|
||||
local response_pipe = "/tmp/balatro_response"
|
||||
local request_handle = nil
|
||||
local response_handle = nil
|
||||
|
||||
--- Check if response file has new content
|
||||
--- @param expected_id number Expected request ID
|
||||
--- @return table|nil Response data if new response available, nil otherwise
|
||||
local function check_for_response(expected_id)
|
||||
local response_file_handle = io.open(response_file, "r")
|
||||
if not response_file_handle then
|
||||
return nil
|
||||
end
|
||||
|
||||
local response_json = response_file_handle:read("*all")
|
||||
response_file_handle:close()
|
||||
|
||||
if not response_json or response_json == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
local response_data = json.decode(response_json)
|
||||
if not response_data then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Check if this is a new response for our request
|
||||
if response_data.id == expected_id and response_data.timestamp then
|
||||
if response_data.timestamp > last_response_time then
|
||||
last_response_time = response_data.timestamp
|
||||
return response_data
|
||||
end
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Initialize file-based communication
|
||||
--- Sets up file-based communication channel with external AI system
|
||||
--- Initialize dual pipe communication with persistent handles
|
||||
--- Sets up persistent pipe handles with external AI system
|
||||
--- @return nil
|
||||
function COMM.init()
|
||||
utils.log_comm("Initializing file-based communication...")
|
||||
comm_enabled = true
|
||||
last_response_time = 0 -- Reset response tracking
|
||||
|
||||
utils.log_comm("Ready for file-based requests")
|
||||
utils.log_comm("Request file: " .. request_file)
|
||||
utils.log_comm("Response file: " .. response_file)
|
||||
utils.log_comm("Initializing dual pipe communication with persistent handles...")
|
||||
utils.log_comm("Note: Pipes will be opened when first needed (lazy initialization)")
|
||||
comm_enabled = true -- Enable communication, pipes will open on first use
|
||||
end
|
||||
|
||||
--- Send game turn request to AI and get action via files
|
||||
--- Lazy initialization of pipe handles when first needed
|
||||
--- @return boolean True if pipes are ready, false otherwise
|
||||
function COMM.ensure_pipes_open()
|
||||
if request_handle and response_handle then
|
||||
return true -- Already open
|
||||
end
|
||||
|
||||
-- Try to open pipes (this will block until AI creates them)
|
||||
utils.log_comm("Opening persistent pipe handles...")
|
||||
|
||||
-- Open response pipe for reading (keep open)
|
||||
response_handle = io.open(response_pipe, "r")
|
||||
if not response_handle then
|
||||
utils.log_comm("ERROR: Cannot open response pipe for reading: " .. response_pipe)
|
||||
return false
|
||||
end
|
||||
|
||||
-- Open request pipe for writing (keep open)
|
||||
request_handle = io.open(request_pipe, "w")
|
||||
if not request_handle then
|
||||
utils.log_comm("ERROR: Cannot open request pipe for writing: " .. request_pipe)
|
||||
if response_handle then
|
||||
response_handle:close()
|
||||
response_handle = nil
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
utils.log_comm("Persistent pipe handles opened successfully")
|
||||
utils.log_comm("Request pipe (write): " .. request_pipe)
|
||||
utils.log_comm("Response pipe (read): " .. response_pipe)
|
||||
return true
|
||||
end
|
||||
|
||||
--- Send game turn request to AI and get action via persistent pipe handles
|
||||
--- @param game_state table Current game state data
|
||||
--- @param available_actions table Available actions list
|
||||
--- @return table|nil Action response from AI, nil if error
|
||||
function COMM.request_action(game_state, available_actions)
|
||||
function COMM.request_action(game_state, available_actions, retry_count)
|
||||
if not comm_enabled then
|
||||
utils.log_comm("ERROR: Communication not enabled")
|
||||
return nil
|
||||
end
|
||||
|
||||
request_counter = request_counter + 1
|
||||
-- Lazy initialization - open pipes when first needed
|
||||
if not COMM.ensure_pipes_open() then
|
||||
utils.log_comm("ERROR: Failed to open pipe handles")
|
||||
return nil
|
||||
end
|
||||
|
||||
local request = {
|
||||
id = request_counter,
|
||||
timestamp = os.time(),
|
||||
game_state = game_state,
|
||||
available_actions = available_actions or {}
|
||||
available_actions = available_actions or {},
|
||||
retry_count = retry_count,
|
||||
}
|
||||
|
||||
utils.log_comm("Requesting action for state: " .. tostring(game_state.state))
|
||||
utils.log_comm(utils.get_timestamp() .. "Sending action request for state: " ..
|
||||
tostring(game_state.state) .. " (" .. utils.get_state_name(game_state.state) .. ")")
|
||||
|
||||
-- Write request to file
|
||||
-- Encode request as JSON
|
||||
local json_data = json.encode(request)
|
||||
if not json_data then
|
||||
utils.log_comm("ERROR: Failed to encode request JSON")
|
||||
return nil
|
||||
end
|
||||
|
||||
local request_file_handle = io.open(request_file, "w")
|
||||
if not request_file_handle then
|
||||
utils.log_comm("ERROR: Cannot write to request file: " .. request_file)
|
||||
-- Write request to persistent handle
|
||||
request_handle:write(json_data .. "\n")
|
||||
request_handle:flush() -- Ensure data is sent immediately
|
||||
utils.log_comm(utils.get_timestamp() .. "Request sent...")
|
||||
|
||||
-- Read response from persistent handle
|
||||
utils.log_comm(utils.get_timestamp() .. "about to read the respones")
|
||||
local response_json = response_handle:read("*line")
|
||||
|
||||
if not response_json or response_json == "" then
|
||||
utils.log_comm("ERROR: No response received from AI")
|
||||
return nil
|
||||
end
|
||||
|
||||
request_file_handle:write(json_data)
|
||||
request_file_handle:close()
|
||||
|
||||
-- Wait for response file (with timeout)
|
||||
local max_wait_time = 2 -- seconds
|
||||
local wait_interval = 0.05 -- seconds
|
||||
local total_waited = 0
|
||||
|
||||
while total_waited < max_wait_time do
|
||||
local response_data = check_for_response(request_counter)
|
||||
if response_data then
|
||||
utils.log_comm("AI action: " .. tostring(response_data.action))
|
||||
return response_data
|
||||
end
|
||||
|
||||
-- Sleep for a short time using busy wait
|
||||
local start_time = os.clock()
|
||||
while (os.clock() - start_time) < wait_interval do
|
||||
-- Busy wait for short duration
|
||||
end
|
||||
total_waited = total_waited + wait_interval
|
||||
local response_data = json.decode(response_json)
|
||||
if not response_data then
|
||||
utils.log_comm("ERROR: Failed to decode response JSON")
|
||||
return nil
|
||||
end
|
||||
|
||||
utils.log_comm("ERROR: Timeout waiting for AI response")
|
||||
return nil
|
||||
utils.log_comm("AI action: " .. tostring(response_data.action))
|
||||
return response_data
|
||||
end
|
||||
|
||||
--- Check if file communication is enabled
|
||||
--- Check if pipe communication is enabled
|
||||
--- Returns the current communication status
|
||||
--- @return boolean True if enabled, false otherwise
|
||||
function COMM.is_connected()
|
||||
|
|
@ -126,12 +119,22 @@ function COMM.is_connected()
|
|||
end
|
||||
|
||||
--- Close communication
|
||||
--- Terminates the communication channel with the AI system
|
||||
--- Terminates the persistent pipe handles with the AI system
|
||||
--- @return nil
|
||||
function COMM.close()
|
||||
comm_enabled = false
|
||||
last_response_time = 0
|
||||
utils.log_comm("File communication disabled")
|
||||
|
||||
if request_handle then
|
||||
request_handle:close()
|
||||
request_handle = nil
|
||||
end
|
||||
|
||||
if response_handle then
|
||||
response_handle:close()
|
||||
response_handle = nil
|
||||
end
|
||||
|
||||
utils.log_comm("Persistent pipe communication closed")
|
||||
end
|
||||
|
||||
return COMM
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ G.SETTINGS.reduced_motion = true
|
|||
function INIT.start_run()
|
||||
print("Starting balatro run")
|
||||
|
||||
|
||||
-- Initialize AI system
|
||||
local ai = require("ai")
|
||||
ai.init()
|
||||
|
|
|
|||
|
|
@ -4,15 +4,20 @@
|
|||
|
||||
local O = {}
|
||||
|
||||
local actions = require("actions")
|
||||
|
||||
--- Get comprehensive game state for AI
|
||||
--- Collects all relevant game information into a structured format
|
||||
--- @return table Complete game state data for AI processing
|
||||
function O.get_game_state()
|
||||
-- Calculate game_over
|
||||
local game_over = 0
|
||||
if G.STATE == G.STATES.GAME_OVER then
|
||||
game_over = 1
|
||||
end
|
||||
|
||||
return {
|
||||
-- Basic state info
|
||||
state = G.STATE,
|
||||
game_over = game_over,
|
||||
|
||||
-- Game progression
|
||||
-- round = G.GAME and G.GAME.round or 0,
|
||||
|
|
@ -27,16 +32,9 @@ function O.get_game_state()
|
|||
-- Hand info (if applicable)
|
||||
hand = O.get_hand_info(),
|
||||
|
||||
-- -- Jokers info
|
||||
-- jokers = O.get_jokers_info(),
|
||||
|
||||
-- -- Chips/score info
|
||||
-- chips = G.GAME and G.GAME.chips or 0,
|
||||
-- chip_target = G.GAME and G.GAME.blind and G.GAME.blind.chips or 0,
|
||||
|
||||
-- -- Action feedback for AI
|
||||
-- last_action_success = true, -- Did last action work?
|
||||
-- last_action_error = nil, -- What went wrong?
|
||||
-- Chips/score info
|
||||
chips = G.GAME and G.GAME.chips or 0,
|
||||
chip_target = G.GAME and G.GAME.blind and G.GAME.blind.chips or 0,
|
||||
}
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ end
|
|||
--- Maps numeric state IDs to their string representations for debugging
|
||||
--- @param state_id number Numeric state identifier
|
||||
--- @return string Human-readable state name or "UNKNOWN_STATE"
|
||||
function UTIL.state_name(state_id)
|
||||
function UTIL.get_state_name(state_id)
|
||||
for key, value in pairs(G.STATES) do
|
||||
if value == state_id then
|
||||
return key
|
||||
|
|
@ -133,4 +133,22 @@ function UTIL.get_action_names(available_actions)
|
|||
return names
|
||||
end
|
||||
|
||||
function UTIL.contains(tbl, val)
|
||||
for x, _ in ipairs(tbl) do
|
||||
if tbl[x] == val then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Get current timestamp as formatted string
|
||||
--- Returns current time in HH:MM:SS.mmm format for debugging
|
||||
--- @return string Formatted timestamp
|
||||
function UTIL.get_timestamp()
|
||||
local time = os.time()
|
||||
local ms = (os.clock() % 1) * 1000
|
||||
return os.date("%H:%M:%S", time) .. string.format(".%03d", ms)
|
||||
end
|
||||
|
||||
return UTIL
|
||||
|
|
|
|||
|
|
@ -1,95 +0,0 @@
|
|||
"""
|
||||
Balatro RL Agent
|
||||
Contains the AI logic for making decisions in Balatro
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
import random
|
||||
import logging
|
||||
|
||||
|
||||
class BalatroAgent:
|
||||
"""Main AI agent for playing Balatro"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: Initialize your RL model here
|
||||
# self.model = load_model("path/to/model")
|
||||
# self.training = True
|
||||
|
||||
def get_action(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Given game state, return the action to take
|
||||
|
||||
Args:
|
||||
game_state: Current game state from Balatro mod
|
||||
|
||||
Returns:
|
||||
Action dictionary to send back to mod
|
||||
"""
|
||||
try:
|
||||
# Extract relevant info from game state
|
||||
available_actions = request.get('available_actions', [])
|
||||
game_state = request.get('game_state', {})
|
||||
|
||||
self.logger.info(f"Processing state {game_state.get('state', 0)} with {len(available_actions)} actions")
|
||||
|
||||
# TODO: Replace with actual RL logic
|
||||
action = self._get_random_action(available_actions) # TODO we are getting random actions for testing
|
||||
|
||||
# TODO: If training, update model with reward
|
||||
# if self.training:
|
||||
# self._update_model(game_state, action, reward)
|
||||
|
||||
return action
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in get_action: {e}")
|
||||
return {"action": "no_action"}
|
||||
|
||||
def _get_random_action(self, available_actions: List) -> Dict[str, Any]:
|
||||
"""Temporary random action selection - replace with RL logic"""
|
||||
|
||||
# available_actions comes as [1, 2] format
|
||||
if not available_actions or not isinstance(available_actions, List):
|
||||
return {"action": "no_action"}
|
||||
|
||||
# Simple random action selection
|
||||
selected_action_id = random.choice(available_actions)
|
||||
self.logger.info(f"Selected action ID: {selected_action_id} from available: {available_actions}")
|
||||
|
||||
# TODO this is for testing, the AI doesn't really care about what the action_names are
|
||||
# but we do so we can test out actions and make sure they work
|
||||
# Map action IDs to names for logic (but return the ID)
|
||||
action_names = {1: "start_run", 2: "select_blind", 3: "select_hand"}
|
||||
action_name = action_names.get(selected_action_id, "unknown")
|
||||
|
||||
# Pick random cards for now
|
||||
params = {}
|
||||
if action_name == "select_hand":
|
||||
params["card_indices"] = [1, 3, 5]
|
||||
|
||||
# Return the action ID (number), not the name
|
||||
return {"action": selected_action_id, "params": params}
|
||||
|
||||
def train_step(self, game_state: Dict, action: Dict, reward: float, next_state: Dict):
|
||||
"""
|
||||
Perform one training step
|
||||
|
||||
TODO: Implement RL training logic here
|
||||
- Update Q-values
|
||||
- Update neural network weights
|
||||
- Store experience in replay buffer
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_model(self, path: str):
|
||||
"""Save the trained model"""
|
||||
# TODO: Implement model saving
|
||||
pass
|
||||
|
||||
def load_model(self, path: str):
|
||||
"""Load a trained model"""
|
||||
# TODO: Implement model loading
|
||||
pass
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
"""
|
||||
Neural network policy for Balatro RL agent
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from typing import Dict, Any
|
||||
|
||||
class BalatroPolicy(nn.Module):
|
||||
"""
|
||||
Neural network policy for Balatro decision making
|
||||
"""
|
||||
|
||||
def __init__(self, state_dim: int = 128, action_dim: int = 64, hidden_dim: int = 256):
|
||||
super().__init__()
|
||||
|
||||
self.state_dim = state_dim
|
||||
self.action_dim = action_dim
|
||||
|
||||
# Feature extraction layers
|
||||
self.feature_net = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# Action value head
|
||||
self.value_head = nn.Linear(hidden_dim, 1)
|
||||
|
||||
# Action probability head
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
|
||||
def forward(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass through the network
|
||||
|
||||
Args:
|
||||
state: Game state tensor
|
||||
|
||||
Returns:
|
||||
Tuple of (action_logits, state_value)
|
||||
"""
|
||||
features = self.feature_net(state)
|
||||
|
||||
action_logits = self.action_head(features)
|
||||
state_value = self.value_head(features)
|
||||
|
||||
return action_logits, state_value
|
||||
|
||||
def get_action(self, state: np.ndarray, available_actions: list) -> tuple[int, float]:
|
||||
"""
|
||||
Get action from policy given current state
|
||||
|
||||
Args:
|
||||
state: Current game state as numpy array
|
||||
available_actions: List of available action indices
|
||||
|
||||
Returns:
|
||||
Tuple of (action_index, action_probability)
|
||||
"""
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0)
|
||||
action_logits, _ = self.forward(state_tensor)
|
||||
|
||||
# Mask unavailable actions
|
||||
masked_logits = action_logits.clone()
|
||||
if available_actions:
|
||||
mask = torch.full_like(action_logits, float('-inf'))
|
||||
mask[0, available_actions] = 0
|
||||
masked_logits = action_logits + mask
|
||||
|
||||
# Get action probabilities
|
||||
action_probs = torch.softmax(masked_logits, dim=-1)
|
||||
|
||||
# Sample action
|
||||
action_dist = torch.distributions.Categorical(action_probs)
|
||||
action = action_dist.sample()
|
||||
|
||||
return action.item(), action_probs[0, action].item()
|
||||
|
|
@ -1,153 +0,0 @@
|
|||
"""
|
||||
Training logic for Balatro RL agent
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any
|
||||
from .policy import BalatroPolicy
|
||||
import logging
|
||||
|
||||
class BalatroTrainer:
|
||||
"""
|
||||
Handles training of the Balatro RL agent
|
||||
"""
|
||||
|
||||
def __init__(self, policy: BalatroPolicy, learning_rate: float = 3e-4):
|
||||
self.policy = policy
|
||||
self.optimizer = optim.Adam(policy.parameters(), lr=learning_rate)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Training buffers
|
||||
self.states = []
|
||||
self.actions = []
|
||||
self.rewards = []
|
||||
self.values = []
|
||||
self.log_probs = []
|
||||
|
||||
def collect_experience(self, state: np.ndarray, action: int, reward: float,
|
||||
value: float, log_prob: float):
|
||||
"""
|
||||
Collect experience for training
|
||||
|
||||
Args:
|
||||
state: Game state
|
||||
action: Action taken
|
||||
reward: Reward received
|
||||
value: State value estimate
|
||||
log_prob: Log probability of action
|
||||
"""
|
||||
self.states.append(state)
|
||||
self.actions.append(action)
|
||||
self.rewards.append(reward)
|
||||
self.values.append(value)
|
||||
self.log_probs.append(log_prob)
|
||||
|
||||
def compute_returns(self, next_value: float = 0.0, gamma: float = 0.99) -> List[float]:
|
||||
"""
|
||||
Compute discounted returns
|
||||
|
||||
Args:
|
||||
next_value: Value of next state (for bootstrapping)
|
||||
gamma: Discount factor
|
||||
|
||||
Returns:
|
||||
List of discounted returns
|
||||
"""
|
||||
returns = []
|
||||
R = next_value
|
||||
|
||||
for reward in reversed(self.rewards):
|
||||
R = reward + gamma * R
|
||||
returns.insert(0, R)
|
||||
|
||||
return returns
|
||||
|
||||
def update_policy(self, gamma: float = 0.99, entropy_coef: float = 0.01,
|
||||
value_coef: float = 0.5) -> Dict[str, float]:
|
||||
"""
|
||||
Update policy using collected experience
|
||||
|
||||
Args:
|
||||
gamma: Discount factor
|
||||
entropy_coef: Entropy regularization coefficient
|
||||
value_coef: Value loss coefficient
|
||||
|
||||
Returns:
|
||||
Dictionary of training metrics
|
||||
"""
|
||||
if not self.states:
|
||||
return {}
|
||||
|
||||
# Convert to tensors
|
||||
states = torch.FloatTensor(np.array(self.states))
|
||||
actions = torch.LongTensor(self.actions)
|
||||
old_log_probs = torch.FloatTensor(self.log_probs)
|
||||
values = torch.FloatTensor(self.values)
|
||||
|
||||
# Compute returns
|
||||
returns = self.compute_returns(gamma=gamma)
|
||||
returns = torch.FloatTensor(returns)
|
||||
|
||||
# Compute advantages
|
||||
advantages = returns - values
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# Forward pass
|
||||
action_logits, new_values = self.policy(states)
|
||||
|
||||
# Compute losses
|
||||
action_dist = torch.distributions.Categorical(logits=action_logits)
|
||||
new_log_probs = action_dist.log_prob(actions)
|
||||
entropy = action_dist.entropy().mean()
|
||||
|
||||
# Policy loss (PPO-style, but simplified)
|
||||
ratio = torch.exp(new_log_probs - old_log_probs)
|
||||
policy_loss = -(ratio * advantages).mean()
|
||||
|
||||
# Value loss
|
||||
value_loss = torch.nn.functional.mse_loss(new_values.squeeze(), returns)
|
||||
|
||||
# Total loss
|
||||
total_loss = policy_loss + value_coef * value_loss - entropy_coef * entropy
|
||||
|
||||
# Update
|
||||
self.optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
|
||||
self.optimizer.step()
|
||||
|
||||
# Clear buffers
|
||||
self.clear_buffers()
|
||||
|
||||
# Return metrics
|
||||
return {
|
||||
'policy_loss': policy_loss.item(),
|
||||
'value_loss': value_loss.item(),
|
||||
'entropy': entropy.item(),
|
||||
'total_loss': total_loss.item()
|
||||
}
|
||||
|
||||
def clear_buffers(self):
|
||||
"""Clear experience buffers"""
|
||||
self.states.clear()
|
||||
self.actions.clear()
|
||||
self.rewards.clear()
|
||||
self.values.clear()
|
||||
self.log_probs.clear()
|
||||
|
||||
def save_model(self, path: str):
|
||||
"""Save model checkpoint"""
|
||||
torch.save({
|
||||
'policy_state_dict': self.policy.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict()
|
||||
}, path)
|
||||
self.logger.info(f"Model saved to {path}")
|
||||
|
||||
def load_model(self, path: str):
|
||||
"""Load model checkpoint"""
|
||||
checkpoint = torch.load(path)
|
||||
self.policy.load_state_dict(checkpoint['policy_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.logger.info(f"Model loaded from {path}")
|
||||
|
|
@ -1,177 +1,190 @@
|
|||
"""
|
||||
Balatro RL Environment
|
||||
Wraps the HTTP communication in a gym-like interface for RL training
|
||||
Wraps the pipe-based communication with Balatro mod in a standard RL interface.
|
||||
This acts as a translator between Balatro's JSON pipe communication and
|
||||
RL libraries that expect gym-style step()/reset() methods.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
import logging
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
from ..utils.debug import tprint
|
||||
|
||||
from ..utils.serialization import GameStateValidator
|
||||
from ..utils.communication import BalatroPipeIO
|
||||
from .reward import BalatroRewardCalculator
|
||||
from ..utils.mappers import BalatroStateMapper, BalatroActionMapper
|
||||
|
||||
|
||||
class BalatroEnvironment:
|
||||
class BalatroEnv(gym.Env):
|
||||
"""
|
||||
Gym-like environment interface for Balatro RL training
|
||||
Standard RL Environment wrapper for Balatro
|
||||
|
||||
This class will eventually wrap the HTTP communication
|
||||
and provide a standard RL interface with step(), reset(), etc.
|
||||
Translates between:
|
||||
- Balatro mod's JSON pipe communication (/tmp/balatro_request, /tmp/balatro_response)
|
||||
- Standard RL interface (step, reset, observation spaces)
|
||||
|
||||
This allows RL libraries like Stable-Baselines3 to train on Balatro
|
||||
without knowing about the underlying pipe communication system.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.current_state = None
|
||||
self.prev_state = None
|
||||
self.game_over = False
|
||||
|
||||
# TODO: Define action and observation spaces
|
||||
# self.action_space = ...
|
||||
# self.observation_space = ...
|
||||
|
||||
def reset(self) -> Dict[str, Any]:
|
||||
"""Reset the environment for a new episode"""
|
||||
self.current_state = None
|
||||
self.game_over = False
|
||||
# Initialize communication and reward systems
|
||||
self.pipe_io = BalatroPipeIO()
|
||||
self.reward_calculator = BalatroRewardCalculator()
|
||||
|
||||
# Define Gymnasium spaces
|
||||
# Action Spaces; This should describe the type and shape of the action
|
||||
# Constants
|
||||
self.MAX_ACTIONS = 10
|
||||
action_selection = np.array([self.MAX_ACTIONS])
|
||||
card_indices = np.array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) # Handles up to 10 cards in a hand
|
||||
self.action_space = spaces.MultiDiscrete(np.concatenate([
|
||||
action_selection,
|
||||
card_indices
|
||||
]))
|
||||
ACTION_SLICE_LAYOUT = [
|
||||
("action_selection", 1),
|
||||
("card_indices", 10)
|
||||
]
|
||||
slices = self._build_action_slices(ACTION_SLICE_LAYOUT)
|
||||
|
||||
# TODO: Send reset signal to Balatro mod
|
||||
# This might involve starting a new run
|
||||
|
||||
return self._get_initial_state()
|
||||
# Observation space: This should describe the type and shape of the observation
|
||||
# Constants
|
||||
self.OBSERVATION_SIZE = 70 # you can get this value by running test_env.py.
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, # lowest bound of observation data
|
||||
high=np.inf, # highest bound of observation data
|
||||
shape=(self.OBSERVATION_SIZE,), # Adjust based on actual state size which This is a 1D array
|
||||
dtype=np.float32 # Data type of the numbers
|
||||
)
|
||||
|
||||
# Initialize mappers
|
||||
self.state_mapper = BalatroStateMapper(observation_size=self.OBSERVATION_SIZE, max_actions=self.MAX_ACTIONS)
|
||||
self.action_mapper = BalatroActionMapper(action_slices=slices)
|
||||
|
||||
def step(self, action: Dict[str, Any]) -> Tuple[Dict, float, bool, Dict]:
|
||||
def reset(self, seed=None, options=None): #TODO add types
|
||||
"""
|
||||
Take an action in the environment
|
||||
Reset the environment for a new episode
|
||||
|
||||
In Balatro context, this means starting a new run.
|
||||
Communicates with Balatro mod via pipes to initiate reset.
|
||||
|
||||
Returns:
|
||||
Initial observation/game state
|
||||
"""
|
||||
self.current_state = None
|
||||
self.prev_state = None
|
||||
self.game_over = False
|
||||
|
||||
# Reset reward tracking
|
||||
self.reward_calculator.reset()
|
||||
|
||||
# Wait for initial request from Balatro (game start)
|
||||
initial_request = self.pipe_io.wait_for_request()
|
||||
if not initial_request:
|
||||
raise RuntimeError("Failed to receive initial request from Balatro")
|
||||
|
||||
# Send dummy response to complete hand shake
|
||||
success = self.pipe_io.send_response({"action": "ready"})
|
||||
if not success:
|
||||
raise RuntimeError("Failed to complete handshake")
|
||||
|
||||
# Process initial state for SB3
|
||||
self.current_state = initial_request
|
||||
initial_observation = self.state_mapper.process_game_state(self.current_state)
|
||||
|
||||
return initial_observation, {}
|
||||
|
||||
def step(self, action): #TODO add types
|
||||
"""
|
||||
Take an action in the Balatro environment
|
||||
Sends action to Balatro mod via JSON pipe, waits for response,
|
||||
calculates reward, and returns standard RL step format.
|
||||
|
||||
Args:
|
||||
action: Action to take
|
||||
action: Action dictionary (e.g., {"action": 1, "params": {...}})
|
||||
|
||||
Returns:
|
||||
Tuple of (observation, reward, done, info)
|
||||
Tuple of (observation, reward, done, info) where:
|
||||
- observation: Processed game state for neural network
|
||||
- reward: Calculated reward for this step
|
||||
- done: Whether episode is finished (game over)
|
||||
- info: Additional debug information
|
||||
"""
|
||||
# TODO: Send action to Balatro via HTTP
|
||||
# TODO: Receive new game state
|
||||
# TODO: Calculate reward
|
||||
# TODO: Check if episode is done
|
||||
# Store previous state for reward calculation
|
||||
self.prev_state = self.current_state
|
||||
|
||||
# Send action response to Balatro mod
|
||||
response_data = self.action_mapper.process_action(rl_action=action)
|
||||
|
||||
observation = self._process_game_state(self.current_state)
|
||||
reward = self._calculate_reward()
|
||||
done = self.game_over
|
||||
info = {}
|
||||
success = self.pipe_io.send_response(response_data)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to send response to Balatro")
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def _get_initial_state(self) -> Dict[str, Any]:
|
||||
"""Get initial game state"""
|
||||
# TODO: Implement
|
||||
return {}
|
||||
|
||||
def _process_game_state(self, raw_state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process raw game state into format suitable for RL agent
|
||||
# Wait for next request with new game state
|
||||
next_request = self.pipe_io.wait_for_request()
|
||||
if not next_request:
|
||||
# If no response, assume game ended
|
||||
self.game_over = True
|
||||
observation = self.state_mapper.process_game_state(self.current_state)
|
||||
reward = 0.0
|
||||
return observation, reward, True, {"timeout": True} #TODO this is a bug we need to return 5 values
|
||||
|
||||
This might involve:
|
||||
- Extracting relevant features
|
||||
- Normalizing values
|
||||
- Converting to numpy arrays
|
||||
"""
|
||||
if not raw_state:
|
||||
return {}
|
||||
# Update current state
|
||||
self.current_state = next_request
|
||||
|
||||
# Validate state structure
|
||||
try:
|
||||
GameStateValidator.validate_game_state(raw_state)
|
||||
except ValueError as e:
|
||||
self.logger.error(f"Invalid game state: {e}")
|
||||
return {}
|
||||
# Process new state for SB3
|
||||
observation = self.state_mapper.process_game_state(self.current_state)
|
||||
|
||||
# TODO: Extract and process features
|
||||
processed = {
|
||||
'hand_features': self._extract_hand_features(raw_state.get('hand', {})),
|
||||
'game_features': self._extract_game_features(raw_state),
|
||||
'available_actions': raw_state.get('available_actions', [])
|
||||
# Calculate reward using expert reward calculator
|
||||
reward = self.reward_calculator.calculate_reward(
|
||||
current_state=self.current_state,
|
||||
prev_state=self.prev_state if self.prev_state else {}
|
||||
)
|
||||
|
||||
# Check if episode is done
|
||||
done = bool(self.current_state.get('game_over', 0)) # TODO send a game_over in our state
|
||||
terminated = done
|
||||
truncated = False # Not using time limits for now
|
||||
|
||||
info = {
|
||||
'balatro_state': self.current_state.get('state', 0),
|
||||
'available_actions': next_request.get('available_actions', [])
|
||||
}
|
||||
# TODO have a feeling observation is not being return when we run out of actions which causes problems
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Clean up environment resources
|
||||
|
||||
return processed
|
||||
|
||||
def _extract_hand_features(self, hand: Dict[str, Any]) -> np.ndarray:
|
||||
"""Extract numerical features from hand"""
|
||||
# TODO: Convert hand cards to numerical representation
|
||||
# This might include:
|
||||
# - Card values (2-14 for 2-A)
|
||||
# - Suits (0-3 for different suits)
|
||||
# - Special abilities/bonuses
|
||||
Call this when shutting down to clean up pipe communication.
|
||||
"""
|
||||
self.pipe_io.cleanup()
|
||||
|
||||
@staticmethod
|
||||
def _build_action_slices(layout: List[Tuple[str, int]]) -> Dict[str, slice]:
|
||||
"""
|
||||
Create slices for our actions so that we can precisely extract the
|
||||
right params to send over to balatro
|
||||
|
||||
cards = hand.get('cards', [])
|
||||
features = []
|
||||
|
||||
for card in cards:
|
||||
# Extract card features
|
||||
suit_encoding = self._encode_suit(card.get('suit', ''))
|
||||
value_encoding = self._encode_value(card.get('base', {}).get('value', ''))
|
||||
nominal = card.get('base', {}).get('nominal', 0)
|
||||
|
||||
# Add ability features
|
||||
ability = card.get('ability', {})
|
||||
ability_features = [
|
||||
ability.get('t_chips', 0),
|
||||
ability.get('t_mult', 0),
|
||||
ability.get('x_mult', 1),
|
||||
ability.get('mult', 0)
|
||||
]
|
||||
|
||||
card_features = [suit_encoding, value_encoding, nominal] + ability_features
|
||||
features.extend(card_features)
|
||||
|
||||
# Pad or truncate to fixed size (e.g., 8 cards max * 7 features each)
|
||||
max_cards = 8
|
||||
features_per_card = 7
|
||||
|
||||
if len(features) < max_cards * features_per_card:
|
||||
features.extend([0] * (max_cards * features_per_card - len(features)))
|
||||
else:
|
||||
features = features[:max_cards * features_per_card]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _extract_game_features(self, state: Dict[str, Any]) -> np.ndarray:
|
||||
"""Extract game-level features"""
|
||||
# TODO: Extract features like:
|
||||
# - Current chips
|
||||
# - Target chips
|
||||
# - Money
|
||||
# - Round/ante
|
||||
# - Discards remaining
|
||||
# - Hands remaining
|
||||
|
||||
features = [
|
||||
state.get('state', 0), # Game state
|
||||
len(state.get('available_actions', [])), # Number of available actions
|
||||
]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _encode_suit(self, suit: str) -> int:
|
||||
"""Encode suit as integer"""
|
||||
suit_map = {'Hearts': 0, 'Diamonds': 1, 'Clubs': 2, 'Spades': 3}
|
||||
return suit_map.get(suit, 0)
|
||||
|
||||
def _encode_value(self, value: str) -> int:
|
||||
"""Encode card value as integer"""
|
||||
if value.isdigit():
|
||||
return int(value)
|
||||
|
||||
value_map = {
|
||||
'Jack': 11, 'Queen': 12, 'King': 13, 'Ace': 14,
|
||||
'A': 14, 'K': 13, 'Q': 12, 'J': 11
|
||||
}
|
||||
return value_map.get(value, 0)
|
||||
|
||||
def _calculate_reward(self) -> float:
|
||||
"""Calculate reward for current state"""
|
||||
# TODO: Implement reward calculation
|
||||
# This might be based on:
|
||||
# - Chips scored
|
||||
# - Blind completion
|
||||
# - Round progression
|
||||
# - Final score
|
||||
|
||||
return 0.0
|
||||
Args:
|
||||
layout: Our ACTION_SLICE_LAYOUT that contains action name and size
|
||||
Return:
|
||||
A dictionary containing a key being our action space slice name, and
|
||||
the slice
|
||||
"""
|
||||
slices = {}
|
||||
start = 0
|
||||
for action_name, size in layout:
|
||||
slices[action_name] = slice(start, start + size)
|
||||
start += size
|
||||
return slices
|
||||
|
|
|
|||
|
|
@ -1,5 +1,11 @@
|
|||
"""
|
||||
Reward calculation for Balatro RL environment
|
||||
Balatro Reward System - The Expert on What's Good/Bad
|
||||
|
||||
This module is the single source of truth for reward calculation in Balatro RL.
|
||||
All reward logic is centralized here to make experimentation and tuning easier.
|
||||
|
||||
The BalatroRewardCalculator analyzes game state changes and assigns rewards
|
||||
that teach the AI what constitutes good vs bad Balatro gameplay.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
|
|
@ -7,114 +13,86 @@ import numpy as np
|
|||
|
||||
class BalatroRewardCalculator:
|
||||
"""
|
||||
Calculates rewards for Balatro RL training
|
||||
Expert reward calculator for Balatro RL training
|
||||
|
||||
This is the single authority on what constitutes good/bad play in Balatro.
|
||||
Centralizes all reward logic for easy experimentation and tuning.
|
||||
|
||||
Reward philosophy:
|
||||
- Positive rewards for progress (chips, rounds, money)
|
||||
- Bonus rewards for efficiency (fewer hands/discards used)
|
||||
- Large rewards for major milestones (completing antes)
|
||||
- Negative rewards for game over (based on progress made)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.prev_score = 0
|
||||
self.prev_money = 0
|
||||
self.prev_round = 0
|
||||
self.prev_ante = 0
|
||||
self.chips = 0
|
||||
|
||||
def calculate_reward(self, current_state: Dict[str, Any],
|
||||
prev_state: Dict[str, Any] = None) -> float:
|
||||
"""
|
||||
Calculate reward based on game state changes
|
||||
Main reward calculation method - analyzes state changes and assigns rewards
|
||||
|
||||
This is the core method that determines what the AI should optimize for.
|
||||
Examines differences between previous and current game state to calculate
|
||||
appropriate rewards for the action that caused the transition.
|
||||
|
||||
Args:
|
||||
current_state: Current game state
|
||||
prev_state: Previous game state (optional)
|
||||
current_state: Current Balatro game state
|
||||
prev_state: Previous Balatro game state (None for first step)
|
||||
|
||||
Returns:
|
||||
Calculated reward
|
||||
Float reward value (positive = good, negative = bad, zero = neutral)
|
||||
"""
|
||||
reward = 0.0
|
||||
|
||||
# Extract relevant metrics
|
||||
current_score = current_state.get('score', 0)
|
||||
current_money = current_state.get('money', 0)
|
||||
current_round = current_state.get('round', 0)
|
||||
current_ante = current_state.get('ante', 0)
|
||||
game_over = current_state.get('game_over', False)
|
||||
current_chips = current_state.get('chips', 0)
|
||||
game_over = current_state.get('game_over', False) #TODO fix
|
||||
|
||||
# Score-based rewards
|
||||
score_diff = current_score - self.prev_score
|
||||
if score_diff > 0:
|
||||
reward += score_diff * 0.001 # Small reward for score increases
|
||||
|
||||
# Money-based rewards
|
||||
money_diff = current_money - self.prev_money
|
||||
if money_diff > 0:
|
||||
reward += money_diff * 0.01 # Reward for gaining money
|
||||
|
||||
# Round progression rewards
|
||||
if current_round > self.prev_round:
|
||||
reward += 10.0 # Big reward for completing rounds
|
||||
|
||||
# Ante progression rewards
|
||||
if current_ante > self.prev_ante:
|
||||
reward += 50.0 # Very big reward for completing antes
|
||||
chip_diff = current_chips - self.chips
|
||||
if chip_diff > 0:
|
||||
reward += chip_diff * 0.001 # Small reward for score increases
|
||||
|
||||
# Game over penalty
|
||||
if game_over:
|
||||
# Penalty based on how early the game ended
|
||||
max_ante = 8 # Typical max ante in Balatro
|
||||
completion_ratio = current_ante / max_ante
|
||||
reward += completion_ratio * 100.0 # Reward based on progress
|
||||
# max_ante = 8 # Typical max ante in Balatro
|
||||
# completion_ratio = current_ante / max_ante
|
||||
# reward += completion_ratio * 100.0 # Reward based on progress
|
||||
reward -= 10 #TODO keeping this simple for now
|
||||
|
||||
# Efficiency rewards (optional)
|
||||
# reward += self._calculate_efficiency_reward(current_state)
|
||||
|
||||
# Update previous state
|
||||
self.prev_score = current_score
|
||||
self.prev_money = current_money
|
||||
self.prev_round = current_round
|
||||
self.prev_ante = current_ante
|
||||
|
||||
return reward
|
||||
|
||||
def _calculate_efficiency_reward(self, state: Dict[str, Any]) -> float:
|
||||
"""
|
||||
Calculate rewards for efficient play
|
||||
|
||||
Args:
|
||||
state: Current game state
|
||||
|
||||
Returns:
|
||||
Efficiency reward
|
||||
"""
|
||||
reward = 0.0
|
||||
|
||||
# Reward for using fewer discards
|
||||
discards_remaining = state.get('discards', 0)
|
||||
if discards_remaining > 0:
|
||||
reward += discards_remaining * 0.1
|
||||
|
||||
# Reward for using fewer hands
|
||||
hands_remaining = state.get('hands', 0)
|
||||
if hands_remaining > 0:
|
||||
reward += hands_remaining * 0.5
|
||||
self.chips = current_chips
|
||||
|
||||
return reward
|
||||
|
||||
def reset(self):
|
||||
"""Reset reward calculator for new episode"""
|
||||
self.prev_score = 0
|
||||
self.prev_money = 0
|
||||
self.prev_round = 0
|
||||
self.prev_ante = 0
|
||||
"""
|
||||
Reset reward calculator state for new episode
|
||||
|
||||
Called at the start of each new Balatro run to clear
|
||||
previous state tracking variables.
|
||||
"""
|
||||
self.chips = 0
|
||||
|
||||
def get_shaped_reward(self, state: Dict[str, Any], action: str) -> float:
|
||||
"""
|
||||
Get shaped reward based on specific actions
|
||||
Calculate action-specific reward shaping
|
||||
|
||||
Provides small rewards/penalties for specific actions to guide
|
||||
AI behavior beyond just game state changes. Use sparingly to
|
||||
avoid overriding the main reward signal.
|
||||
|
||||
Args:
|
||||
state: Current game state
|
||||
action: Action taken
|
||||
state: Current game state when action was taken
|
||||
action: Action that was taken (for action-specific rewards)
|
||||
|
||||
Returns:
|
||||
Shaped reward
|
||||
Small shaped reward value (usually -1 to +1)
|
||||
"""
|
||||
# TODO look at this this could help but currently not in a working state
|
||||
reward = 0.0
|
||||
|
||||
# Action-specific rewards
|
||||
|
|
@ -134,4 +112,4 @@ class BalatroRewardCalculator:
|
|||
# Neutral - depends on if it's a good purchase
|
||||
reward += 0.0
|
||||
|
||||
return reward
|
||||
return reward
|
||||
|
|
|
|||
|
|
@ -1,111 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
File-based communication handler for Balatro RL
|
||||
Watches for request files and responds with AI actions
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
from agent.balatro_agent import BalatroAgent
|
||||
|
||||
# File paths
|
||||
REQUEST_FILE = "/tmp/balatro_request.json"
|
||||
RESPONSE_FILE = "/tmp/balatro_response.json"
|
||||
|
||||
class RequestHandler(FileSystemEventHandler):
|
||||
"""Handles incoming request files from Balatro"""
|
||||
|
||||
def __init__(self):
|
||||
self.agent = BalatroAgent()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def on_created(self, event):
|
||||
"""Handle new request file creation"""
|
||||
if event.src_path == REQUEST_FILE and not event.is_directory:
|
||||
self.process_request()
|
||||
|
||||
def on_modified(self, event):
|
||||
"""Handle request file modification"""
|
||||
if event.src_path == REQUEST_FILE and not event.is_directory:
|
||||
self.process_request()
|
||||
|
||||
def process_request(self):
|
||||
"""Process a request file and generate response"""
|
||||
try:
|
||||
# Small delay to ensure file write is complete
|
||||
time.sleep(0.01)
|
||||
|
||||
# Read request
|
||||
with open(REQUEST_FILE, 'r') as f:
|
||||
request_data = json.load(f)
|
||||
|
||||
self.logger.info(f"Processing request {request_data.get('id')}: state={request_data.get('state')}")
|
||||
|
||||
# Get action from AI agent
|
||||
action_response = self.agent.get_action(request_data)
|
||||
|
||||
# Write response
|
||||
response = {
|
||||
"id": request_data.get("id"),
|
||||
"action": action_response.get("action", "no_action"),
|
||||
"params": action_response.get("params"),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
with open(RESPONSE_FILE, 'w') as f:
|
||||
json.dump(response, f)
|
||||
|
||||
self.logger.info(f"Response written: {response['action']}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing request: {e}")
|
||||
|
||||
# Write error response
|
||||
error_response = {
|
||||
"id": request_data.get("id", 0) if 'request_data' in locals() else 0,
|
||||
"action": "no_action",
|
||||
"params": {"error": str(e)},
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
try:
|
||||
with open(RESPONSE_FILE, 'w') as f:
|
||||
json.dump(error_response, f)
|
||||
except:
|
||||
pass
|
||||
|
||||
def main():
|
||||
"""Main file watcher loop"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Starting Balatro RL file watcher...")
|
||||
|
||||
# Set up file watcher
|
||||
event_handler = RequestHandler()
|
||||
observer = Observer()
|
||||
observer.schedule(event_handler, path="/tmp", recursive=False)
|
||||
|
||||
observer.start()
|
||||
logger.info(f"Watching for requests at: {REQUEST_FILE}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Stopping file watcher...")
|
||||
observer.stop()
|
||||
|
||||
observer.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
numpy>=1.21.0
|
||||
torch>=1.9.0
|
||||
gymnasium>=0.26.0 # For RL environment interface
|
||||
stable-baselines3>=1.6.0 # For RL algorithms
|
||||
stable-baselines3[extra]>=1.6.0 # For RL algorithms
|
||||
tensorboard>=2.8.0 # For training visualization
|
||||
watchdog>=2.1.0 # For file system monitoring
|
||||
|
|
@ -22,5 +22,5 @@ echo "Setup complete!"
|
|||
echo ""
|
||||
echo "Usage:"
|
||||
echo " To activate environment: source venv/bin/activate"
|
||||
echo " To run file watcher: python file_watcher.py"
|
||||
echo " To train agent: python -m agent.training"
|
||||
echo " To test environment is working: python -m ai.test_env"
|
||||
echo " To train agent: python -m ai.train_balatro"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
Environment Testing Script for Balatro RL
|
||||
|
||||
Tests the BalatroEnv to ensure it follows the Gym interface and
|
||||
communicates properly with the Balatro mod via file I/O.
|
||||
|
||||
Usage:
|
||||
python test_env.py
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
# from stable_baselines3.common.env_checker import check_env # Not compatible with pipes
|
||||
from .environment.balatro_env import BalatroEnv
|
||||
|
||||
|
||||
def test_manual_actions():
|
||||
"""Manual testing with random actions"""
|
||||
print("\n=== Manual Action Testing ===")
|
||||
|
||||
env = BalatroEnv()
|
||||
print("BalatroEnv created successfully")
|
||||
|
||||
try:
|
||||
# Reset and get initial observation (ONLY ONCE!)
|
||||
print("Calling env.reset() - this should wait for Balatro...")
|
||||
obs, info = env.reset()
|
||||
print("env.reset() completed!")
|
||||
print(f"Initial observation: {obs}")
|
||||
print(f"Observation space: {env.observation_space}")
|
||||
print(f"Action space: {env.action_space}")
|
||||
print(f"Sample action: {env.action_space.sample()}")
|
||||
|
||||
# Run test episodes
|
||||
max_steps = 3
|
||||
max_episodes = 1
|
||||
|
||||
for episode in range(max_episodes):
|
||||
print(f"\n--- Episode {episode + 1} ---")
|
||||
# Don't call reset() again! Use the obs from above
|
||||
|
||||
for step in range(max_steps):
|
||||
# Get random action (like the old balatro_agent)
|
||||
action = env.action_space.sample()
|
||||
|
||||
print(f"Step {step + 1}: Taking action {action}")
|
||||
|
||||
# Take action
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
print(f" obs={obs}, reward={reward}, done={done}")
|
||||
|
||||
# Optional: Add delay for readability
|
||||
time.sleep(0.1)
|
||||
|
||||
if done:
|
||||
print(f" Episode finished after {step + 1} steps")
|
||||
break
|
||||
|
||||
if not done:
|
||||
print(f" Episode reached max steps ({max_steps})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Manual testing failed: {e}")
|
||||
return False
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
print("✓ Manual testing completed!")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Starting Balatro Environment Testing...")
|
||||
|
||||
# Setup logging to see communication details
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
success = True
|
||||
success &= test_manual_actions()
|
||||
|
||||
if success:
|
||||
print("\n🎉 All tests passed! Environment is ready for training.")
|
||||
else:
|
||||
print("\n❌ Some tests failed. Check the environment implementation.")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,273 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Balatro RL Training Script
|
||||
|
||||
Main training script for teaching AI to play Balatro using Stable-Baselines3.
|
||||
This script creates the Balatro environment, sets up the RL model, and runs training.
|
||||
|
||||
Usage:
|
||||
python train_balatro.py
|
||||
|
||||
Requirements:
|
||||
- Balatro game running with RLBridge mod
|
||||
- file_watcher.py should NOT be running (this replaces it)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# SB3 imports
|
||||
from stable_baselines3 import DQN, A2C, PPO
|
||||
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
|
||||
# Our custom environment
|
||||
from .environment.balatro_env import BalatroEnv
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""Setup logging for training"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('training.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def create_environment():
|
||||
"""Create and wrap the Balatro environment"""
|
||||
# Create base environment
|
||||
env = BalatroEnv()
|
||||
|
||||
# Wrap with Monitor for logging episode stats
|
||||
env = Monitor(env, filename="training_monitor.csv")
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def create_model(env, algorithm="DQN", model_path=None):
|
||||
"""
|
||||
Create SB3 model for training
|
||||
|
||||
Args:
|
||||
env: Balatro environment
|
||||
algorithm: RL algorithm to use ("DQN", "A2C", "PPO")
|
||||
model_path: Path to load existing model (optional)
|
||||
|
||||
Returns:
|
||||
SB3 model ready for training
|
||||
"""
|
||||
if algorithm == "DQN":
|
||||
model = DQN(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=1,
|
||||
learning_rate=1e-4,
|
||||
buffer_size=10000,
|
||||
learning_starts=1000,
|
||||
target_update_interval=500,
|
||||
train_freq=4,
|
||||
gradient_steps=1,
|
||||
exploration_fraction=0.1,
|
||||
exploration_initial_eps=1.0,
|
||||
exploration_final_eps=0.05,
|
||||
tensorboard_log="./tensorboard_logs/"
|
||||
)
|
||||
elif algorithm == "A2C":
|
||||
model = A2C(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=1,
|
||||
learning_rate=7e-4,
|
||||
n_steps=5,
|
||||
gamma=0.99,
|
||||
tensorboard_log="./tensorboard_logs/"
|
||||
)
|
||||
elif algorithm == "PPO":
|
||||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=1,
|
||||
learning_rate=3e-4,
|
||||
n_steps=2048,
|
||||
batch_size=64,
|
||||
n_epochs=10,
|
||||
gamma=0.99,
|
||||
tensorboard_log="./tensorboard_logs/"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown algorithm: {algorithm}")
|
||||
|
||||
# Load existing model if path provided
|
||||
if model_path and Path(model_path).exists():
|
||||
model.load(model_path)
|
||||
print(f"Loaded existing model from {model_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_callbacks(save_freq=1000):
|
||||
"""Create training callbacks for saving and evaluation"""
|
||||
callbacks = []
|
||||
|
||||
# Checkpoint callback - save model periodically
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=save_freq,
|
||||
save_path="./models/",
|
||||
name_prefix="balatro_model"
|
||||
)
|
||||
callbacks.append(checkpoint_callback)
|
||||
|
||||
return callbacks
|
||||
|
||||
|
||||
def train_agent(total_timesteps=50000, algorithm="DQN", save_path="./models/balatro_final"):
|
||||
"""
|
||||
Main training function
|
||||
|
||||
Args:
|
||||
total_timesteps: Number of training steps
|
||||
algorithm: RL algorithm to use
|
||||
save_path: Where to save final model
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Starting Balatro RL training with {algorithm}")
|
||||
logger.info(f"Training for {total_timesteps} timesteps")
|
||||
|
||||
try:
|
||||
# Create environment and model
|
||||
env = create_environment()
|
||||
model = create_model(env, algorithm)
|
||||
|
||||
# Create callbacks
|
||||
callbacks = create_callbacks(save_freq=max(1000, total_timesteps // 20))
|
||||
|
||||
# Train the model
|
||||
logger.info("Starting training...")
|
||||
start_time = time.time()
|
||||
|
||||
model.learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callbacks,
|
||||
progress_bar=True
|
||||
)
|
||||
|
||||
training_time = time.time() - start_time
|
||||
logger.info(f"Training completed in {training_time:.2f} seconds")
|
||||
|
||||
# Save final model
|
||||
model.save(save_path)
|
||||
logger.info(f"Model saved to {save_path}")
|
||||
|
||||
# Clean up environment
|
||||
if hasattr(env, 'cleanup'):
|
||||
env.cleanup()
|
||||
elif hasattr(env.env, 'cleanup'): # Monitor wrapper
|
||||
env.env.cleanup()
|
||||
|
||||
return model
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
if hasattr(env, 'cleanup'):
|
||||
env.cleanup()
|
||||
elif hasattr(env.env, 'cleanup'):
|
||||
env.env.cleanup()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
if hasattr(env, 'cleanup'):
|
||||
env.cleanup()
|
||||
elif hasattr(env.env, 'cleanup'):
|
||||
env.env.cleanup()
|
||||
raise
|
||||
|
||||
|
||||
def test_trained_model(model_path, num_episodes=5):
|
||||
"""
|
||||
Test a trained model
|
||||
|
||||
Args:
|
||||
model_path: Path to trained model
|
||||
num_episodes: Number of episodes to test
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Testing model from {model_path}")
|
||||
|
||||
# Create environment and load model
|
||||
env = create_environment()
|
||||
model = DQN.load(model_path)
|
||||
|
||||
episode_rewards = []
|
||||
|
||||
for episode in range(num_episodes):
|
||||
obs = env.reset()
|
||||
total_reward = 0
|
||||
steps = 0
|
||||
|
||||
while True:
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
total_reward += reward
|
||||
steps += 1
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
episode_rewards.append(total_reward)
|
||||
logger.info(f"Episode {episode + 1}: {steps} steps, reward: {total_reward:.2f}")
|
||||
|
||||
avg_reward = sum(episode_rewards) / len(episode_rewards)
|
||||
logger.info(f"Average reward over {num_episodes} episodes: {avg_reward:.2f}")
|
||||
|
||||
env.cleanup()
|
||||
return episode_rewards
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Setup
|
||||
setup_logging()
|
||||
|
||||
# Create necessary directories
|
||||
Path(".models").mkdir(exist_ok=True)
|
||||
Path(".tensorboard_logs").mkdir(exist_ok=True)
|
||||
|
||||
# We'll let BalatroEnv create the pipes when needed
|
||||
print("🔧 Pipe communication will be initialized with environment...")
|
||||
|
||||
# Train the agent
|
||||
print("\n🎮 Starting Balatro RL Training!")
|
||||
print("Setup steps:")
|
||||
print("1. ✅ Balatro is running with RLBridge mod")
|
||||
print("2. ✅ Balatro is in menu state")
|
||||
print("3. Press Enter below to start training")
|
||||
print("4. When prompted, press 'R' in Balatro within 3 seconds")
|
||||
print("\n📡 Training will create pipes and connect automatically!")
|
||||
print()
|
||||
|
||||
input("Press Enter to start training (then you'll have 3 seconds to press 'R' in Balatro)...")
|
||||
|
||||
try:
|
||||
model = train_agent(
|
||||
total_timesteps=10000, # Start small for testing
|
||||
algorithm="PPO", # PPO supports MultiDiscrete actions
|
||||
save_path="./models/balatro_trained"
|
||||
)
|
||||
|
||||
if model:
|
||||
print("\n🎉 Training completed successfully!")
|
||||
print("Testing the trained model...")
|
||||
|
||||
# Test the trained model
|
||||
test_trained_model("./models/balatro_trained", num_episodes=3)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Training failed: {e}")
|
||||
print("Check the logs for more details.")
|
||||
finally:
|
||||
# Pipes will be cleaned up by the environment
|
||||
print("🧹 Training session ended")
|
||||
|
|
@ -0,0 +1,192 @@
|
|||
"""
|
||||
Balatro Dual Pipe Communication Abstraction
|
||||
|
||||
Pure dual pipe I/O layer with persistent handles for communicating with Balatro mod.
|
||||
Handles low-level pipe operations without any game logic or AI logic.
|
||||
|
||||
This abstraction can be used by:
|
||||
- balatro_env.py (for SB3 training)
|
||||
- file_watcher.py (for testing)
|
||||
- Any other component that needs to talk to Balatro
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class BalatroPipeIO:
|
||||
"""
|
||||
Clean abstraction for dual pipe communication with Balatro mod using persistent handles
|
||||
|
||||
Responsibilities:
|
||||
- Read JSON requests from Balatro mod via request pipe
|
||||
- Write JSON responses back to Balatro mod via response pipe
|
||||
- Keep pipe handles open persistently to avoid deadlocks
|
||||
- Provide simple, clean interface for higher-level code
|
||||
"""
|
||||
|
||||
def __init__(self, request_pipe: str = "/tmp/balatro_request", response_pipe: str = "/tmp/balatro_response"):
|
||||
self.request_pipe = request_pipe
|
||||
self.response_pipe = response_pipe
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Persistent pipe handles
|
||||
self.request_handle = None
|
||||
self.response_handle = None
|
||||
|
||||
# Create pipes and open persistent handles
|
||||
self.create_pipes()
|
||||
self.open_persistent_handles()
|
||||
|
||||
def create_pipes(self) -> None:
|
||||
"""
|
||||
Create dual named pipes for communication
|
||||
|
||||
Creates both request and response pipes if they don't exist.
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
for pipe_path in [self.request_pipe, self.response_pipe]:
|
||||
try:
|
||||
# Remove existing pipe if it exists
|
||||
if os.path.exists(pipe_path):
|
||||
os.unlink(pipe_path)
|
||||
self.logger.debug(f"Removed existing pipe: {pipe_path}")
|
||||
|
||||
# Create named pipe
|
||||
os.mkfifo(pipe_path)
|
||||
self.logger.info(f"Created pipe: {pipe_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create pipe {pipe_path}: {e}")
|
||||
raise RuntimeError(f"Could not create communication pipe: {pipe_path}")
|
||||
|
||||
def open_persistent_handles(self) -> None:
|
||||
"""
|
||||
Open persistent handles for reading and writing
|
||||
|
||||
Opens request pipe for reading and response pipe for writing.
|
||||
Keeps handles open to avoid deadlocks.
|
||||
"""
|
||||
import time
|
||||
|
||||
try:
|
||||
self.logger.info(f"🔧 Waiting for Balatro to connect...")
|
||||
self.logger.info(f" Press 'R' in Balatro now to activate RL training!")
|
||||
|
||||
# Open request pipe for reading (Balatro writes to this)
|
||||
self.request_handle = open(self.request_pipe, 'r')
|
||||
|
||||
# Open response pipe for writing (AI writes to this)
|
||||
self.response_handle = open(self.response_pipe, 'w')
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to open persistent handles: {e}")
|
||||
self.cleanup_handles()
|
||||
raise RuntimeError(f"Could not open persistent pipe handles: {e}")
|
||||
|
||||
def wait_for_request(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Wait for new request from Balatro mod using persistent handle
|
||||
|
||||
Blocks until Balatro writes a request to the pipe.
|
||||
No timeout needed - pipes block until data arrives.
|
||||
|
||||
Returns:
|
||||
Parsed JSON request data, or None if error
|
||||
"""
|
||||
if not self.request_handle:
|
||||
self.logger.error("Request handle not open")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Read from persistent request handle
|
||||
request_line = self.request_handle.readline().strip()
|
||||
if not request_line:
|
||||
return None
|
||||
|
||||
request_data = json.loads(request_line)
|
||||
self.logger.info(f"📥 RECEIVED REQUEST: {request_line}...") # Show first 100 chars
|
||||
return request_data
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.error(f"Invalid JSON in request: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error reading from request pipe: {e}")
|
||||
return None
|
||||
|
||||
def send_response(self, response_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Send response back to Balatro mod using persistent handle
|
||||
|
||||
Writes response data to response pipe for Balatro to read.
|
||||
|
||||
Args:
|
||||
response_data: Response dictionary to send
|
||||
|
||||
Returns:
|
||||
True if successful, False if error
|
||||
"""
|
||||
if not self.response_handle:
|
||||
self.logger.error("Response handle not open")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Write to persistent response handle
|
||||
json.dump(response_data, self.response_handle)
|
||||
self.response_handle.write('\n') # Important: newline for pipe communication
|
||||
self.response_handle.flush() # Force write to pipe immediately
|
||||
|
||||
self.logger.info(f"📤 SENT RESPONSE: {json.dumps(response_data)}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ ERROR sending response: {e}")
|
||||
import traceback
|
||||
self.logger.error(f"❌ TRACEBACK: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def cleanup_handles(self):
|
||||
"""
|
||||
Close persistent pipe handles
|
||||
|
||||
Closes the open file handles.
|
||||
"""
|
||||
if self.request_handle:
|
||||
try:
|
||||
self.request_handle.close()
|
||||
self.logger.debug("Closed request handle")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to close request handle: {e}")
|
||||
self.request_handle = None
|
||||
|
||||
if self.response_handle:
|
||||
try:
|
||||
self.response_handle.close()
|
||||
self.logger.debug("Closed response handle")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to close response handle: {e}")
|
||||
self.response_handle = None
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Clean up communication pipes and handles
|
||||
|
||||
Closes handles and removes pipe files from the filesystem.
|
||||
"""
|
||||
# Close handles first
|
||||
self.cleanup_handles()
|
||||
|
||||
# Remove pipe files
|
||||
for pipe_path in [self.request_pipe, self.response_pipe]:
|
||||
try:
|
||||
if os.path.exists(pipe_path):
|
||||
os.unlink(pipe_path)
|
||||
self.logger.debug(f"Removed pipe: {pipe_path}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to remove pipe {pipe_path}: {e}")
|
||||
|
||||
self.logger.debug("Dual pipe communication cleanup complete")
|
||||
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""
|
||||
Debug utilities for timestamped print statements and debugging helpers
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
def timestamp() -> str:
|
||||
"""
|
||||
Get current timestamp formatted as HH:MM:SS.mmm
|
||||
|
||||
Returns:
|
||||
str: Formatted timestamp string
|
||||
"""
|
||||
now = datetime.datetime.now()
|
||||
return now.strftime("%H:%M:%S.%f")[:-3] # Remove last 3 digits to get milliseconds
|
||||
|
||||
|
||||
def tprint(*args, **kwargs):
|
||||
"""
|
||||
Print with timestamp prefix
|
||||
|
||||
Args:
|
||||
*args: Arguments to print
|
||||
**kwargs: Keyword arguments for print()
|
||||
"""
|
||||
print(f"[{timestamp()}]", *args, **kwargs)
|
||||
|
||||
|
||||
def dprint(prefix: str, *args, **kwargs):
|
||||
"""
|
||||
Debug print with custom prefix and timestamp
|
||||
|
||||
Args:
|
||||
prefix: Custom prefix string
|
||||
*args: Arguments to print
|
||||
**kwargs: Keyword arguments for print()
|
||||
"""
|
||||
print(f"[{timestamp()}] [{prefix}]", *args, **kwargs)
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
"""
|
||||
Logging setup for Balatro RL system
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
def setup_logging(
|
||||
level: str = "INFO",
|
||||
log_file: Optional[str] = None,
|
||||
format_string: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Setup logging configuration
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_file: Optional log file path
|
||||
format_string: Optional custom format string
|
||||
"""
|
||||
if format_string is None:
|
||||
format_string = (
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
# Convert string level to logging constant
|
||||
numeric_level = getattr(logging, level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
raise ValueError(f'Invalid log level: {level}')
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(format_string)
|
||||
|
||||
# Setup root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(numeric_level)
|
||||
|
||||
# Remove existing handlers
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(numeric_level)
|
||||
console_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler (optional)
|
||||
if log_file:
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setLevel(numeric_level)
|
||||
file_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# Set specific logger levels
|
||||
logging.getLogger("uvicorn").setLevel(logging.INFO)
|
||||
logging.getLogger("fastapi").setLevel(logging.INFO)
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger with the specified name
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
|
|
@ -0,0 +1,245 @@
|
|||
"""
|
||||
Mappers for converting between Balatro game data and RL-compatible formats.
|
||||
|
||||
This module handles the two-way data transformation:
|
||||
1. BalatroStateMapper: Converts incoming Balatro JSON to normalized RL observations
|
||||
2. BalatroActionMapper: Converts RL actions to Balatro command JSON
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import Dict, List, Any
|
||||
from ..utils.validation import GameStateValidator, ResponseValidator
|
||||
import logging
|
||||
|
||||
|
||||
|
||||
class BalatroStateMapper:
|
||||
"""
|
||||
Converts raw Balatro game state JSON to normalized RL observations.
|
||||
|
||||
Handles:
|
||||
- Card data normalization
|
||||
- Game state parsing
|
||||
- Observation space formatting
|
||||
"""
|
||||
def __init__(self, observation_size: int, max_actions: int):
|
||||
self.observation_size = observation_size
|
||||
self.max_actions = max_actions
|
||||
|
||||
# Logger
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# GameStateValidator
|
||||
self.game_state_validator = GameStateValidator()
|
||||
|
||||
#TODO review Might be something wrong here
|
||||
def process_game_state(self, raw_state: Dict[str, Any] | None) -> np.ndarray:
|
||||
"""
|
||||
Convert Balatro's raw JSON game state into neural network input format
|
||||
converting into standardized numerical arrays that neural networks can
|
||||
process.
|
||||
|
||||
Args:
|
||||
raw_state: Raw game state from Balatro mod JSON
|
||||
|
||||
Returns:
|
||||
Processed state suitable for RL training
|
||||
"""
|
||||
if not raw_state:
|
||||
return np.zeros(self.observation_size, dtype=np.float32) # TODO maybe raise error?
|
||||
|
||||
# Validate game state request
|
||||
try:
|
||||
self.game_state_validator.validate_game_state(raw_state)
|
||||
except ValueError as e:
|
||||
self.logger.error(f"Invalid game state: {e}")
|
||||
|
||||
hand_features = self._extract_hand_features(raw_state.get('hand', {}))
|
||||
game_features = self._extract_game_features(raw_state)
|
||||
available_actions = self._extract_available_actions(raw_state.get('available_actions', []))
|
||||
chips = self._extract_chip_features(raw_state)
|
||||
# TODO extract state
|
||||
|
||||
return np.concatenate([hand_features, game_features, available_actions, chips])
|
||||
|
||||
def _extract_available_actions(self, available_actions: List[int]) -> np.ndarray:
|
||||
"""
|
||||
Convert available actions into a np array
|
||||
Args:
|
||||
available_actions: Available actions list from Balatro game state
|
||||
Returns:
|
||||
Fixed-size numpy array of hand features
|
||||
"""
|
||||
mask = np.zeros(self.max_actions, dtype=np.float32)
|
||||
for action_id in available_actions:
|
||||
mask[action_id] = 1.0
|
||||
return mask
|
||||
|
||||
def _extract_hand_features(self, hand: Dict[str, Any]) -> np.ndarray:
|
||||
"""
|
||||
Convert Balatro hand data into numerical features for neural network
|
||||
|
||||
Transforms card dictionaries into fixed-size numerical arrays:
|
||||
- Card values: 2-14 (2 through Ace)
|
||||
- Suits: 0-3 (Hearts, Diamonds, Clubs, Spades)
|
||||
- Card abilities: chips, mult, special effects
|
||||
- Pads/truncates to fixed hand size for consistent input
|
||||
|
||||
Args:
|
||||
hand: Hand dictionary from Balatro game state
|
||||
|
||||
Returns:
|
||||
Fixed-size numpy array of hand features
|
||||
"""
|
||||
|
||||
cards = hand.get('cards', [])
|
||||
features = []
|
||||
|
||||
for card in cards:
|
||||
# Extract card features
|
||||
suit_encoding = self._encode_suit(card.get('suit', ''))
|
||||
value_encoding = self._encode_value(card.get('base', {}).get('value', ''))
|
||||
nominal = card.get('base', {}).get('nominal', 0)
|
||||
|
||||
# Add ability features
|
||||
ability = card.get('ability', {})
|
||||
ability_features = [
|
||||
ability.get('t_chips', 0),
|
||||
ability.get('t_mult', 0),
|
||||
ability.get('x_mult', 1),
|
||||
ability.get('mult', 0)
|
||||
]
|
||||
|
||||
card_features = [suit_encoding, value_encoding, nominal] + ability_features
|
||||
features.extend(card_features)
|
||||
|
||||
# Pad or truncate to fixed size (e.g., 8 cards max * 7 features each)
|
||||
# TODO hand.get('size')
|
||||
# TODO hand.get('highlighted_count')
|
||||
max_cards = 8
|
||||
features_per_card = 7
|
||||
|
||||
if len(features) < max_cards * features_per_card:
|
||||
features.extend([0] * (max_cards * features_per_card - len(features)))
|
||||
else:
|
||||
features = features[:max_cards * features_per_card]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _extract_game_features(self, state: Dict[str, Any]) -> np.ndarray:
|
||||
"""
|
||||
Extract numerical game-level features from Balatro state
|
||||
|
||||
Converts game metadata into neural network inputs:
|
||||
- Current game state (menu, selecting hand, etc.)
|
||||
- Available actions count
|
||||
- Money, chips, round progression
|
||||
- Remaining hands/discards
|
||||
|
||||
Args:
|
||||
state: Full Balatro game state dictionary
|
||||
|
||||
Returns:
|
||||
Numpy array of normalized game features
|
||||
"""
|
||||
|
||||
features = [
|
||||
state.get('state', 0), # Game state
|
||||
len(state.get('available_actions', [])), # Number of available actions
|
||||
]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _extract_chip_features(self, state: Dict[str, Any]) -> np.ndarray:
|
||||
"""
|
||||
Extract information relating to chips
|
||||
|
||||
Args:
|
||||
state: Full Balatro game state dictionary
|
||||
Returns:
|
||||
Numpy array of normalized chip features
|
||||
"""
|
||||
features = [
|
||||
state.get('chips', 0),
|
||||
state.get('chip_target', 0)
|
||||
]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _encode_suit(self, suit: str) -> int:
|
||||
"""Encode suit as integer"""
|
||||
suit_map = {'Hearts': 0, 'Diamonds': 1, 'Clubs': 2, 'Spades': 3}
|
||||
return suit_map.get(suit, 0)
|
||||
|
||||
def _encode_value(self, value: str) -> int:
|
||||
"""Encode card value as integer"""
|
||||
if value.isdigit():
|
||||
return int(value)
|
||||
|
||||
value_map = {
|
||||
'Jack': 11, 'Queen': 12, 'King': 13, 'Ace': 14,
|
||||
'A': 14, 'K': 13, 'Q': 12, 'J': 11
|
||||
}
|
||||
return value_map.get(value, 0)
|
||||
|
||||
|
||||
|
||||
class BalatroActionMapper:
|
||||
"""
|
||||
Converts RL actions to Balatro command JSON.
|
||||
|
||||
Handles:
|
||||
- Binary action conversion to card indices
|
||||
- Action validation
|
||||
- JSON response formatting
|
||||
"""
|
||||
|
||||
def __init__(self, action_slices: Dict[str, slice]):
|
||||
self.slices = action_slices
|
||||
|
||||
# Validator
|
||||
self.response_validator = ResponseValidator()
|
||||
|
||||
# Logger
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def process_action(self, rl_action: np.ndarray) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert RL action to Balatro JSON.
|
||||
|
||||
Args:
|
||||
rl_action: Binary action array from RL agent
|
||||
game_state: Current game state for validation
|
||||
|
||||
Returns:
|
||||
JSON response formatted for Balatro mod
|
||||
"""
|
||||
ai_action = rl_action[self.slices["action_selection"]].tolist()[0]
|
||||
response_data = {
|
||||
"action": ai_action,
|
||||
"params": self._extract_select_hand_params(rl_action),
|
||||
}
|
||||
self.response_validator.validate_response(response_data)
|
||||
|
||||
# Validate action structure
|
||||
try:
|
||||
self.response_validator.validate_response(response_data)
|
||||
except ValueError as e:
|
||||
self.logger.error(f"Invalid game state: {e}")
|
||||
|
||||
return response_data
|
||||
|
||||
def _extract_select_hand_params(self, raw_action: np.ndarray) -> List[int]:
|
||||
"""
|
||||
Converts the raw action to a list of Lua card indices
|
||||
|
||||
Args:
|
||||
raw_action: The whole action from the RL agent
|
||||
Returns:
|
||||
List of 1-based card indices for Lua
|
||||
"""
|
||||
card_indices = raw_action[self.slices["card_indices"]]
|
||||
return [i + 1 for i, val in enumerate(card_indices) if val == 1]
|
||||
|
||||
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
"""
|
||||
JSON Serialization utilities for Balatro RL
|
||||
Handles conversion between game state formats
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
|
||||
class GameStateSerializer:
|
||||
"""Handles serialization/deserialization of game states"""
|
||||
|
||||
@staticmethod
|
||||
def serialize_game_state(game_state: Dict[str, Any]) -> str:
|
||||
"""Convert game state dict to JSON string"""
|
||||
try:
|
||||
return json.dumps(game_state, indent=2, ensure_ascii=False)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Failed to serialize game state: {e}")
|
||||
|
||||
@staticmethod
|
||||
def deserialize_game_state(json_string: str) -> Dict[str, Any]:
|
||||
"""Convert JSON string to game state dict"""
|
||||
try:
|
||||
return json.loads(json_string)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to deserialize game state: {e}")
|
||||
|
||||
@staticmethod
|
||||
def serialize_action(action: Dict[str, Any]) -> str:
|
||||
"""Convert action dict to JSON string"""
|
||||
try:
|
||||
return json.dumps(action, ensure_ascii=False)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Failed to serialize action: {e}")
|
||||
|
||||
@staticmethod
|
||||
def deserialize_action(json_string: str) -> Dict[str, Any]:
|
||||
"""Convert JSON string to action dict"""
|
||||
try:
|
||||
return json.loads(json_string)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to deserialize action: {e}")
|
||||
|
||||
|
||||
class GameStateValidator:
|
||||
"""Validates game state structure and content"""
|
||||
|
||||
@staticmethod
|
||||
def validate_game_state(game_state: Dict[str, Any]) -> bool:
|
||||
"""Validate that game state has required fields"""
|
||||
required_fields = ['state', 'available_actions']
|
||||
|
||||
for field in required_fields:
|
||||
if field not in game_state:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
# Validate hand structure if present
|
||||
if 'hand' in game_state:
|
||||
GameStateValidator._validate_hand(game_state['hand'])
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _validate_hand(hand: Dict[str, Any]) -> bool:
|
||||
"""Validate hand structure"""
|
||||
if not isinstance(hand, dict):
|
||||
raise ValueError("Hand must be a dictionary")
|
||||
|
||||
if 'cards' in hand and not isinstance(hand['cards'], list):
|
||||
raise ValueError("Hand cards must be a list")
|
||||
|
||||
# Validate each card
|
||||
for i, card in enumerate(hand.get('cards', [])):
|
||||
GameStateValidator._validate_card(card, i)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _validate_card(card: Dict[str, Any], index: int) -> bool:
|
||||
"""Validate individual card structure"""
|
||||
required_fields = ['suit', 'base']
|
||||
|
||||
for field in required_fields:
|
||||
if field not in card:
|
||||
raise ValueError(f"Card {index} missing required field: {field}")
|
||||
|
||||
# Validate base structure
|
||||
if 'base' in card:
|
||||
base = card['base']
|
||||
if 'value' not in base or 'nominal' not in base:
|
||||
raise ValueError(f"Card {index} base missing value or nominal")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_action(action: Dict[str, Any]) -> bool:
|
||||
"""Validate action structure"""
|
||||
if 'action' not in action:
|
||||
raise ValueError("Action must have 'action' field")
|
||||
|
||||
action_type = action['action']
|
||||
|
||||
# Validate specific action types
|
||||
if action_type in ['play_hand', 'discard']:
|
||||
if 'card_indices' not in action:
|
||||
raise ValueError(f"Action {action_type} requires card_indices")
|
||||
|
||||
if not isinstance(action['card_indices'], list):
|
||||
raise ValueError("card_indices must be a list")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def to_json(obj: Dict[str, Any]) -> str:
|
||||
"""Quick serialize to JSON"""
|
||||
return GameStateSerializer.serialize_game_state(obj)
|
||||
|
||||
def from_json(json_str: str) -> Dict[str, Any]:
|
||||
"""Quick deserialize from JSON"""
|
||||
return GameStateSerializer.deserialize_game_state(json_str)
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Game State Validation utilities for Balatro RL
|
||||
Validates game state structure and content
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
|
||||
class GameStateValidator:
|
||||
"""
|
||||
Validates game state json structure
|
||||
|
||||
Request Contract (Macro level not everything)
|
||||
{
|
||||
game_state: Dict, # The whole game state table
|
||||
available_actions: List, # The available actions all as integers
|
||||
retry_count: int, # TODO update observation space. this is to handle retries if AI fails
|
||||
}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def validate_game_state(game_state: Dict[str, Any]) -> bool:
|
||||
"""Validate that game state has required fields"""
|
||||
required_fields = ['game_state', 'available_actions', 'retry_count']
|
||||
# TODO we can add more validations as needed like ensuring available_actions is a list and stuff
|
||||
# TODO game_over might be a validation
|
||||
# TODO validate chips for rewards
|
||||
|
||||
for field in required_fields:
|
||||
if field not in game_state:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
# Validate hand structure if present
|
||||
if 'hand' in game_state:
|
||||
GameStateValidator._validate_hand(game_state['hand'])
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _validate_hand(hand: Dict[str, Any]) -> bool:
|
||||
"""Validate hand structure"""
|
||||
if not isinstance(hand, dict):
|
||||
raise ValueError("Hand must be a dictionary")
|
||||
|
||||
if 'cards' in hand and not isinstance(hand['cards'], list):
|
||||
raise ValueError("Hand cards must be a list")
|
||||
|
||||
# Validate each card
|
||||
for i, card in enumerate(hand.get('cards', [])):
|
||||
GameStateValidator._validate_card(card, i)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _validate_card(card: Dict[str, Any], index: int) -> bool:
|
||||
"""Validate individual card structure"""
|
||||
required_fields = ['suit', 'base']
|
||||
|
||||
for field in required_fields:
|
||||
if field not in card:
|
||||
raise ValueError(f"Card {index} missing required field: {field}")
|
||||
|
||||
# Validate base structure
|
||||
if 'base' in card:
|
||||
base = card['base']
|
||||
if 'value' not in base or 'nominal' not in base:
|
||||
raise ValueError(f"Card {index} base missing value or nominal")
|
||||
|
||||
return True
|
||||
|
||||
class ResponseValidator:
|
||||
"""
|
||||
Validates response json structure
|
||||
|
||||
Response contract
|
||||
{
|
||||
action, # The action to take
|
||||
params, # The params in the event the action takes in params
|
||||
}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def validate_response(response: Dict[str, Any]):
|
||||
required_fields = ["action"]
|
||||
for field in required_fields:
|
||||
if field not in response:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
Loading…
Reference in New Issue