Have a successful pipe communication between AI and balatro with some intial iteration going

This commit is contained in:
Angel Valentin 2025-07-19 17:02:15 -04:00
parent bbeab8c635
commit 9b5f949a0b
25 changed files with 1364 additions and 971 deletions

10
.gitignore vendored
View File

@ -56,13 +56,17 @@ logs/
*.temp *.temp
.cache/ .cache/
# Model checkpoints and training artifacts # Training artifacts (generated during RL training sessions)
models/
tensorboard_logs/
training.log
training_monitor.csv
# Other ML artifacts
*.pth *.pth
*.pt *.pt
*.ckpt *.ckpt
checkpoints/ checkpoints/
models/
tensorboard_logs/
wandb/ wandb/
# Large reference files (don't commit decompiled game source) # Large reference files (don't commit decompiled game source)

View File

@ -55,11 +55,32 @@ Balatro uses numbered states defined in `globals.lua`:
- `G.GAME` - Main game data structure - `G.GAME` - Main game data structure
- `G.FUNCS` - Game function callbacks - `G.FUNCS` - Game function callbacks
## Communication Architecture
### Dual Pipe Communication
The mod uses separate named pipes for request/response communication with the AI:
- **Request pipe path**: `/tmp/balatro_request` (Balatro writes, AI reads)
- **Response pipe path**: `/tmp/balatro_response` (AI writes, Balatro reads)
- **Protocol**: Persistent pipe handles with blocking reads
- **Flow**:
1. Balatro opens persistent handles to both pipes
2. Balatro writes JSON request to request pipe
3. AI reads request from request pipe
4. AI writes JSON response to response pipe
5. Balatro performs blocking read from response pipe
### Benefits of Dual Pipes
- Clear separation of request/response channels
- Persistent handles avoid repeated open/close overhead
- Blocking reads ensure proper synchronization
- Each pipe has single direction, eliminating read/write conflicts
## Current Status ## Current Status
The mod can successfully: The mod can successfully:
- Start a run automatically - Start a run automatically
- Detect blind selection state - Detect blind selection state
- Automatically select the next blind - Automatically select the next blind
- Communicate with AI via dual pipe system
## Communication with Claude ## Communication with Claude
- Please keep chats as efficient as possible and brief. No fluff talk, just get to the point - Please keep chats as efficient as possible and brief. No fluff talk, just get to the point

View File

@ -28,7 +28,12 @@ Made a symlink -> ln -s ~/dev/balatro-rl/RLBridge /mnt/gamerlinuxssd/SteamLibrar
- [ ] Training loop integration - [ ] Training loop integration
### Game Features ### Game Features
- [ ] Always have restart_run as an action option assuming the game is ongoing
- [ ] Blind selection choices (skip vs select) - [ ] Blind selection choices (skip vs select)
- [ ] Extended game state (money, discards, hands played) - [ ] Extended game state (money, discards, hands played)
- [ ] Shop interactions - [ ] Shop interactions
- [ ] Joker management - [ ] Joker management
### RL Enhancements
- [ ] Add a "Replay System" to analyze successful actions. For example, save seed, have an action log for reproduction etc
Can probably do this by adding it before the game is reset like a check on what criteria I want to save for, and save

View File

@ -13,6 +13,10 @@ local utils = require("utils")
local last_state_hash = nil local last_state_hash = nil
local last_actions_hash = nil local last_actions_hash = nil
local pending_action = nil local pending_action = nil
local retry_count = 0
local need_retry_request = false
local rl_training_active = false
local last_key_pressed = nil
--- Initialize AI system --- Initialize AI system
--- Sets up communication and prepares the AI for operation --- Sets up communication and prepares the AI for operation
@ -20,12 +24,50 @@ local pending_action = nil
function AI.init() function AI.init()
utils.log_ai("Initializing AI system...") utils.log_ai("Initializing AI system...")
communication.init() communication.init()
-- Hook into Love2D keyboard events
if love and love.keypressed then
local original_keypressed = love.keypressed
love.keypressed = function(key)
-- Store the key press for our AI
last_key_pressed = key
-- Call original function
if original_keypressed then
original_keypressed(key)
end
end
else
end
end end
--- Main AI update loop (called every frame) --- Main AI update loop (called every frame)
--- Monitors game state changes, handles communication, and executes AI actions --- Monitors game state changes, handles communication, and executes AI actions
--- @return nil --- @return nil
function AI.update() function AI.update()
-- Check for key press to start/stop RL training
if last_key_pressed then
if last_key_pressed == "r" then
if not rl_training_active then
rl_training_active = true
utils.log_ai("🚀 RL Training STARTED (R pressed)")
end
elseif last_key_pressed == "t" then
if rl_training_active then
rl_training_active = false
utils.log_ai("⏹️ RL Training STOPPED (T pressed)")
end
end
-- Clear the key press
last_key_pressed = nil
end
-- Don't process AI requests unless RL training is active
if not rl_training_active then
return
end
-- Get current game state -- Get current game state
local current_state = output.get_game_state() local current_state = output.get_game_state()
local available_actions = action.get_available_actions() local available_actions = action.get_available_actions()
@ -40,29 +82,40 @@ function AI.update()
return return
end end
-- Create simple hash to detect state changes -- Create hash to detect state changes
local state_hash = AI.hash_state(current_state) local state_hash = AI.hash_state(current_state)
local actions_hash = AI.hash_actions(available_actions) local actions_hash = AI.hash_actions(available_actions)
-- Request action from AI if state or actions changed if state_hash ~= last_state_hash or actions_hash ~= last_actions_hash or need_retry_request then
if state_hash ~= last_state_hash or actions_hash ~= last_actions_hash then -- State has changed
if state_hash ~= last_state_hash then if state_hash ~= last_state_hash then
utils.log_ai("State changed to: " .. utils.log_ai("State changed to: " ..
current_state.state .. " (" .. utils.state_name(current_state.state) .. ")") current_state.state .. " (" .. utils.get_state_name(current_state.state) .. ")")
action.reset_state() action.reset_state()
last_state_hash = state_hash last_state_hash = state_hash
end end
-- Available actions have changed
if actions_hash ~= last_actions_hash then if actions_hash ~= last_actions_hash then
utils.log_ai("Available actions changed: " .. utils.log_ai("Available actions changed: " ..
table.concat(utils.get_action_names(available_actions), ", ")) table.concat(utils.get_action_names(available_actions), ", "))
last_actions_hash = actions_hash last_actions_hash = actions_hash
end end
-- Request action from AI via file I/O -- Request action from AI
local ai_response = communication.request_action(current_state, available_actions) need_retry_request = false
if ai_response and ai_response.action ~= "no_action" then local ai_response = communication.request_action(current_state, available_actions, retry_count)
pending_action = ai_response
if ai_response then
-- Handling handshake
if ai_response.action == "ready" then
utils.log_ai("Handshake complete - AI ready")
-- Don't return - continue normal game loop
-- Force a state check on next frame to send real request
last_state_hash = nil
else
pending_action = ai_response
end
end end
end end
@ -71,8 +124,12 @@ function AI.update()
local result = action.execute_action(pending_action.action, pending_action.params) local result = action.execute_action(pending_action.action, pending_action.params)
if result.success then if result.success then
utils.log_ai("Action executed successfully: " .. pending_action.action) utils.log_ai("Action executed successfully: " .. pending_action.action)
retry_count = 0
else else
utils.log_ai("Action failed: " .. (result.error or "Unknown error")) -- Update retry_count to indicate state change
utils.log_ai("Action failed: " .. (result.error or "Unknown error") .. " RETRYING...")
retry_count = retry_count + 1
need_retry_request = true
end end
pending_action = nil pending_action = nil
end end
@ -83,7 +140,7 @@ end
--- @param game_state table Current game state data --- @param game_state table Current game state data
--- @return string Hash representing the current state --- @return string Hash representing the current state
function AI.hash_state(game_state) function AI.hash_state(game_state)
return game_state.state .. "_" .. (game_state.round or 0) .. "_" .. (game_state.chips or 0) return game_state.state .. "_" .. (game_state.chips or 0) -- TODO update hash to be more unique
end end
--- Create simple hash of available actions --- Create simple hash of available actions

View File

@ -1,124 +1,117 @@
--- RLBridge communication module --- RLBridge communication module
--- Handles file-based communication between the game and external AI system --- Handles dual pipe communication with persistent handles between the game and external AI system
local COMM = {} local COMM = {}
local utils = require("utils") local utils = require("utils")
local action = require("actions")
local json = require("dkjson") local json = require("dkjson")
-- File-based communication settings -- Dual pipe communication settings with persistent handles
local comm_enabled = false local comm_enabled = false
local request_file = "/tmp/balatro_request.json" local request_pipe = "/tmp/balatro_request"
local response_file = "/tmp/balatro_response.json" local response_pipe = "/tmp/balatro_response"
local request_counter = 0 local request_handle = nil
local last_response_time = 0 local response_handle = nil
--- Check if response file has new content --- Initialize dual pipe communication with persistent handles
--- @param expected_id number Expected request ID --- Sets up persistent pipe handles with external AI system
--- @return table|nil Response data if new response available, nil otherwise
local function check_for_response(expected_id)
local response_file_handle = io.open(response_file, "r")
if not response_file_handle then
return nil
end
local response_json = response_file_handle:read("*all")
response_file_handle:close()
if not response_json or response_json == "" then
return nil
end
local response_data = json.decode(response_json)
if not response_data then
return nil
end
-- Check if this is a new response for our request
if response_data.id == expected_id and response_data.timestamp then
if response_data.timestamp > last_response_time then
last_response_time = response_data.timestamp
return response_data
end
end
return nil
end
--- Initialize file-based communication
--- Sets up file-based communication channel with external AI system
--- @return nil --- @return nil
function COMM.init() function COMM.init()
utils.log_comm("Initializing file-based communication...") utils.log_comm("Initializing dual pipe communication with persistent handles...")
comm_enabled = true utils.log_comm("Note: Pipes will be opened when first needed (lazy initialization)")
last_response_time = 0 -- Reset response tracking comm_enabled = true -- Enable communication, pipes will open on first use
utils.log_comm("Ready for file-based requests")
utils.log_comm("Request file: " .. request_file)
utils.log_comm("Response file: " .. response_file)
end end
--- Send game turn request to AI and get action via files --- Lazy initialization of pipe handles when first needed
--- @return boolean True if pipes are ready, false otherwise
function COMM.ensure_pipes_open()
if request_handle and response_handle then
return true -- Already open
end
-- Try to open pipes (this will block until AI creates them)
utils.log_comm("Opening persistent pipe handles...")
-- Open response pipe for reading (keep open)
response_handle = io.open(response_pipe, "r")
if not response_handle then
utils.log_comm("ERROR: Cannot open response pipe for reading: " .. response_pipe)
return false
end
-- Open request pipe for writing (keep open)
request_handle = io.open(request_pipe, "w")
if not request_handle then
utils.log_comm("ERROR: Cannot open request pipe for writing: " .. request_pipe)
if response_handle then
response_handle:close()
response_handle = nil
end
return false
end
utils.log_comm("Persistent pipe handles opened successfully")
utils.log_comm("Request pipe (write): " .. request_pipe)
utils.log_comm("Response pipe (read): " .. response_pipe)
return true
end
--- Send game turn request to AI and get action via persistent pipe handles
--- @param game_state table Current game state data --- @param game_state table Current game state data
--- @param available_actions table Available actions list --- @param available_actions table Available actions list
--- @return table|nil Action response from AI, nil if error --- @return table|nil Action response from AI, nil if error
function COMM.request_action(game_state, available_actions) function COMM.request_action(game_state, available_actions, retry_count)
if not comm_enabled then if not comm_enabled then
utils.log_comm("ERROR: Communication not enabled")
return nil return nil
end end
request_counter = request_counter + 1 -- Lazy initialization - open pipes when first needed
if not COMM.ensure_pipes_open() then
utils.log_comm("ERROR: Failed to open pipe handles")
return nil
end
local request = { local request = {
id = request_counter,
timestamp = os.time(),
game_state = game_state, game_state = game_state,
available_actions = available_actions or {} available_actions = available_actions or {},
retry_count = retry_count,
} }
utils.log_comm("Requesting action for state: " .. tostring(game_state.state)) utils.log_comm(utils.get_timestamp() .. "Sending action request for state: " ..
tostring(game_state.state) .. " (" .. utils.get_state_name(game_state.state) .. ")")
-- Write request to file -- Encode request as JSON
local json_data = json.encode(request) local json_data = json.encode(request)
if not json_data then if not json_data then
utils.log_comm("ERROR: Failed to encode request JSON") utils.log_comm("ERROR: Failed to encode request JSON")
return nil return nil
end end
local request_file_handle = io.open(request_file, "w") -- Write request to persistent handle
if not request_file_handle then request_handle:write(json_data .. "\n")
utils.log_comm("ERROR: Cannot write to request file: " .. request_file) request_handle:flush() -- Ensure data is sent immediately
utils.log_comm(utils.get_timestamp() .. "Request sent...")
-- Read response from persistent handle
utils.log_comm(utils.get_timestamp() .. "about to read the respones")
local response_json = response_handle:read("*line")
if not response_json or response_json == "" then
utils.log_comm("ERROR: No response received from AI")
return nil return nil
end end
request_file_handle:write(json_data) local response_data = json.decode(response_json)
request_file_handle:close() if not response_data then
utils.log_comm("ERROR: Failed to decode response JSON")
-- Wait for response file (with timeout) return nil
local max_wait_time = 2 -- seconds
local wait_interval = 0.05 -- seconds
local total_waited = 0
while total_waited < max_wait_time do
local response_data = check_for_response(request_counter)
if response_data then
utils.log_comm("AI action: " .. tostring(response_data.action))
return response_data
end
-- Sleep for a short time using busy wait
local start_time = os.clock()
while (os.clock() - start_time) < wait_interval do
-- Busy wait for short duration
end
total_waited = total_waited + wait_interval
end end
utils.log_comm("ERROR: Timeout waiting for AI response") utils.log_comm("AI action: " .. tostring(response_data.action))
return nil return response_data
end end
--- Check if file communication is enabled --- Check if pipe communication is enabled
--- Returns the current communication status --- Returns the current communication status
--- @return boolean True if enabled, false otherwise --- @return boolean True if enabled, false otherwise
function COMM.is_connected() function COMM.is_connected()
@ -126,12 +119,22 @@ function COMM.is_connected()
end end
--- Close communication --- Close communication
--- Terminates the communication channel with the AI system --- Terminates the persistent pipe handles with the AI system
--- @return nil --- @return nil
function COMM.close() function COMM.close()
comm_enabled = false comm_enabled = false
last_response_time = 0
utils.log_comm("File communication disabled") if request_handle then
request_handle:close()
request_handle = nil
end
if response_handle then
response_handle:close()
response_handle = nil
end
utils.log_comm("Persistent pipe communication closed")
end end
return COMM return COMM

View File

@ -15,6 +15,7 @@ G.SETTINGS.reduced_motion = true
function INIT.start_run() function INIT.start_run()
print("Starting balatro run") print("Starting balatro run")
-- Initialize AI system -- Initialize AI system
local ai = require("ai") local ai = require("ai")
ai.init() ai.init()

View File

@ -4,15 +4,20 @@
local O = {} local O = {}
local actions = require("actions")
--- Get comprehensive game state for AI --- Get comprehensive game state for AI
--- Collects all relevant game information into a structured format --- Collects all relevant game information into a structured format
--- @return table Complete game state data for AI processing --- @return table Complete game state data for AI processing
function O.get_game_state() function O.get_game_state()
-- Calculate game_over
local game_over = 0
if G.STATE == G.STATES.GAME_OVER then
game_over = 1
end
return { return {
-- Basic state info -- Basic state info
state = G.STATE, state = G.STATE,
game_over = game_over,
-- Game progression -- Game progression
-- round = G.GAME and G.GAME.round or 0, -- round = G.GAME and G.GAME.round or 0,
@ -27,16 +32,9 @@ function O.get_game_state()
-- Hand info (if applicable) -- Hand info (if applicable)
hand = O.get_hand_info(), hand = O.get_hand_info(),
-- -- Jokers info -- Chips/score info
-- jokers = O.get_jokers_info(), chips = G.GAME and G.GAME.chips or 0,
chip_target = G.GAME and G.GAME.blind and G.GAME.blind.chips or 0,
-- -- Chips/score info
-- chips = G.GAME and G.GAME.chips or 0,
-- chip_target = G.GAME and G.GAME.blind and G.GAME.blind.chips or 0,
-- -- Action feedback for AI
-- last_action_success = true, -- Did last action work?
-- last_action_error = nil, -- What went wrong?
} }
end end

View File

@ -46,7 +46,7 @@ end
--- Maps numeric state IDs to their string representations for debugging --- Maps numeric state IDs to their string representations for debugging
--- @param state_id number Numeric state identifier --- @param state_id number Numeric state identifier
--- @return string Human-readable state name or "UNKNOWN_STATE" --- @return string Human-readable state name or "UNKNOWN_STATE"
function UTIL.state_name(state_id) function UTIL.get_state_name(state_id)
for key, value in pairs(G.STATES) do for key, value in pairs(G.STATES) do
if value == state_id then if value == state_id then
return key return key
@ -133,4 +133,22 @@ function UTIL.get_action_names(available_actions)
return names return names
end end
function UTIL.contains(tbl, val)
for x, _ in ipairs(tbl) do
if tbl[x] == val then
return true
end
end
return false
end
--- Get current timestamp as formatted string
--- Returns current time in HH:MM:SS.mmm format for debugging
--- @return string Formatted timestamp
function UTIL.get_timestamp()
local time = os.time()
local ms = (os.clock() % 1) * 1000
return os.date("%H:%M:%S", time) .. string.format(".%03d", ms)
end
return UTIL return UTIL

View File

View File

@ -1,95 +0,0 @@
"""
Balatro RL Agent
Contains the AI logic for making decisions in Balatro
"""
from typing import Dict, Any, List
import random
import logging
class BalatroAgent:
"""Main AI agent for playing Balatro"""
def __init__(self):
self.logger = logging.getLogger(__name__)
# TODO: Initialize your RL model here
# self.model = load_model("path/to/model")
# self.training = True
def get_action(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""
Given game state, return the action to take
Args:
game_state: Current game state from Balatro mod
Returns:
Action dictionary to send back to mod
"""
try:
# Extract relevant info from game state
available_actions = request.get('available_actions', [])
game_state = request.get('game_state', {})
self.logger.info(f"Processing state {game_state.get('state', 0)} with {len(available_actions)} actions")
# TODO: Replace with actual RL logic
action = self._get_random_action(available_actions) # TODO we are getting random actions for testing
# TODO: If training, update model with reward
# if self.training:
# self._update_model(game_state, action, reward)
return action
except Exception as e:
self.logger.error(f"Error in get_action: {e}")
return {"action": "no_action"}
def _get_random_action(self, available_actions: List) -> Dict[str, Any]:
"""Temporary random action selection - replace with RL logic"""
# available_actions comes as [1, 2] format
if not available_actions or not isinstance(available_actions, List):
return {"action": "no_action"}
# Simple random action selection
selected_action_id = random.choice(available_actions)
self.logger.info(f"Selected action ID: {selected_action_id} from available: {available_actions}")
# TODO this is for testing, the AI doesn't really care about what the action_names are
# but we do so we can test out actions and make sure they work
# Map action IDs to names for logic (but return the ID)
action_names = {1: "start_run", 2: "select_blind", 3: "select_hand"}
action_name = action_names.get(selected_action_id, "unknown")
# Pick random cards for now
params = {}
if action_name == "select_hand":
params["card_indices"] = [1, 3, 5]
# Return the action ID (number), not the name
return {"action": selected_action_id, "params": params}
def train_step(self, game_state: Dict, action: Dict, reward: float, next_state: Dict):
"""
Perform one training step
TODO: Implement RL training logic here
- Update Q-values
- Update neural network weights
- Store experience in replay buffer
"""
pass
def save_model(self, path: str):
"""Save the trained model"""
# TODO: Implement model saving
pass
def load_model(self, path: str):
"""Load a trained model"""
# TODO: Implement model loading
pass

View File

@ -1,83 +0,0 @@
"""
Neural network policy for Balatro RL agent
"""
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Any
class BalatroPolicy(nn.Module):
"""
Neural network policy for Balatro decision making
"""
def __init__(self, state_dim: int = 128, action_dim: int = 64, hidden_dim: int = 256):
super().__init__()
self.state_dim = state_dim
self.action_dim = action_dim
# Feature extraction layers
self.feature_net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Action value head
self.value_head = nn.Linear(hidden_dim, 1)
# Action probability head
self.action_head = nn.Linear(hidden_dim, action_dim)
def forward(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass through the network
Args:
state: Game state tensor
Returns:
Tuple of (action_logits, state_value)
"""
features = self.feature_net(state)
action_logits = self.action_head(features)
state_value = self.value_head(features)
return action_logits, state_value
def get_action(self, state: np.ndarray, available_actions: list) -> tuple[int, float]:
"""
Get action from policy given current state
Args:
state: Current game state as numpy array
available_actions: List of available action indices
Returns:
Tuple of (action_index, action_probability)
"""
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0)
action_logits, _ = self.forward(state_tensor)
# Mask unavailable actions
masked_logits = action_logits.clone()
if available_actions:
mask = torch.full_like(action_logits, float('-inf'))
mask[0, available_actions] = 0
masked_logits = action_logits + mask
# Get action probabilities
action_probs = torch.softmax(masked_logits, dim=-1)
# Sample action
action_dist = torch.distributions.Categorical(action_probs)
action = action_dist.sample()
return action.item(), action_probs[0, action].item()

View File

@ -1,153 +0,0 @@
"""
Training logic for Balatro RL agent
"""
import torch
import torch.optim as optim
import numpy as np
from typing import Dict, List, Any
from .policy import BalatroPolicy
import logging
class BalatroTrainer:
"""
Handles training of the Balatro RL agent
"""
def __init__(self, policy: BalatroPolicy, learning_rate: float = 3e-4):
self.policy = policy
self.optimizer = optim.Adam(policy.parameters(), lr=learning_rate)
self.logger = logging.getLogger(__name__)
# Training buffers
self.states = []
self.actions = []
self.rewards = []
self.values = []
self.log_probs = []
def collect_experience(self, state: np.ndarray, action: int, reward: float,
value: float, log_prob: float):
"""
Collect experience for training
Args:
state: Game state
action: Action taken
reward: Reward received
value: State value estimate
log_prob: Log probability of action
"""
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.values.append(value)
self.log_probs.append(log_prob)
def compute_returns(self, next_value: float = 0.0, gamma: float = 0.99) -> List[float]:
"""
Compute discounted returns
Args:
next_value: Value of next state (for bootstrapping)
gamma: Discount factor
Returns:
List of discounted returns
"""
returns = []
R = next_value
for reward in reversed(self.rewards):
R = reward + gamma * R
returns.insert(0, R)
return returns
def update_policy(self, gamma: float = 0.99, entropy_coef: float = 0.01,
value_coef: float = 0.5) -> Dict[str, float]:
"""
Update policy using collected experience
Args:
gamma: Discount factor
entropy_coef: Entropy regularization coefficient
value_coef: Value loss coefficient
Returns:
Dictionary of training metrics
"""
if not self.states:
return {}
# Convert to tensors
states = torch.FloatTensor(np.array(self.states))
actions = torch.LongTensor(self.actions)
old_log_probs = torch.FloatTensor(self.log_probs)
values = torch.FloatTensor(self.values)
# Compute returns
returns = self.compute_returns(gamma=gamma)
returns = torch.FloatTensor(returns)
# Compute advantages
advantages = returns - values
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Forward pass
action_logits, new_values = self.policy(states)
# Compute losses
action_dist = torch.distributions.Categorical(logits=action_logits)
new_log_probs = action_dist.log_prob(actions)
entropy = action_dist.entropy().mean()
# Policy loss (PPO-style, but simplified)
ratio = torch.exp(new_log_probs - old_log_probs)
policy_loss = -(ratio * advantages).mean()
# Value loss
value_loss = torch.nn.functional.mse_loss(new_values.squeeze(), returns)
# Total loss
total_loss = policy_loss + value_coef * value_loss - entropy_coef * entropy
# Update
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
self.optimizer.step()
# Clear buffers
self.clear_buffers()
# Return metrics
return {
'policy_loss': policy_loss.item(),
'value_loss': value_loss.item(),
'entropy': entropy.item(),
'total_loss': total_loss.item()
}
def clear_buffers(self):
"""Clear experience buffers"""
self.states.clear()
self.actions.clear()
self.rewards.clear()
self.values.clear()
self.log_probs.clear()
def save_model(self, path: str):
"""Save model checkpoint"""
torch.save({
'policy_state_dict': self.policy.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict()
}, path)
self.logger.info(f"Model saved to {path}")
def load_model(self, path: str):
"""Load model checkpoint"""
checkpoint = torch.load(path)
self.policy.load_state_dict(checkpoint['policy_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.logger.info(f"Model loaded from {path}")

View File

@ -1,177 +1,190 @@
""" """
Balatro RL Environment Balatro RL Environment
Wraps the HTTP communication in a gym-like interface for RL training Wraps the pipe-based communication with Balatro mod in a standard RL interface.
This acts as a translator between Balatro's JSON pipe communication and
RL libraries that expect gym-style step()/reset() methods.
""" """
import numpy as np import numpy as np
from typing import Dict, Any, Tuple, List, Optional from typing import Dict, Any, Tuple, List, Optional
import logging import logging
import gymnasium as gym
from gymnasium import spaces
from ..utils.debug import tprint
from ..utils.serialization import GameStateValidator from ..utils.communication import BalatroPipeIO
from .reward import BalatroRewardCalculator
from ..utils.mappers import BalatroStateMapper, BalatroActionMapper
class BalatroEnvironment: class BalatroEnv(gym.Env):
""" """
Gym-like environment interface for Balatro RL training Standard RL Environment wrapper for Balatro
This class will eventually wrap the HTTP communication Translates between:
and provide a standard RL interface with step(), reset(), etc. - Balatro mod's JSON pipe communication (/tmp/balatro_request, /tmp/balatro_response)
- Standard RL interface (step, reset, observation spaces)
This allows RL libraries like Stable-Baselines3 to train on Balatro
without knowing about the underlying pipe communication system.
""" """
def __init__(self): def __init__(self):
super().__init__()
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.current_state = None self.current_state = None
self.prev_state = None
self.game_over = False self.game_over = False
# TODO: Define action and observation spaces # Initialize communication and reward systems
# self.action_space = ... self.pipe_io = BalatroPipeIO()
# self.observation_space = ... self.reward_calculator = BalatroRewardCalculator()
def reset(self) -> Dict[str, Any]: # Define Gymnasium spaces
"""Reset the environment for a new episode""" # Action Spaces; This should describe the type and shape of the action
self.current_state = None # Constants
self.game_over = False self.MAX_ACTIONS = 10
action_selection = np.array([self.MAX_ACTIONS])
card_indices = np.array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) # Handles up to 10 cards in a hand
self.action_space = spaces.MultiDiscrete(np.concatenate([
action_selection,
card_indices
]))
ACTION_SLICE_LAYOUT = [
("action_selection", 1),
("card_indices", 10)
]
slices = self._build_action_slices(ACTION_SLICE_LAYOUT)
# TODO: Send reset signal to Balatro mod # Observation space: This should describe the type and shape of the observation
# This might involve starting a new run # Constants
self.OBSERVATION_SIZE = 70 # you can get this value by running test_env.py.
return self._get_initial_state() self.observation_space = spaces.Box(
low=-np.inf, # lowest bound of observation data
high=np.inf, # highest bound of observation data
shape=(self.OBSERVATION_SIZE,), # Adjust based on actual state size which This is a 1D array
dtype=np.float32 # Data type of the numbers
)
# Initialize mappers
self.state_mapper = BalatroStateMapper(observation_size=self.OBSERVATION_SIZE, max_actions=self.MAX_ACTIONS)
self.action_mapper = BalatroActionMapper(action_slices=slices)
def step(self, action: Dict[str, Any]) -> Tuple[Dict, float, bool, Dict]: def reset(self, seed=None, options=None): #TODO add types
""" """
Take an action in the environment Reset the environment for a new episode
In Balatro context, this means starting a new run.
Communicates with Balatro mod via pipes to initiate reset.
Returns:
Initial observation/game state
"""
self.current_state = None
self.prev_state = None
self.game_over = False
# Reset reward tracking
self.reward_calculator.reset()
# Wait for initial request from Balatro (game start)
initial_request = self.pipe_io.wait_for_request()
if not initial_request:
raise RuntimeError("Failed to receive initial request from Balatro")
# Send dummy response to complete hand shake
success = self.pipe_io.send_response({"action": "ready"})
if not success:
raise RuntimeError("Failed to complete handshake")
# Process initial state for SB3
self.current_state = initial_request
initial_observation = self.state_mapper.process_game_state(self.current_state)
return initial_observation, {}
def step(self, action): #TODO add types
"""
Take an action in the Balatro environment
Sends action to Balatro mod via JSON pipe, waits for response,
calculates reward, and returns standard RL step format.
Args: Args:
action: Action to take action: Action dictionary (e.g., {"action": 1, "params": {...}})
Returns: Returns:
Tuple of (observation, reward, done, info) Tuple of (observation, reward, done, info) where:
- observation: Processed game state for neural network
- reward: Calculated reward for this step
- done: Whether episode is finished (game over)
- info: Additional debug information
""" """
# TODO: Send action to Balatro via HTTP # Store previous state for reward calculation
# TODO: Receive new game state self.prev_state = self.current_state
# TODO: Calculate reward
# TODO: Check if episode is done # Send action response to Balatro mod
response_data = self.action_mapper.process_action(rl_action=action)
observation = self._process_game_state(self.current_state) success = self.pipe_io.send_response(response_data)
reward = self._calculate_reward() if not success:
done = self.game_over raise RuntimeError("Failed to send response to Balatro")
info = {}
return observation, reward, done, info # Wait for next request with new game state
next_request = self.pipe_io.wait_for_request()
def _get_initial_state(self) -> Dict[str, Any]: if not next_request:
"""Get initial game state""" # If no response, assume game ended
# TODO: Implement self.game_over = True
return {} observation = self.state_mapper.process_game_state(self.current_state)
reward = 0.0
def _process_game_state(self, raw_state: Dict[str, Any]) -> Dict[str, Any]: return observation, reward, True, {"timeout": True} #TODO this is a bug we need to return 5 values
"""
Process raw game state into format suitable for RL agent
This might involve: # Update current state
- Extracting relevant features self.current_state = next_request
- Normalizing values
- Converting to numpy arrays
"""
if not raw_state:
return {}
# Validate state structure # Process new state for SB3
try: observation = self.state_mapper.process_game_state(self.current_state)
GameStateValidator.validate_game_state(raw_state)
except ValueError as e:
self.logger.error(f"Invalid game state: {e}")
return {}
# TODO: Extract and process features # Calculate reward using expert reward calculator
processed = { reward = self.reward_calculator.calculate_reward(
'hand_features': self._extract_hand_features(raw_state.get('hand', {})), current_state=self.current_state,
'game_features': self._extract_game_features(raw_state), prev_state=self.prev_state if self.prev_state else {}
'available_actions': raw_state.get('available_actions', []) )
# Check if episode is done
done = bool(self.current_state.get('game_over', 0)) # TODO send a game_over in our state
terminated = done
truncated = False # Not using time limits for now
info = {
'balatro_state': self.current_state.get('state', 0),
'available_actions': next_request.get('available_actions', [])
} }
# TODO have a feeling observation is not being return when we run out of actions which causes problems
return observation, reward, terminated, truncated, info
def cleanup(self):
"""
Clean up environment resources
return processed Call this when shutting down to clean up pipe communication.
"""
def _extract_hand_features(self, hand: Dict[str, Any]) -> np.ndarray: self.pipe_io.cleanup()
"""Extract numerical features from hand"""
# TODO: Convert hand cards to numerical representation @staticmethod
# This might include: def _build_action_slices(layout: List[Tuple[str, int]]) -> Dict[str, slice]:
# - Card values (2-14 for 2-A) """
# - Suits (0-3 for different suits) Create slices for our actions so that we can precisely extract the
# - Special abilities/bonuses right params to send over to balatro
cards = hand.get('cards', []) Args:
features = [] layout: Our ACTION_SLICE_LAYOUT that contains action name and size
Return:
for card in cards: A dictionary containing a key being our action space slice name, and
# Extract card features the slice
suit_encoding = self._encode_suit(card.get('suit', '')) """
value_encoding = self._encode_value(card.get('base', {}).get('value', '')) slices = {}
nominal = card.get('base', {}).get('nominal', 0) start = 0
for action_name, size in layout:
# Add ability features slices[action_name] = slice(start, start + size)
ability = card.get('ability', {}) start += size
ability_features = [ return slices
ability.get('t_chips', 0),
ability.get('t_mult', 0),
ability.get('x_mult', 1),
ability.get('mult', 0)
]
card_features = [suit_encoding, value_encoding, nominal] + ability_features
features.extend(card_features)
# Pad or truncate to fixed size (e.g., 8 cards max * 7 features each)
max_cards = 8
features_per_card = 7
if len(features) < max_cards * features_per_card:
features.extend([0] * (max_cards * features_per_card - len(features)))
else:
features = features[:max_cards * features_per_card]
return np.array(features, dtype=np.float32)
def _extract_game_features(self, state: Dict[str, Any]) -> np.ndarray:
"""Extract game-level features"""
# TODO: Extract features like:
# - Current chips
# - Target chips
# - Money
# - Round/ante
# - Discards remaining
# - Hands remaining
features = [
state.get('state', 0), # Game state
len(state.get('available_actions', [])), # Number of available actions
]
return np.array(features, dtype=np.float32)
def _encode_suit(self, suit: str) -> int:
"""Encode suit as integer"""
suit_map = {'Hearts': 0, 'Diamonds': 1, 'Clubs': 2, 'Spades': 3}
return suit_map.get(suit, 0)
def _encode_value(self, value: str) -> int:
"""Encode card value as integer"""
if value.isdigit():
return int(value)
value_map = {
'Jack': 11, 'Queen': 12, 'King': 13, 'Ace': 14,
'A': 14, 'K': 13, 'Q': 12, 'J': 11
}
return value_map.get(value, 0)
def _calculate_reward(self) -> float:
"""Calculate reward for current state"""
# TODO: Implement reward calculation
# This might be based on:
# - Chips scored
# - Blind completion
# - Round progression
# - Final score
return 0.0

View File

@ -1,5 +1,11 @@
""" """
Reward calculation for Balatro RL environment Balatro Reward System - The Expert on What's Good/Bad
This module is the single source of truth for reward calculation in Balatro RL.
All reward logic is centralized here to make experimentation and tuning easier.
The BalatroRewardCalculator analyzes game state changes and assigns rewards
that teach the AI what constitutes good vs bad Balatro gameplay.
""" """
from typing import Dict, Any from typing import Dict, Any
@ -7,114 +13,86 @@ import numpy as np
class BalatroRewardCalculator: class BalatroRewardCalculator:
""" """
Calculates rewards for Balatro RL training Expert reward calculator for Balatro RL training
This is the single authority on what constitutes good/bad play in Balatro.
Centralizes all reward logic for easy experimentation and tuning.
Reward philosophy:
- Positive rewards for progress (chips, rounds, money)
- Bonus rewards for efficiency (fewer hands/discards used)
- Large rewards for major milestones (completing antes)
- Negative rewards for game over (based on progress made)
""" """
def __init__(self): def __init__(self):
self.prev_score = 0 self.chips = 0
self.prev_money = 0
self.prev_round = 0
self.prev_ante = 0
def calculate_reward(self, current_state: Dict[str, Any], def calculate_reward(self, current_state: Dict[str, Any],
prev_state: Dict[str, Any] = None) -> float: prev_state: Dict[str, Any] = None) -> float:
""" """
Calculate reward based on game state changes Main reward calculation method - analyzes state changes and assigns rewards
This is the core method that determines what the AI should optimize for.
Examines differences between previous and current game state to calculate
appropriate rewards for the action that caused the transition.
Args: Args:
current_state: Current game state current_state: Current Balatro game state
prev_state: Previous game state (optional) prev_state: Previous Balatro game state (None for first step)
Returns: Returns:
Calculated reward Float reward value (positive = good, negative = bad, zero = neutral)
""" """
reward = 0.0 reward = 0.0
# Extract relevant metrics # Extract relevant metrics
current_score = current_state.get('score', 0) current_chips = current_state.get('chips', 0)
current_money = current_state.get('money', 0) game_over = current_state.get('game_over', False) #TODO fix
current_round = current_state.get('round', 0)
current_ante = current_state.get('ante', 0)
game_over = current_state.get('game_over', False)
# Score-based rewards # Score-based rewards
score_diff = current_score - self.prev_score chip_diff = current_chips - self.chips
if score_diff > 0: if chip_diff > 0:
reward += score_diff * 0.001 # Small reward for score increases reward += chip_diff * 0.001 # Small reward for score increases
# Money-based rewards
money_diff = current_money - self.prev_money
if money_diff > 0:
reward += money_diff * 0.01 # Reward for gaining money
# Round progression rewards
if current_round > self.prev_round:
reward += 10.0 # Big reward for completing rounds
# Ante progression rewards
if current_ante > self.prev_ante:
reward += 50.0 # Very big reward for completing antes
# Game over penalty # Game over penalty
if game_over: if game_over:
# Penalty based on how early the game ended # Penalty based on how early the game ended
max_ante = 8 # Typical max ante in Balatro # max_ante = 8 # Typical max ante in Balatro
completion_ratio = current_ante / max_ante # completion_ratio = current_ante / max_ante
reward += completion_ratio * 100.0 # Reward based on progress # reward += completion_ratio * 100.0 # Reward based on progress
reward -= 10 #TODO keeping this simple for now
# Efficiency rewards (optional)
# reward += self._calculate_efficiency_reward(current_state)
# Update previous state # Update previous state
self.prev_score = current_score self.chips = current_chips
self.prev_money = current_money
self.prev_round = current_round
self.prev_ante = current_ante
return reward
def _calculate_efficiency_reward(self, state: Dict[str, Any]) -> float:
"""
Calculate rewards for efficient play
Args:
state: Current game state
Returns:
Efficiency reward
"""
reward = 0.0
# Reward for using fewer discards
discards_remaining = state.get('discards', 0)
if discards_remaining > 0:
reward += discards_remaining * 0.1
# Reward for using fewer hands
hands_remaining = state.get('hands', 0)
if hands_remaining > 0:
reward += hands_remaining * 0.5
return reward return reward
def reset(self): def reset(self):
"""Reset reward calculator for new episode""" """
self.prev_score = 0 Reset reward calculator state for new episode
self.prev_money = 0
self.prev_round = 0 Called at the start of each new Balatro run to clear
self.prev_ante = 0 previous state tracking variables.
"""
self.chips = 0
def get_shaped_reward(self, state: Dict[str, Any], action: str) -> float: def get_shaped_reward(self, state: Dict[str, Any], action: str) -> float:
""" """
Get shaped reward based on specific actions Calculate action-specific reward shaping
Provides small rewards/penalties for specific actions to guide
AI behavior beyond just game state changes. Use sparingly to
avoid overriding the main reward signal.
Args: Args:
state: Current game state state: Current game state when action was taken
action: Action taken action: Action that was taken (for action-specific rewards)
Returns: Returns:
Shaped reward Small shaped reward value (usually -1 to +1)
""" """
# TODO look at this this could help but currently not in a working state
reward = 0.0 reward = 0.0
# Action-specific rewards # Action-specific rewards
@ -134,4 +112,4 @@ class BalatroRewardCalculator:
# Neutral - depends on if it's a good purchase # Neutral - depends on if it's a good purchase
reward += 0.0 reward += 0.0
return reward return reward

View File

@ -1,111 +0,0 @@
#!/usr/bin/env python3
"""
File-based communication handler for Balatro RL
Watches for request files and responds with AI actions
"""
import json
import time
import logging
from pathlib import Path
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from agent.balatro_agent import BalatroAgent
# File paths
REQUEST_FILE = "/tmp/balatro_request.json"
RESPONSE_FILE = "/tmp/balatro_response.json"
class RequestHandler(FileSystemEventHandler):
"""Handles incoming request files from Balatro"""
def __init__(self):
self.agent = BalatroAgent()
self.logger = logging.getLogger(__name__)
def on_created(self, event):
"""Handle new request file creation"""
if event.src_path == REQUEST_FILE and not event.is_directory:
self.process_request()
def on_modified(self, event):
"""Handle request file modification"""
if event.src_path == REQUEST_FILE and not event.is_directory:
self.process_request()
def process_request(self):
"""Process a request file and generate response"""
try:
# Small delay to ensure file write is complete
time.sleep(0.01)
# Read request
with open(REQUEST_FILE, 'r') as f:
request_data = json.load(f)
self.logger.info(f"Processing request {request_data.get('id')}: state={request_data.get('state')}")
# Get action from AI agent
action_response = self.agent.get_action(request_data)
# Write response
response = {
"id": request_data.get("id"),
"action": action_response.get("action", "no_action"),
"params": action_response.get("params"),
"timestamp": time.time()
}
with open(RESPONSE_FILE, 'w') as f:
json.dump(response, f)
self.logger.info(f"Response written: {response['action']}")
except Exception as e:
self.logger.error(f"Error processing request: {e}")
# Write error response
error_response = {
"id": request_data.get("id", 0) if 'request_data' in locals() else 0,
"action": "no_action",
"params": {"error": str(e)},
"timestamp": time.time()
}
try:
with open(RESPONSE_FILE, 'w') as f:
json.dump(error_response, f)
except:
pass
def main():
"""Main file watcher loop"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
logger.info("Starting Balatro RL file watcher...")
# Set up file watcher
event_handler = RequestHandler()
observer = Observer()
observer.schedule(event_handler, path="/tmp", recursive=False)
observer.start()
logger.info(f"Watching for requests at: {REQUEST_FILE}")
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("Stopping file watcher...")
observer.stop()
observer.join()
if __name__ == "__main__":
main()

View File

@ -1,6 +1,5 @@
numpy>=1.21.0 numpy>=1.21.0
torch>=1.9.0 torch>=1.9.0
gymnasium>=0.26.0 # For RL environment interface gymnasium>=0.26.0 # For RL environment interface
stable-baselines3>=1.6.0 # For RL algorithms stable-baselines3[extra]>=1.6.0 # For RL algorithms
tensorboard>=2.8.0 # For training visualization tensorboard>=2.8.0 # For training visualization
watchdog>=2.1.0 # For file system monitoring

View File

@ -22,5 +22,5 @@ echo "Setup complete!"
echo "" echo ""
echo "Usage:" echo "Usage:"
echo " To activate environment: source venv/bin/activate" echo " To activate environment: source venv/bin/activate"
echo " To run file watcher: python file_watcher.py" echo " To test environment is working: python -m ai.test_env"
echo " To train agent: python -m agent.training" echo " To train agent: python -m ai.train_balatro"

96
ai/test_env.py Normal file
View File

@ -0,0 +1,96 @@
"""
Environment Testing Script for Balatro RL
Tests the BalatroEnv to ensure it follows the Gym interface and
communicates properly with the Balatro mod via file I/O.
Usage:
python test_env.py
"""
import time
import logging
# from stable_baselines3.common.env_checker import check_env # Not compatible with pipes
from .environment.balatro_env import BalatroEnv
def test_manual_actions():
"""Manual testing with random actions"""
print("\n=== Manual Action Testing ===")
env = BalatroEnv()
print("BalatroEnv created successfully")
try:
# Reset and get initial observation (ONLY ONCE!)
print("Calling env.reset() - this should wait for Balatro...")
obs, info = env.reset()
print("env.reset() completed!")
print(f"Initial observation: {obs}")
print(f"Observation space: {env.observation_space}")
print(f"Action space: {env.action_space}")
print(f"Sample action: {env.action_space.sample()}")
# Run test episodes
max_steps = 3
max_episodes = 1
for episode in range(max_episodes):
print(f"\n--- Episode {episode + 1} ---")
# Don't call reset() again! Use the obs from above
for step in range(max_steps):
# Get random action (like the old balatro_agent)
action = env.action_space.sample()
print(f"Step {step + 1}: Taking action {action}")
# Take action
obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
print(f" obs={obs}, reward={reward}, done={done}")
# Optional: Add delay for readability
time.sleep(0.1)
if done:
print(f" Episode finished after {step + 1} steps")
break
if not done:
print(f" Episode reached max steps ({max_steps})")
except Exception as e:
print(f"✗ Manual testing failed: {e}")
return False
finally:
env.close()
print("✓ Manual testing completed!")
return True
def main():
"""Run all tests"""
print("Starting Balatro Environment Testing...")
# Setup logging to see communication details
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
success = True
success &= test_manual_actions()
if success:
print("\n🎉 All tests passed! Environment is ready for training.")
else:
print("\n❌ Some tests failed. Check the environment implementation.")
return success
if __name__ == "__main__":
main()

273
ai/train_balatro.py Executable file
View File

@ -0,0 +1,273 @@
#!/usr/bin/env python3
"""
Balatro RL Training Script
Main training script for teaching AI to play Balatro using Stable-Baselines3.
This script creates the Balatro environment, sets up the RL model, and runs training.
Usage:
python train_balatro.py
Requirements:
- Balatro game running with RLBridge mod
- file_watcher.py should NOT be running (this replaces it)
"""
import logging
import time
from pathlib import Path
# SB3 imports
from stable_baselines3 import DQN, A2C, PPO
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.monitor import Monitor
# Our custom environment
from .environment.balatro_env import BalatroEnv
def setup_logging():
"""Setup logging for training"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('training.log'),
logging.StreamHandler()
]
)
def create_environment():
"""Create and wrap the Balatro environment"""
# Create base environment
env = BalatroEnv()
# Wrap with Monitor for logging episode stats
env = Monitor(env, filename="training_monitor.csv")
return env
def create_model(env, algorithm="DQN", model_path=None):
"""
Create SB3 model for training
Args:
env: Balatro environment
algorithm: RL algorithm to use ("DQN", "A2C", "PPO")
model_path: Path to load existing model (optional)
Returns:
SB3 model ready for training
"""
if algorithm == "DQN":
model = DQN(
"MlpPolicy",
env,
verbose=1,
learning_rate=1e-4,
buffer_size=10000,
learning_starts=1000,
target_update_interval=500,
train_freq=4,
gradient_steps=1,
exploration_fraction=0.1,
exploration_initial_eps=1.0,
exploration_final_eps=0.05,
tensorboard_log="./tensorboard_logs/"
)
elif algorithm == "A2C":
model = A2C(
"MlpPolicy",
env,
verbose=1,
learning_rate=7e-4,
n_steps=5,
gamma=0.99,
tensorboard_log="./tensorboard_logs/"
)
elif algorithm == "PPO":
model = PPO(
"MlpPolicy",
env,
verbose=1,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
tensorboard_log="./tensorboard_logs/"
)
else:
raise ValueError(f"Unknown algorithm: {algorithm}")
# Load existing model if path provided
if model_path and Path(model_path).exists():
model.load(model_path)
print(f"Loaded existing model from {model_path}")
return model
def create_callbacks(save_freq=1000):
"""Create training callbacks for saving and evaluation"""
callbacks = []
# Checkpoint callback - save model periodically
checkpoint_callback = CheckpointCallback(
save_freq=save_freq,
save_path="./models/",
name_prefix="balatro_model"
)
callbacks.append(checkpoint_callback)
return callbacks
def train_agent(total_timesteps=50000, algorithm="DQN", save_path="./models/balatro_final"):
"""
Main training function
Args:
total_timesteps: Number of training steps
algorithm: RL algorithm to use
save_path: Where to save final model
"""
logger = logging.getLogger(__name__)
logger.info(f"Starting Balatro RL training with {algorithm}")
logger.info(f"Training for {total_timesteps} timesteps")
try:
# Create environment and model
env = create_environment()
model = create_model(env, algorithm)
# Create callbacks
callbacks = create_callbacks(save_freq=max(1000, total_timesteps // 20))
# Train the model
logger.info("Starting training...")
start_time = time.time()
model.learn(
total_timesteps=total_timesteps,
callback=callbacks,
progress_bar=True
)
training_time = time.time() - start_time
logger.info(f"Training completed in {training_time:.2f} seconds")
# Save final model
model.save(save_path)
logger.info(f"Model saved to {save_path}")
# Clean up environment
if hasattr(env, 'cleanup'):
env.cleanup()
elif hasattr(env.env, 'cleanup'): # Monitor wrapper
env.env.cleanup()
return model
except KeyboardInterrupt:
logger.info("Training interrupted by user")
if hasattr(env, 'cleanup'):
env.cleanup()
elif hasattr(env.env, 'cleanup'):
env.env.cleanup()
return None
except Exception as e:
logger.error(f"Training failed: {e}")
if hasattr(env, 'cleanup'):
env.cleanup()
elif hasattr(env.env, 'cleanup'):
env.env.cleanup()
raise
def test_trained_model(model_path, num_episodes=5):
"""
Test a trained model
Args:
model_path: Path to trained model
num_episodes: Number of episodes to test
"""
logger = logging.getLogger(__name__)
logger.info(f"Testing model from {model_path}")
# Create environment and load model
env = create_environment()
model = DQN.load(model_path)
episode_rewards = []
for episode in range(num_episodes):
obs = env.reset()
total_reward = 0
steps = 0
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
total_reward += reward
steps += 1
if done:
break
episode_rewards.append(total_reward)
logger.info(f"Episode {episode + 1}: {steps} steps, reward: {total_reward:.2f}")
avg_reward = sum(episode_rewards) / len(episode_rewards)
logger.info(f"Average reward over {num_episodes} episodes: {avg_reward:.2f}")
env.cleanup()
return episode_rewards
if __name__ == "__main__":
# Setup
setup_logging()
# Create necessary directories
Path(".models").mkdir(exist_ok=True)
Path(".tensorboard_logs").mkdir(exist_ok=True)
# We'll let BalatroEnv create the pipes when needed
print("🔧 Pipe communication will be initialized with environment...")
# Train the agent
print("\n🎮 Starting Balatro RL Training!")
print("Setup steps:")
print("1. ✅ Balatro is running with RLBridge mod")
print("2. ✅ Balatro is in menu state")
print("3. Press Enter below to start training")
print("4. When prompted, press 'R' in Balatro within 3 seconds")
print("\n📡 Training will create pipes and connect automatically!")
print()
input("Press Enter to start training (then you'll have 3 seconds to press 'R' in Balatro)...")
try:
model = train_agent(
total_timesteps=10000, # Start small for testing
algorithm="PPO", # PPO supports MultiDiscrete actions
save_path="./models/balatro_trained"
)
if model:
print("\n🎉 Training completed successfully!")
print("Testing the trained model...")
# Test the trained model
test_trained_model("./models/balatro_trained", num_episodes=3)
except Exception as e:
print(f"\n❌ Training failed: {e}")
print("Check the logs for more details.")
finally:
# Pipes will be cleaned up by the environment
print("🧹 Training session ended")

192
ai/utils/communication.py Normal file
View File

@ -0,0 +1,192 @@
"""
Balatro Dual Pipe Communication Abstraction
Pure dual pipe I/O layer with persistent handles for communicating with Balatro mod.
Handles low-level pipe operations without any game logic or AI logic.
This abstraction can be used by:
- balatro_env.py (for SB3 training)
- file_watcher.py (for testing)
- Any other component that needs to talk to Balatro
"""
import json
import logging
import os
from typing import Dict, Any, Optional
class BalatroPipeIO:
"""
Clean abstraction for dual pipe communication with Balatro mod using persistent handles
Responsibilities:
- Read JSON requests from Balatro mod via request pipe
- Write JSON responses back to Balatro mod via response pipe
- Keep pipe handles open persistently to avoid deadlocks
- Provide simple, clean interface for higher-level code
"""
def __init__(self, request_pipe: str = "/tmp/balatro_request", response_pipe: str = "/tmp/balatro_response"):
self.request_pipe = request_pipe
self.response_pipe = response_pipe
self.logger = logging.getLogger(__name__)
# Persistent pipe handles
self.request_handle = None
self.response_handle = None
# Create pipes and open persistent handles
self.create_pipes()
self.open_persistent_handles()
def create_pipes(self) -> None:
"""
Create dual named pipes for communication
Creates both request and response pipes if they don't exist.
Safe to call multiple times.
"""
for pipe_path in [self.request_pipe, self.response_pipe]:
try:
# Remove existing pipe if it exists
if os.path.exists(pipe_path):
os.unlink(pipe_path)
self.logger.debug(f"Removed existing pipe: {pipe_path}")
# Create named pipe
os.mkfifo(pipe_path)
self.logger.info(f"Created pipe: {pipe_path}")
except Exception as e:
self.logger.error(f"Failed to create pipe {pipe_path}: {e}")
raise RuntimeError(f"Could not create communication pipe: {pipe_path}")
def open_persistent_handles(self) -> None:
"""
Open persistent handles for reading and writing
Opens request pipe for reading and response pipe for writing.
Keeps handles open to avoid deadlocks.
"""
import time
try:
self.logger.info(f"🔧 Waiting for Balatro to connect...")
self.logger.info(f" Press 'R' in Balatro now to activate RL training!")
# Open request pipe for reading (Balatro writes to this)
self.request_handle = open(self.request_pipe, 'r')
# Open response pipe for writing (AI writes to this)
self.response_handle = open(self.response_pipe, 'w')
except Exception as e:
self.logger.error(f"Failed to open persistent handles: {e}")
self.cleanup_handles()
raise RuntimeError(f"Could not open persistent pipe handles: {e}")
def wait_for_request(self) -> Optional[Dict[str, Any]]:
"""
Wait for new request from Balatro mod using persistent handle
Blocks until Balatro writes a request to the pipe.
No timeout needed - pipes block until data arrives.
Returns:
Parsed JSON request data, or None if error
"""
if not self.request_handle:
self.logger.error("Request handle not open")
return None
try:
# Read from persistent request handle
request_line = self.request_handle.readline().strip()
if not request_line:
return None
request_data = json.loads(request_line)
self.logger.info(f"📥 RECEIVED REQUEST: {request_line}...") # Show first 100 chars
return request_data
except json.JSONDecodeError as e:
self.logger.error(f"Invalid JSON in request: {e}")
return None
except Exception as e:
self.logger.error(f"Error reading from request pipe: {e}")
return None
def send_response(self, response_data: Dict[str, Any]) -> bool:
"""
Send response back to Balatro mod using persistent handle
Writes response data to response pipe for Balatro to read.
Args:
response_data: Response dictionary to send
Returns:
True if successful, False if error
"""
if not self.response_handle:
self.logger.error("Response handle not open")
return False
try:
# Write to persistent response handle
json.dump(response_data, self.response_handle)
self.response_handle.write('\n') # Important: newline for pipe communication
self.response_handle.flush() # Force write to pipe immediately
self.logger.info(f"📤 SENT RESPONSE: {json.dumps(response_data)}")
return True
except Exception as e:
self.logger.error(f"❌ ERROR sending response: {e}")
import traceback
self.logger.error(f"❌ TRACEBACK: {traceback.format_exc()}")
return False
def cleanup_handles(self):
"""
Close persistent pipe handles
Closes the open file handles.
"""
if self.request_handle:
try:
self.request_handle.close()
self.logger.debug("Closed request handle")
except Exception as e:
self.logger.warning(f"Failed to close request handle: {e}")
self.request_handle = None
if self.response_handle:
try:
self.response_handle.close()
self.logger.debug("Closed response handle")
except Exception as e:
self.logger.warning(f"Failed to close response handle: {e}")
self.response_handle = None
def cleanup(self):
"""
Clean up communication pipes and handles
Closes handles and removes pipe files from the filesystem.
"""
# Close handles first
self.cleanup_handles()
# Remove pipe files
for pipe_path in [self.request_pipe, self.response_pipe]:
try:
if os.path.exists(pipe_path):
os.unlink(pipe_path)
self.logger.debug(f"Removed pipe: {pipe_path}")
except Exception as e:
self.logger.warning(f"Failed to remove pipe {pipe_path}: {e}")
self.logger.debug("Dual pipe communication cleanup complete")

39
ai/utils/debug.py Normal file
View File

@ -0,0 +1,39 @@
"""
Debug utilities for timestamped print statements and debugging helpers
"""
import datetime
def timestamp() -> str:
"""
Get current timestamp formatted as HH:MM:SS.mmm
Returns:
str: Formatted timestamp string
"""
now = datetime.datetime.now()
return now.strftime("%H:%M:%S.%f")[:-3] # Remove last 3 digits to get milliseconds
def tprint(*args, **kwargs):
"""
Print with timestamp prefix
Args:
*args: Arguments to print
**kwargs: Keyword arguments for print()
"""
print(f"[{timestamp()}]", *args, **kwargs)
def dprint(prefix: str, *args, **kwargs):
"""
Debug print with custom prefix and timestamp
Args:
prefix: Custom prefix string
*args: Arguments to print
**kwargs: Keyword arguments for print()
"""
print(f"[{timestamp()}] [{prefix}]", *args, **kwargs)

View File

@ -1,74 +0,0 @@
"""
Logging setup for Balatro RL system
"""
import logging
import sys
from pathlib import Path
from typing import Optional
def setup_logging(
level: str = "INFO",
log_file: Optional[str] = None,
format_string: Optional[str] = None
) -> None:
"""
Setup logging configuration
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_file: Optional log file path
format_string: Optional custom format string
"""
if format_string is None:
format_string = (
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# Convert string level to logging constant
numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {level}')
# Create formatter
formatter = logging.Formatter(format_string)
# Setup root logger
root_logger = logging.getLogger()
root_logger.setLevel(numeric_level)
# Remove existing handlers
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(numeric_level)
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
# File handler (optional)
if log_file:
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(numeric_level)
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
# Set specific logger levels
logging.getLogger("uvicorn").setLevel(logging.INFO)
logging.getLogger("fastapi").setLevel(logging.INFO)
def get_logger(name: str) -> logging.Logger:
"""
Get a logger with the specified name
Args:
name: Logger name
Returns:
Logger instance
"""
return logging.getLogger(name)

245
ai/utils/mappers.py Normal file
View File

@ -0,0 +1,245 @@
"""
Mappers for converting between Balatro game data and RL-compatible formats.
This module handles the two-way data transformation:
1. BalatroStateMapper: Converts incoming Balatro JSON to normalized RL observations
2. BalatroActionMapper: Converts RL actions to Balatro command JSON
"""
import numpy as np
import time
from typing import Dict, List, Any
from ..utils.validation import GameStateValidator, ResponseValidator
import logging
class BalatroStateMapper:
"""
Converts raw Balatro game state JSON to normalized RL observations.
Handles:
- Card data normalization
- Game state parsing
- Observation space formatting
"""
def __init__(self, observation_size: int, max_actions: int):
self.observation_size = observation_size
self.max_actions = max_actions
# Logger
self.logger = logging.getLogger(__name__)
# GameStateValidator
self.game_state_validator = GameStateValidator()
#TODO review Might be something wrong here
def process_game_state(self, raw_state: Dict[str, Any] | None) -> np.ndarray:
"""
Convert Balatro's raw JSON game state into neural network input format
converting into standardized numerical arrays that neural networks can
process.
Args:
raw_state: Raw game state from Balatro mod JSON
Returns:
Processed state suitable for RL training
"""
if not raw_state:
return np.zeros(self.observation_size, dtype=np.float32) # TODO maybe raise error?
# Validate game state request
try:
self.game_state_validator.validate_game_state(raw_state)
except ValueError as e:
self.logger.error(f"Invalid game state: {e}")
hand_features = self._extract_hand_features(raw_state.get('hand', {}))
game_features = self._extract_game_features(raw_state)
available_actions = self._extract_available_actions(raw_state.get('available_actions', []))
chips = self._extract_chip_features(raw_state)
# TODO extract state
return np.concatenate([hand_features, game_features, available_actions, chips])
def _extract_available_actions(self, available_actions: List[int]) -> np.ndarray:
"""
Convert available actions into a np array
Args:
available_actions: Available actions list from Balatro game state
Returns:
Fixed-size numpy array of hand features
"""
mask = np.zeros(self.max_actions, dtype=np.float32)
for action_id in available_actions:
mask[action_id] = 1.0
return mask
def _extract_hand_features(self, hand: Dict[str, Any]) -> np.ndarray:
"""
Convert Balatro hand data into numerical features for neural network
Transforms card dictionaries into fixed-size numerical arrays:
- Card values: 2-14 (2 through Ace)
- Suits: 0-3 (Hearts, Diamonds, Clubs, Spades)
- Card abilities: chips, mult, special effects
- Pads/truncates to fixed hand size for consistent input
Args:
hand: Hand dictionary from Balatro game state
Returns:
Fixed-size numpy array of hand features
"""
cards = hand.get('cards', [])
features = []
for card in cards:
# Extract card features
suit_encoding = self._encode_suit(card.get('suit', ''))
value_encoding = self._encode_value(card.get('base', {}).get('value', ''))
nominal = card.get('base', {}).get('nominal', 0)
# Add ability features
ability = card.get('ability', {})
ability_features = [
ability.get('t_chips', 0),
ability.get('t_mult', 0),
ability.get('x_mult', 1),
ability.get('mult', 0)
]
card_features = [suit_encoding, value_encoding, nominal] + ability_features
features.extend(card_features)
# Pad or truncate to fixed size (e.g., 8 cards max * 7 features each)
# TODO hand.get('size')
# TODO hand.get('highlighted_count')
max_cards = 8
features_per_card = 7
if len(features) < max_cards * features_per_card:
features.extend([0] * (max_cards * features_per_card - len(features)))
else:
features = features[:max_cards * features_per_card]
return np.array(features, dtype=np.float32)
def _extract_game_features(self, state: Dict[str, Any]) -> np.ndarray:
"""
Extract numerical game-level features from Balatro state
Converts game metadata into neural network inputs:
- Current game state (menu, selecting hand, etc.)
- Available actions count
- Money, chips, round progression
- Remaining hands/discards
Args:
state: Full Balatro game state dictionary
Returns:
Numpy array of normalized game features
"""
features = [
state.get('state', 0), # Game state
len(state.get('available_actions', [])), # Number of available actions
]
return np.array(features, dtype=np.float32)
def _extract_chip_features(self, state: Dict[str, Any]) -> np.ndarray:
"""
Extract information relating to chips
Args:
state: Full Balatro game state dictionary
Returns:
Numpy array of normalized chip features
"""
features = [
state.get('chips', 0),
state.get('chip_target', 0)
]
return np.array(features, dtype=np.float32)
def _encode_suit(self, suit: str) -> int:
"""Encode suit as integer"""
suit_map = {'Hearts': 0, 'Diamonds': 1, 'Clubs': 2, 'Spades': 3}
return suit_map.get(suit, 0)
def _encode_value(self, value: str) -> int:
"""Encode card value as integer"""
if value.isdigit():
return int(value)
value_map = {
'Jack': 11, 'Queen': 12, 'King': 13, 'Ace': 14,
'A': 14, 'K': 13, 'Q': 12, 'J': 11
}
return value_map.get(value, 0)
class BalatroActionMapper:
"""
Converts RL actions to Balatro command JSON.
Handles:
- Binary action conversion to card indices
- Action validation
- JSON response formatting
"""
def __init__(self, action_slices: Dict[str, slice]):
self.slices = action_slices
# Validator
self.response_validator = ResponseValidator()
# Logger
self.logger = logging.getLogger(__name__)
def process_action(self, rl_action: np.ndarray) -> Dict[str, Any]:
"""
Convert RL action to Balatro JSON.
Args:
rl_action: Binary action array from RL agent
game_state: Current game state for validation
Returns:
JSON response formatted for Balatro mod
"""
ai_action = rl_action[self.slices["action_selection"]].tolist()[0]
response_data = {
"action": ai_action,
"params": self._extract_select_hand_params(rl_action),
}
self.response_validator.validate_response(response_data)
# Validate action structure
try:
self.response_validator.validate_response(response_data)
except ValueError as e:
self.logger.error(f"Invalid game state: {e}")
return response_data
def _extract_select_hand_params(self, raw_action: np.ndarray) -> List[int]:
"""
Converts the raw action to a list of Lua card indices
Args:
raw_action: The whole action from the RL agent
Returns:
List of 1-based card indices for Lua
"""
card_indices = raw_action[self.slices["card_indices"]]
return [i + 1 for i, val in enumerate(card_indices) if val == 1]

View File

@ -1,122 +0,0 @@
"""
JSON Serialization utilities for Balatro RL
Handles conversion between game state formats
"""
import json
from typing import Dict, Any, List, Optional
class GameStateSerializer:
"""Handles serialization/deserialization of game states"""
@staticmethod
def serialize_game_state(game_state: Dict[str, Any]) -> str:
"""Convert game state dict to JSON string"""
try:
return json.dumps(game_state, indent=2, ensure_ascii=False)
except (TypeError, ValueError) as e:
raise ValueError(f"Failed to serialize game state: {e}")
@staticmethod
def deserialize_game_state(json_string: str) -> Dict[str, Any]:
"""Convert JSON string to game state dict"""
try:
return json.loads(json_string)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to deserialize game state: {e}")
@staticmethod
def serialize_action(action: Dict[str, Any]) -> str:
"""Convert action dict to JSON string"""
try:
return json.dumps(action, ensure_ascii=False)
except (TypeError, ValueError) as e:
raise ValueError(f"Failed to serialize action: {e}")
@staticmethod
def deserialize_action(json_string: str) -> Dict[str, Any]:
"""Convert JSON string to action dict"""
try:
return json.loads(json_string)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to deserialize action: {e}")
class GameStateValidator:
"""Validates game state structure and content"""
@staticmethod
def validate_game_state(game_state: Dict[str, Any]) -> bool:
"""Validate that game state has required fields"""
required_fields = ['state', 'available_actions']
for field in required_fields:
if field not in game_state:
raise ValueError(f"Missing required field: {field}")
# Validate hand structure if present
if 'hand' in game_state:
GameStateValidator._validate_hand(game_state['hand'])
return True
@staticmethod
def _validate_hand(hand: Dict[str, Any]) -> bool:
"""Validate hand structure"""
if not isinstance(hand, dict):
raise ValueError("Hand must be a dictionary")
if 'cards' in hand and not isinstance(hand['cards'], list):
raise ValueError("Hand cards must be a list")
# Validate each card
for i, card in enumerate(hand.get('cards', [])):
GameStateValidator._validate_card(card, i)
return True
@staticmethod
def _validate_card(card: Dict[str, Any], index: int) -> bool:
"""Validate individual card structure"""
required_fields = ['suit', 'base']
for field in required_fields:
if field not in card:
raise ValueError(f"Card {index} missing required field: {field}")
# Validate base structure
if 'base' in card:
base = card['base']
if 'value' not in base or 'nominal' not in base:
raise ValueError(f"Card {index} base missing value or nominal")
return True
@staticmethod
def validate_action(action: Dict[str, Any]) -> bool:
"""Validate action structure"""
if 'action' not in action:
raise ValueError("Action must have 'action' field")
action_type = action['action']
# Validate specific action types
if action_type in ['play_hand', 'discard']:
if 'card_indices' not in action:
raise ValueError(f"Action {action_type} requires card_indices")
if not isinstance(action['card_indices'], list):
raise ValueError("card_indices must be a list")
return True
# Convenience functions
def to_json(obj: Dict[str, Any]) -> str:
"""Quick serialize to JSON"""
return GameStateSerializer.serialize_game_state(obj)
def from_json(json_str: str) -> Dict[str, Any]:
"""Quick deserialize from JSON"""
return GameStateSerializer.deserialize_game_state(json_str)

89
ai/utils/validation.py Normal file
View File

@ -0,0 +1,89 @@
"""
Game State Validation utilities for Balatro RL
Validates game state structure and content
"""
from typing import Dict, Any, List
class GameStateValidator:
"""
Validates game state json structure
Request Contract (Macro level not everything)
{
game_state: Dict, # The whole game state table
available_actions: List, # The available actions all as integers
retry_count: int, # TODO update observation space. this is to handle retries if AI fails
}
"""
@staticmethod
def validate_game_state(game_state: Dict[str, Any]) -> bool:
"""Validate that game state has required fields"""
required_fields = ['game_state', 'available_actions', 'retry_count']
# TODO we can add more validations as needed like ensuring available_actions is a list and stuff
# TODO game_over might be a validation
# TODO validate chips for rewards
for field in required_fields:
if field not in game_state:
raise ValueError(f"Missing required field: {field}")
# Validate hand structure if present
if 'hand' in game_state:
GameStateValidator._validate_hand(game_state['hand'])
return True
@staticmethod
def _validate_hand(hand: Dict[str, Any]) -> bool:
"""Validate hand structure"""
if not isinstance(hand, dict):
raise ValueError("Hand must be a dictionary")
if 'cards' in hand and not isinstance(hand['cards'], list):
raise ValueError("Hand cards must be a list")
# Validate each card
for i, card in enumerate(hand.get('cards', [])):
GameStateValidator._validate_card(card, i)
return True
@staticmethod
def _validate_card(card: Dict[str, Any], index: int) -> bool:
"""Validate individual card structure"""
required_fields = ['suit', 'base']
for field in required_fields:
if field not in card:
raise ValueError(f"Card {index} missing required field: {field}")
# Validate base structure
if 'base' in card:
base = card['base']
if 'value' not in base or 'nominal' not in base:
raise ValueError(f"Card {index} base missing value or nominal")
return True
class ResponseValidator:
"""
Validates response json structure
Response contract
{
action, # The action to take
params, # The params in the event the action takes in params
}
"""
@staticmethod
def validate_response(response: Dict[str, Any]):
required_fields = ["action"]
for field in required_fields:
if field not in response:
raise ValueError(f"Missing required field: {field}")