Have a successful pipe communication between AI and balatro with some intial iteration going
This commit is contained in:
parent
bbeab8c635
commit
9b5f949a0b
|
|
@ -56,13 +56,17 @@ logs/
|
||||||
*.temp
|
*.temp
|
||||||
.cache/
|
.cache/
|
||||||
|
|
||||||
# Model checkpoints and training artifacts
|
# Training artifacts (generated during RL training sessions)
|
||||||
|
models/
|
||||||
|
tensorboard_logs/
|
||||||
|
training.log
|
||||||
|
training_monitor.csv
|
||||||
|
|
||||||
|
# Other ML artifacts
|
||||||
*.pth
|
*.pth
|
||||||
*.pt
|
*.pt
|
||||||
*.ckpt
|
*.ckpt
|
||||||
checkpoints/
|
checkpoints/
|
||||||
models/
|
|
||||||
tensorboard_logs/
|
|
||||||
wandb/
|
wandb/
|
||||||
|
|
||||||
# Large reference files (don't commit decompiled game source)
|
# Large reference files (don't commit decompiled game source)
|
||||||
|
|
|
||||||
21
CLAUDE.md
21
CLAUDE.md
|
|
@ -55,11 +55,32 @@ Balatro uses numbered states defined in `globals.lua`:
|
||||||
- `G.GAME` - Main game data structure
|
- `G.GAME` - Main game data structure
|
||||||
- `G.FUNCS` - Game function callbacks
|
- `G.FUNCS` - Game function callbacks
|
||||||
|
|
||||||
|
## Communication Architecture
|
||||||
|
|
||||||
|
### Dual Pipe Communication
|
||||||
|
The mod uses separate named pipes for request/response communication with the AI:
|
||||||
|
- **Request pipe path**: `/tmp/balatro_request` (Balatro writes, AI reads)
|
||||||
|
- **Response pipe path**: `/tmp/balatro_response` (AI writes, Balatro reads)
|
||||||
|
- **Protocol**: Persistent pipe handles with blocking reads
|
||||||
|
- **Flow**:
|
||||||
|
1. Balatro opens persistent handles to both pipes
|
||||||
|
2. Balatro writes JSON request to request pipe
|
||||||
|
3. AI reads request from request pipe
|
||||||
|
4. AI writes JSON response to response pipe
|
||||||
|
5. Balatro performs blocking read from response pipe
|
||||||
|
|
||||||
|
### Benefits of Dual Pipes
|
||||||
|
- Clear separation of request/response channels
|
||||||
|
- Persistent handles avoid repeated open/close overhead
|
||||||
|
- Blocking reads ensure proper synchronization
|
||||||
|
- Each pipe has single direction, eliminating read/write conflicts
|
||||||
|
|
||||||
## Current Status
|
## Current Status
|
||||||
The mod can successfully:
|
The mod can successfully:
|
||||||
- Start a run automatically
|
- Start a run automatically
|
||||||
- Detect blind selection state
|
- Detect blind selection state
|
||||||
- Automatically select the next blind
|
- Automatically select the next blind
|
||||||
|
- Communicate with AI via dual pipe system
|
||||||
|
|
||||||
## Communication with Claude
|
## Communication with Claude
|
||||||
- Please keep chats as efficient as possible and brief. No fluff talk, just get to the point
|
- Please keep chats as efficient as possible and brief. No fluff talk, just get to the point
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,12 @@ Made a symlink -> ln -s ~/dev/balatro-rl/RLBridge /mnt/gamerlinuxssd/SteamLibrar
|
||||||
- [ ] Training loop integration
|
- [ ] Training loop integration
|
||||||
|
|
||||||
### Game Features
|
### Game Features
|
||||||
|
- [ ] Always have restart_run as an action option assuming the game is ongoing
|
||||||
- [ ] Blind selection choices (skip vs select)
|
- [ ] Blind selection choices (skip vs select)
|
||||||
- [ ] Extended game state (money, discards, hands played)
|
- [ ] Extended game state (money, discards, hands played)
|
||||||
- [ ] Shop interactions
|
- [ ] Shop interactions
|
||||||
- [ ] Joker management
|
- [ ] Joker management
|
||||||
|
|
||||||
|
### RL Enhancements
|
||||||
|
- [ ] Add a "Replay System" to analyze successful actions. For example, save seed, have an action log for reproduction etc
|
||||||
|
Can probably do this by adding it before the game is reset like a check on what criteria I want to save for, and save
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,10 @@ local utils = require("utils")
|
||||||
local last_state_hash = nil
|
local last_state_hash = nil
|
||||||
local last_actions_hash = nil
|
local last_actions_hash = nil
|
||||||
local pending_action = nil
|
local pending_action = nil
|
||||||
|
local retry_count = 0
|
||||||
|
local need_retry_request = false
|
||||||
|
local rl_training_active = false
|
||||||
|
local last_key_pressed = nil
|
||||||
|
|
||||||
--- Initialize AI system
|
--- Initialize AI system
|
||||||
--- Sets up communication and prepares the AI for operation
|
--- Sets up communication and prepares the AI for operation
|
||||||
|
|
@ -20,12 +24,50 @@ local pending_action = nil
|
||||||
function AI.init()
|
function AI.init()
|
||||||
utils.log_ai("Initializing AI system...")
|
utils.log_ai("Initializing AI system...")
|
||||||
communication.init()
|
communication.init()
|
||||||
|
|
||||||
|
-- Hook into Love2D keyboard events
|
||||||
|
if love and love.keypressed then
|
||||||
|
local original_keypressed = love.keypressed
|
||||||
|
love.keypressed = function(key)
|
||||||
|
-- Store the key press for our AI
|
||||||
|
last_key_pressed = key
|
||||||
|
|
||||||
|
-- Call original function
|
||||||
|
if original_keypressed then
|
||||||
|
original_keypressed(key)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
else
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
--- Main AI update loop (called every frame)
|
--- Main AI update loop (called every frame)
|
||||||
--- Monitors game state changes, handles communication, and executes AI actions
|
--- Monitors game state changes, handles communication, and executes AI actions
|
||||||
--- @return nil
|
--- @return nil
|
||||||
function AI.update()
|
function AI.update()
|
||||||
|
-- Check for key press to start/stop RL training
|
||||||
|
if last_key_pressed then
|
||||||
|
if last_key_pressed == "r" then
|
||||||
|
if not rl_training_active then
|
||||||
|
rl_training_active = true
|
||||||
|
utils.log_ai("🚀 RL Training STARTED (R pressed)")
|
||||||
|
end
|
||||||
|
elseif last_key_pressed == "t" then
|
||||||
|
if rl_training_active then
|
||||||
|
rl_training_active = false
|
||||||
|
utils.log_ai("⏹️ RL Training STOPPED (T pressed)")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Clear the key press
|
||||||
|
last_key_pressed = nil
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Don't process AI requests unless RL training is active
|
||||||
|
if not rl_training_active then
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
-- Get current game state
|
-- Get current game state
|
||||||
local current_state = output.get_game_state()
|
local current_state = output.get_game_state()
|
||||||
local available_actions = action.get_available_actions()
|
local available_actions = action.get_available_actions()
|
||||||
|
|
@ -40,29 +82,40 @@ function AI.update()
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Create simple hash to detect state changes
|
-- Create hash to detect state changes
|
||||||
local state_hash = AI.hash_state(current_state)
|
local state_hash = AI.hash_state(current_state)
|
||||||
local actions_hash = AI.hash_actions(available_actions)
|
local actions_hash = AI.hash_actions(available_actions)
|
||||||
|
|
||||||
-- Request action from AI if state or actions changed
|
if state_hash ~= last_state_hash or actions_hash ~= last_actions_hash or need_retry_request then
|
||||||
if state_hash ~= last_state_hash or actions_hash ~= last_actions_hash then
|
-- State has changed
|
||||||
if state_hash ~= last_state_hash then
|
if state_hash ~= last_state_hash then
|
||||||
utils.log_ai("State changed to: " ..
|
utils.log_ai("State changed to: " ..
|
||||||
current_state.state .. " (" .. utils.state_name(current_state.state) .. ")")
|
current_state.state .. " (" .. utils.get_state_name(current_state.state) .. ")")
|
||||||
action.reset_state()
|
action.reset_state()
|
||||||
last_state_hash = state_hash
|
last_state_hash = state_hash
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- Available actions have changed
|
||||||
if actions_hash ~= last_actions_hash then
|
if actions_hash ~= last_actions_hash then
|
||||||
utils.log_ai("Available actions changed: " ..
|
utils.log_ai("Available actions changed: " ..
|
||||||
table.concat(utils.get_action_names(available_actions), ", "))
|
table.concat(utils.get_action_names(available_actions), ", "))
|
||||||
last_actions_hash = actions_hash
|
last_actions_hash = actions_hash
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Request action from AI via file I/O
|
-- Request action from AI
|
||||||
local ai_response = communication.request_action(current_state, available_actions)
|
need_retry_request = false
|
||||||
if ai_response and ai_response.action ~= "no_action" then
|
local ai_response = communication.request_action(current_state, available_actions, retry_count)
|
||||||
pending_action = ai_response
|
|
||||||
|
if ai_response then
|
||||||
|
-- Handling handshake
|
||||||
|
if ai_response.action == "ready" then
|
||||||
|
utils.log_ai("Handshake complete - AI ready")
|
||||||
|
-- Don't return - continue normal game loop
|
||||||
|
-- Force a state check on next frame to send real request
|
||||||
|
last_state_hash = nil
|
||||||
|
else
|
||||||
|
pending_action = ai_response
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
@ -71,8 +124,12 @@ function AI.update()
|
||||||
local result = action.execute_action(pending_action.action, pending_action.params)
|
local result = action.execute_action(pending_action.action, pending_action.params)
|
||||||
if result.success then
|
if result.success then
|
||||||
utils.log_ai("Action executed successfully: " .. pending_action.action)
|
utils.log_ai("Action executed successfully: " .. pending_action.action)
|
||||||
|
retry_count = 0
|
||||||
else
|
else
|
||||||
utils.log_ai("Action failed: " .. (result.error or "Unknown error"))
|
-- Update retry_count to indicate state change
|
||||||
|
utils.log_ai("Action failed: " .. (result.error or "Unknown error") .. " RETRYING...")
|
||||||
|
retry_count = retry_count + 1
|
||||||
|
need_retry_request = true
|
||||||
end
|
end
|
||||||
pending_action = nil
|
pending_action = nil
|
||||||
end
|
end
|
||||||
|
|
@ -83,7 +140,7 @@ end
|
||||||
--- @param game_state table Current game state data
|
--- @param game_state table Current game state data
|
||||||
--- @return string Hash representing the current state
|
--- @return string Hash representing the current state
|
||||||
function AI.hash_state(game_state)
|
function AI.hash_state(game_state)
|
||||||
return game_state.state .. "_" .. (game_state.round or 0) .. "_" .. (game_state.chips or 0)
|
return game_state.state .. "_" .. (game_state.chips or 0) -- TODO update hash to be more unique
|
||||||
end
|
end
|
||||||
|
|
||||||
--- Create simple hash of available actions
|
--- Create simple hash of available actions
|
||||||
|
|
|
||||||
|
|
@ -1,124 +1,117 @@
|
||||||
--- RLBridge communication module
|
--- RLBridge communication module
|
||||||
--- Handles file-based communication between the game and external AI system
|
--- Handles dual pipe communication with persistent handles between the game and external AI system
|
||||||
|
|
||||||
local COMM = {}
|
local COMM = {}
|
||||||
local utils = require("utils")
|
local utils = require("utils")
|
||||||
local action = require("actions")
|
|
||||||
local json = require("dkjson")
|
local json = require("dkjson")
|
||||||
|
|
||||||
-- File-based communication settings
|
-- Dual pipe communication settings with persistent handles
|
||||||
local comm_enabled = false
|
local comm_enabled = false
|
||||||
local request_file = "/tmp/balatro_request.json"
|
local request_pipe = "/tmp/balatro_request"
|
||||||
local response_file = "/tmp/balatro_response.json"
|
local response_pipe = "/tmp/balatro_response"
|
||||||
local request_counter = 0
|
local request_handle = nil
|
||||||
local last_response_time = 0
|
local response_handle = nil
|
||||||
|
|
||||||
--- Check if response file has new content
|
--- Initialize dual pipe communication with persistent handles
|
||||||
--- @param expected_id number Expected request ID
|
--- Sets up persistent pipe handles with external AI system
|
||||||
--- @return table|nil Response data if new response available, nil otherwise
|
|
||||||
local function check_for_response(expected_id)
|
|
||||||
local response_file_handle = io.open(response_file, "r")
|
|
||||||
if not response_file_handle then
|
|
||||||
return nil
|
|
||||||
end
|
|
||||||
|
|
||||||
local response_json = response_file_handle:read("*all")
|
|
||||||
response_file_handle:close()
|
|
||||||
|
|
||||||
if not response_json or response_json == "" then
|
|
||||||
return nil
|
|
||||||
end
|
|
||||||
|
|
||||||
local response_data = json.decode(response_json)
|
|
||||||
if not response_data then
|
|
||||||
return nil
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Check if this is a new response for our request
|
|
||||||
if response_data.id == expected_id and response_data.timestamp then
|
|
||||||
if response_data.timestamp > last_response_time then
|
|
||||||
last_response_time = response_data.timestamp
|
|
||||||
return response_data
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return nil
|
|
||||||
end
|
|
||||||
|
|
||||||
--- Initialize file-based communication
|
|
||||||
--- Sets up file-based communication channel with external AI system
|
|
||||||
--- @return nil
|
--- @return nil
|
||||||
function COMM.init()
|
function COMM.init()
|
||||||
utils.log_comm("Initializing file-based communication...")
|
utils.log_comm("Initializing dual pipe communication with persistent handles...")
|
||||||
comm_enabled = true
|
utils.log_comm("Note: Pipes will be opened when first needed (lazy initialization)")
|
||||||
last_response_time = 0 -- Reset response tracking
|
comm_enabled = true -- Enable communication, pipes will open on first use
|
||||||
|
|
||||||
utils.log_comm("Ready for file-based requests")
|
|
||||||
utils.log_comm("Request file: " .. request_file)
|
|
||||||
utils.log_comm("Response file: " .. response_file)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
--- Send game turn request to AI and get action via files
|
--- Lazy initialization of pipe handles when first needed
|
||||||
|
--- @return boolean True if pipes are ready, false otherwise
|
||||||
|
function COMM.ensure_pipes_open()
|
||||||
|
if request_handle and response_handle then
|
||||||
|
return true -- Already open
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Try to open pipes (this will block until AI creates them)
|
||||||
|
utils.log_comm("Opening persistent pipe handles...")
|
||||||
|
|
||||||
|
-- Open response pipe for reading (keep open)
|
||||||
|
response_handle = io.open(response_pipe, "r")
|
||||||
|
if not response_handle then
|
||||||
|
utils.log_comm("ERROR: Cannot open response pipe for reading: " .. response_pipe)
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Open request pipe for writing (keep open)
|
||||||
|
request_handle = io.open(request_pipe, "w")
|
||||||
|
if not request_handle then
|
||||||
|
utils.log_comm("ERROR: Cannot open request pipe for writing: " .. request_pipe)
|
||||||
|
if response_handle then
|
||||||
|
response_handle:close()
|
||||||
|
response_handle = nil
|
||||||
|
end
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
utils.log_comm("Persistent pipe handles opened successfully")
|
||||||
|
utils.log_comm("Request pipe (write): " .. request_pipe)
|
||||||
|
utils.log_comm("Response pipe (read): " .. response_pipe)
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
|
||||||
|
--- Send game turn request to AI and get action via persistent pipe handles
|
||||||
--- @param game_state table Current game state data
|
--- @param game_state table Current game state data
|
||||||
--- @param available_actions table Available actions list
|
--- @param available_actions table Available actions list
|
||||||
--- @return table|nil Action response from AI, nil if error
|
--- @return table|nil Action response from AI, nil if error
|
||||||
function COMM.request_action(game_state, available_actions)
|
function COMM.request_action(game_state, available_actions, retry_count)
|
||||||
if not comm_enabled then
|
if not comm_enabled then
|
||||||
|
utils.log_comm("ERROR: Communication not enabled")
|
||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
||||||
request_counter = request_counter + 1
|
-- Lazy initialization - open pipes when first needed
|
||||||
|
if not COMM.ensure_pipes_open() then
|
||||||
|
utils.log_comm("ERROR: Failed to open pipe handles")
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
|
||||||
local request = {
|
local request = {
|
||||||
id = request_counter,
|
|
||||||
timestamp = os.time(),
|
|
||||||
game_state = game_state,
|
game_state = game_state,
|
||||||
available_actions = available_actions or {}
|
available_actions = available_actions or {},
|
||||||
|
retry_count = retry_count,
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.log_comm("Requesting action for state: " .. tostring(game_state.state))
|
utils.log_comm(utils.get_timestamp() .. "Sending action request for state: " ..
|
||||||
|
tostring(game_state.state) .. " (" .. utils.get_state_name(game_state.state) .. ")")
|
||||||
|
|
||||||
-- Write request to file
|
-- Encode request as JSON
|
||||||
local json_data = json.encode(request)
|
local json_data = json.encode(request)
|
||||||
if not json_data then
|
if not json_data then
|
||||||
utils.log_comm("ERROR: Failed to encode request JSON")
|
utils.log_comm("ERROR: Failed to encode request JSON")
|
||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
||||||
local request_file_handle = io.open(request_file, "w")
|
-- Write request to persistent handle
|
||||||
if not request_file_handle then
|
request_handle:write(json_data .. "\n")
|
||||||
utils.log_comm("ERROR: Cannot write to request file: " .. request_file)
|
request_handle:flush() -- Ensure data is sent immediately
|
||||||
|
utils.log_comm(utils.get_timestamp() .. "Request sent...")
|
||||||
|
|
||||||
|
-- Read response from persistent handle
|
||||||
|
utils.log_comm(utils.get_timestamp() .. "about to read the respones")
|
||||||
|
local response_json = response_handle:read("*line")
|
||||||
|
|
||||||
|
if not response_json or response_json == "" then
|
||||||
|
utils.log_comm("ERROR: No response received from AI")
|
||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
||||||
request_file_handle:write(json_data)
|
local response_data = json.decode(response_json)
|
||||||
request_file_handle:close()
|
if not response_data then
|
||||||
|
utils.log_comm("ERROR: Failed to decode response JSON")
|
||||||
-- Wait for response file (with timeout)
|
return nil
|
||||||
local max_wait_time = 2 -- seconds
|
|
||||||
local wait_interval = 0.05 -- seconds
|
|
||||||
local total_waited = 0
|
|
||||||
|
|
||||||
while total_waited < max_wait_time do
|
|
||||||
local response_data = check_for_response(request_counter)
|
|
||||||
if response_data then
|
|
||||||
utils.log_comm("AI action: " .. tostring(response_data.action))
|
|
||||||
return response_data
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Sleep for a short time using busy wait
|
|
||||||
local start_time = os.clock()
|
|
||||||
while (os.clock() - start_time) < wait_interval do
|
|
||||||
-- Busy wait for short duration
|
|
||||||
end
|
|
||||||
total_waited = total_waited + wait_interval
|
|
||||||
end
|
end
|
||||||
|
|
||||||
utils.log_comm("ERROR: Timeout waiting for AI response")
|
utils.log_comm("AI action: " .. tostring(response_data.action))
|
||||||
return nil
|
return response_data
|
||||||
end
|
end
|
||||||
|
|
||||||
--- Check if file communication is enabled
|
--- Check if pipe communication is enabled
|
||||||
--- Returns the current communication status
|
--- Returns the current communication status
|
||||||
--- @return boolean True if enabled, false otherwise
|
--- @return boolean True if enabled, false otherwise
|
||||||
function COMM.is_connected()
|
function COMM.is_connected()
|
||||||
|
|
@ -126,12 +119,22 @@ function COMM.is_connected()
|
||||||
end
|
end
|
||||||
|
|
||||||
--- Close communication
|
--- Close communication
|
||||||
--- Terminates the communication channel with the AI system
|
--- Terminates the persistent pipe handles with the AI system
|
||||||
--- @return nil
|
--- @return nil
|
||||||
function COMM.close()
|
function COMM.close()
|
||||||
comm_enabled = false
|
comm_enabled = false
|
||||||
last_response_time = 0
|
|
||||||
utils.log_comm("File communication disabled")
|
if request_handle then
|
||||||
|
request_handle:close()
|
||||||
|
request_handle = nil
|
||||||
|
end
|
||||||
|
|
||||||
|
if response_handle then
|
||||||
|
response_handle:close()
|
||||||
|
response_handle = nil
|
||||||
|
end
|
||||||
|
|
||||||
|
utils.log_comm("Persistent pipe communication closed")
|
||||||
end
|
end
|
||||||
|
|
||||||
return COMM
|
return COMM
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ G.SETTINGS.reduced_motion = true
|
||||||
function INIT.start_run()
|
function INIT.start_run()
|
||||||
print("Starting balatro run")
|
print("Starting balatro run")
|
||||||
|
|
||||||
|
|
||||||
-- Initialize AI system
|
-- Initialize AI system
|
||||||
local ai = require("ai")
|
local ai = require("ai")
|
||||||
ai.init()
|
ai.init()
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,20 @@
|
||||||
|
|
||||||
local O = {}
|
local O = {}
|
||||||
|
|
||||||
local actions = require("actions")
|
|
||||||
|
|
||||||
--- Get comprehensive game state for AI
|
--- Get comprehensive game state for AI
|
||||||
--- Collects all relevant game information into a structured format
|
--- Collects all relevant game information into a structured format
|
||||||
--- @return table Complete game state data for AI processing
|
--- @return table Complete game state data for AI processing
|
||||||
function O.get_game_state()
|
function O.get_game_state()
|
||||||
|
-- Calculate game_over
|
||||||
|
local game_over = 0
|
||||||
|
if G.STATE == G.STATES.GAME_OVER then
|
||||||
|
game_over = 1
|
||||||
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
-- Basic state info
|
-- Basic state info
|
||||||
state = G.STATE,
|
state = G.STATE,
|
||||||
|
game_over = game_over,
|
||||||
|
|
||||||
-- Game progression
|
-- Game progression
|
||||||
-- round = G.GAME and G.GAME.round or 0,
|
-- round = G.GAME and G.GAME.round or 0,
|
||||||
|
|
@ -27,16 +32,9 @@ function O.get_game_state()
|
||||||
-- Hand info (if applicable)
|
-- Hand info (if applicable)
|
||||||
hand = O.get_hand_info(),
|
hand = O.get_hand_info(),
|
||||||
|
|
||||||
-- -- Jokers info
|
-- Chips/score info
|
||||||
-- jokers = O.get_jokers_info(),
|
chips = G.GAME and G.GAME.chips or 0,
|
||||||
|
chip_target = G.GAME and G.GAME.blind and G.GAME.blind.chips or 0,
|
||||||
-- -- Chips/score info
|
|
||||||
-- chips = G.GAME and G.GAME.chips or 0,
|
|
||||||
-- chip_target = G.GAME and G.GAME.blind and G.GAME.blind.chips or 0,
|
|
||||||
|
|
||||||
-- -- Action feedback for AI
|
|
||||||
-- last_action_success = true, -- Did last action work?
|
|
||||||
-- last_action_error = nil, -- What went wrong?
|
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ end
|
||||||
--- Maps numeric state IDs to their string representations for debugging
|
--- Maps numeric state IDs to their string representations for debugging
|
||||||
--- @param state_id number Numeric state identifier
|
--- @param state_id number Numeric state identifier
|
||||||
--- @return string Human-readable state name or "UNKNOWN_STATE"
|
--- @return string Human-readable state name or "UNKNOWN_STATE"
|
||||||
function UTIL.state_name(state_id)
|
function UTIL.get_state_name(state_id)
|
||||||
for key, value in pairs(G.STATES) do
|
for key, value in pairs(G.STATES) do
|
||||||
if value == state_id then
|
if value == state_id then
|
||||||
return key
|
return key
|
||||||
|
|
@ -133,4 +133,22 @@ function UTIL.get_action_names(available_actions)
|
||||||
return names
|
return names
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function UTIL.contains(tbl, val)
|
||||||
|
for x, _ in ipairs(tbl) do
|
||||||
|
if tbl[x] == val then
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
--- Get current timestamp as formatted string
|
||||||
|
--- Returns current time in HH:MM:SS.mmm format for debugging
|
||||||
|
--- @return string Formatted timestamp
|
||||||
|
function UTIL.get_timestamp()
|
||||||
|
local time = os.time()
|
||||||
|
local ms = (os.clock() % 1) * 1000
|
||||||
|
return os.date("%H:%M:%S", time) .. string.format(".%03d", ms)
|
||||||
|
end
|
||||||
|
|
||||||
return UTIL
|
return UTIL
|
||||||
|
|
|
||||||
|
|
@ -1,95 +0,0 @@
|
||||||
"""
|
|
||||||
Balatro RL Agent
|
|
||||||
Contains the AI logic for making decisions in Balatro
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any, List
|
|
||||||
import random
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
class BalatroAgent:
|
|
||||||
"""Main AI agent for playing Balatro"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# TODO: Initialize your RL model here
|
|
||||||
# self.model = load_model("path/to/model")
|
|
||||||
# self.training = True
|
|
||||||
|
|
||||||
def get_action(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Given game state, return the action to take
|
|
||||||
|
|
||||||
Args:
|
|
||||||
game_state: Current game state from Balatro mod
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Action dictionary to send back to mod
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Extract relevant info from game state
|
|
||||||
available_actions = request.get('available_actions', [])
|
|
||||||
game_state = request.get('game_state', {})
|
|
||||||
|
|
||||||
self.logger.info(f"Processing state {game_state.get('state', 0)} with {len(available_actions)} actions")
|
|
||||||
|
|
||||||
# TODO: Replace with actual RL logic
|
|
||||||
action = self._get_random_action(available_actions) # TODO we are getting random actions for testing
|
|
||||||
|
|
||||||
# TODO: If training, update model with reward
|
|
||||||
# if self.training:
|
|
||||||
# self._update_model(game_state, action, reward)
|
|
||||||
|
|
||||||
return action
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error in get_action: {e}")
|
|
||||||
return {"action": "no_action"}
|
|
||||||
|
|
||||||
def _get_random_action(self, available_actions: List) -> Dict[str, Any]:
|
|
||||||
"""Temporary random action selection - replace with RL logic"""
|
|
||||||
|
|
||||||
# available_actions comes as [1, 2] format
|
|
||||||
if not available_actions or not isinstance(available_actions, List):
|
|
||||||
return {"action": "no_action"}
|
|
||||||
|
|
||||||
# Simple random action selection
|
|
||||||
selected_action_id = random.choice(available_actions)
|
|
||||||
self.logger.info(f"Selected action ID: {selected_action_id} from available: {available_actions}")
|
|
||||||
|
|
||||||
# TODO this is for testing, the AI doesn't really care about what the action_names are
|
|
||||||
# but we do so we can test out actions and make sure they work
|
|
||||||
# Map action IDs to names for logic (but return the ID)
|
|
||||||
action_names = {1: "start_run", 2: "select_blind", 3: "select_hand"}
|
|
||||||
action_name = action_names.get(selected_action_id, "unknown")
|
|
||||||
|
|
||||||
# Pick random cards for now
|
|
||||||
params = {}
|
|
||||||
if action_name == "select_hand":
|
|
||||||
params["card_indices"] = [1, 3, 5]
|
|
||||||
|
|
||||||
# Return the action ID (number), not the name
|
|
||||||
return {"action": selected_action_id, "params": params}
|
|
||||||
|
|
||||||
def train_step(self, game_state: Dict, action: Dict, reward: float, next_state: Dict):
|
|
||||||
"""
|
|
||||||
Perform one training step
|
|
||||||
|
|
||||||
TODO: Implement RL training logic here
|
|
||||||
- Update Q-values
|
|
||||||
- Update neural network weights
|
|
||||||
- Store experience in replay buffer
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def save_model(self, path: str):
|
|
||||||
"""Save the trained model"""
|
|
||||||
# TODO: Implement model saving
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load_model(self, path: str):
|
|
||||||
"""Load a trained model"""
|
|
||||||
# TODO: Implement model loading
|
|
||||||
pass
|
|
||||||
|
|
@ -1,83 +0,0 @@
|
||||||
"""
|
|
||||||
Neural network policy for Balatro RL agent
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
class BalatroPolicy(nn.Module):
|
|
||||||
"""
|
|
||||||
Neural network policy for Balatro decision making
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, state_dim: int = 128, action_dim: int = 64, hidden_dim: int = 256):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.state_dim = state_dim
|
|
||||||
self.action_dim = action_dim
|
|
||||||
|
|
||||||
# Feature extraction layers
|
|
||||||
self.feature_net = nn.Sequential(
|
|
||||||
nn.Linear(state_dim, hidden_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim, hidden_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim, hidden_dim),
|
|
||||||
nn.ReLU()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Action value head
|
|
||||||
self.value_head = nn.Linear(hidden_dim, 1)
|
|
||||||
|
|
||||||
# Action probability head
|
|
||||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
|
||||||
|
|
||||||
def forward(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Forward pass through the network
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: Game state tensor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (action_logits, state_value)
|
|
||||||
"""
|
|
||||||
features = self.feature_net(state)
|
|
||||||
|
|
||||||
action_logits = self.action_head(features)
|
|
||||||
state_value = self.value_head(features)
|
|
||||||
|
|
||||||
return action_logits, state_value
|
|
||||||
|
|
||||||
def get_action(self, state: np.ndarray, available_actions: list) -> tuple[int, float]:
|
|
||||||
"""
|
|
||||||
Get action from policy given current state
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: Current game state as numpy array
|
|
||||||
available_actions: List of available action indices
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (action_index, action_probability)
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
state_tensor = torch.FloatTensor(state).unsqueeze(0)
|
|
||||||
action_logits, _ = self.forward(state_tensor)
|
|
||||||
|
|
||||||
# Mask unavailable actions
|
|
||||||
masked_logits = action_logits.clone()
|
|
||||||
if available_actions:
|
|
||||||
mask = torch.full_like(action_logits, float('-inf'))
|
|
||||||
mask[0, available_actions] = 0
|
|
||||||
masked_logits = action_logits + mask
|
|
||||||
|
|
||||||
# Get action probabilities
|
|
||||||
action_probs = torch.softmax(masked_logits, dim=-1)
|
|
||||||
|
|
||||||
# Sample action
|
|
||||||
action_dist = torch.distributions.Categorical(action_probs)
|
|
||||||
action = action_dist.sample()
|
|
||||||
|
|
||||||
return action.item(), action_probs[0, action].item()
|
|
||||||
|
|
@ -1,153 +0,0 @@
|
||||||
"""
|
|
||||||
Training logic for Balatro RL agent
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.optim as optim
|
|
||||||
import numpy as np
|
|
||||||
from typing import Dict, List, Any
|
|
||||||
from .policy import BalatroPolicy
|
|
||||||
import logging
|
|
||||||
|
|
||||||
class BalatroTrainer:
|
|
||||||
"""
|
|
||||||
Handles training of the Balatro RL agent
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, policy: BalatroPolicy, learning_rate: float = 3e-4):
|
|
||||||
self.policy = policy
|
|
||||||
self.optimizer = optim.Adam(policy.parameters(), lr=learning_rate)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Training buffers
|
|
||||||
self.states = []
|
|
||||||
self.actions = []
|
|
||||||
self.rewards = []
|
|
||||||
self.values = []
|
|
||||||
self.log_probs = []
|
|
||||||
|
|
||||||
def collect_experience(self, state: np.ndarray, action: int, reward: float,
|
|
||||||
value: float, log_prob: float):
|
|
||||||
"""
|
|
||||||
Collect experience for training
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: Game state
|
|
||||||
action: Action taken
|
|
||||||
reward: Reward received
|
|
||||||
value: State value estimate
|
|
||||||
log_prob: Log probability of action
|
|
||||||
"""
|
|
||||||
self.states.append(state)
|
|
||||||
self.actions.append(action)
|
|
||||||
self.rewards.append(reward)
|
|
||||||
self.values.append(value)
|
|
||||||
self.log_probs.append(log_prob)
|
|
||||||
|
|
||||||
def compute_returns(self, next_value: float = 0.0, gamma: float = 0.99) -> List[float]:
|
|
||||||
"""
|
|
||||||
Compute discounted returns
|
|
||||||
|
|
||||||
Args:
|
|
||||||
next_value: Value of next state (for bootstrapping)
|
|
||||||
gamma: Discount factor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of discounted returns
|
|
||||||
"""
|
|
||||||
returns = []
|
|
||||||
R = next_value
|
|
||||||
|
|
||||||
for reward in reversed(self.rewards):
|
|
||||||
R = reward + gamma * R
|
|
||||||
returns.insert(0, R)
|
|
||||||
|
|
||||||
return returns
|
|
||||||
|
|
||||||
def update_policy(self, gamma: float = 0.99, entropy_coef: float = 0.01,
|
|
||||||
value_coef: float = 0.5) -> Dict[str, float]:
|
|
||||||
"""
|
|
||||||
Update policy using collected experience
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gamma: Discount factor
|
|
||||||
entropy_coef: Entropy regularization coefficient
|
|
||||||
value_coef: Value loss coefficient
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of training metrics
|
|
||||||
"""
|
|
||||||
if not self.states:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Convert to tensors
|
|
||||||
states = torch.FloatTensor(np.array(self.states))
|
|
||||||
actions = torch.LongTensor(self.actions)
|
|
||||||
old_log_probs = torch.FloatTensor(self.log_probs)
|
|
||||||
values = torch.FloatTensor(self.values)
|
|
||||||
|
|
||||||
# Compute returns
|
|
||||||
returns = self.compute_returns(gamma=gamma)
|
|
||||||
returns = torch.FloatTensor(returns)
|
|
||||||
|
|
||||||
# Compute advantages
|
|
||||||
advantages = returns - values
|
|
||||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
action_logits, new_values = self.policy(states)
|
|
||||||
|
|
||||||
# Compute losses
|
|
||||||
action_dist = torch.distributions.Categorical(logits=action_logits)
|
|
||||||
new_log_probs = action_dist.log_prob(actions)
|
|
||||||
entropy = action_dist.entropy().mean()
|
|
||||||
|
|
||||||
# Policy loss (PPO-style, but simplified)
|
|
||||||
ratio = torch.exp(new_log_probs - old_log_probs)
|
|
||||||
policy_loss = -(ratio * advantages).mean()
|
|
||||||
|
|
||||||
# Value loss
|
|
||||||
value_loss = torch.nn.functional.mse_loss(new_values.squeeze(), returns)
|
|
||||||
|
|
||||||
# Total loss
|
|
||||||
total_loss = policy_loss + value_coef * value_loss - entropy_coef * entropy
|
|
||||||
|
|
||||||
# Update
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
total_loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
# Clear buffers
|
|
||||||
self.clear_buffers()
|
|
||||||
|
|
||||||
# Return metrics
|
|
||||||
return {
|
|
||||||
'policy_loss': policy_loss.item(),
|
|
||||||
'value_loss': value_loss.item(),
|
|
||||||
'entropy': entropy.item(),
|
|
||||||
'total_loss': total_loss.item()
|
|
||||||
}
|
|
||||||
|
|
||||||
def clear_buffers(self):
|
|
||||||
"""Clear experience buffers"""
|
|
||||||
self.states.clear()
|
|
||||||
self.actions.clear()
|
|
||||||
self.rewards.clear()
|
|
||||||
self.values.clear()
|
|
||||||
self.log_probs.clear()
|
|
||||||
|
|
||||||
def save_model(self, path: str):
|
|
||||||
"""Save model checkpoint"""
|
|
||||||
torch.save({
|
|
||||||
'policy_state_dict': self.policy.state_dict(),
|
|
||||||
'optimizer_state_dict': self.optimizer.state_dict()
|
|
||||||
}, path)
|
|
||||||
self.logger.info(f"Model saved to {path}")
|
|
||||||
|
|
||||||
def load_model(self, path: str):
|
|
||||||
"""Load model checkpoint"""
|
|
||||||
checkpoint = torch.load(path)
|
|
||||||
self.policy.load_state_dict(checkpoint['policy_state_dict'])
|
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
||||||
self.logger.info(f"Model loaded from {path}")
|
|
||||||
|
|
@ -1,177 +1,190 @@
|
||||||
"""
|
"""
|
||||||
Balatro RL Environment
|
Balatro RL Environment
|
||||||
Wraps the HTTP communication in a gym-like interface for RL training
|
Wraps the pipe-based communication with Balatro mod in a standard RL interface.
|
||||||
|
This acts as a translator between Balatro's JSON pipe communication and
|
||||||
|
RL libraries that expect gym-style step()/reset() methods.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Dict, Any, Tuple, List, Optional
|
from typing import Dict, Any, Tuple, List, Optional
|
||||||
import logging
|
import logging
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium import spaces
|
||||||
|
from ..utils.debug import tprint
|
||||||
|
|
||||||
from ..utils.serialization import GameStateValidator
|
from ..utils.communication import BalatroPipeIO
|
||||||
|
from .reward import BalatroRewardCalculator
|
||||||
|
from ..utils.mappers import BalatroStateMapper, BalatroActionMapper
|
||||||
|
|
||||||
|
|
||||||
class BalatroEnvironment:
|
class BalatroEnv(gym.Env):
|
||||||
"""
|
"""
|
||||||
Gym-like environment interface for Balatro RL training
|
Standard RL Environment wrapper for Balatro
|
||||||
|
|
||||||
This class will eventually wrap the HTTP communication
|
Translates between:
|
||||||
and provide a standard RL interface with step(), reset(), etc.
|
- Balatro mod's JSON pipe communication (/tmp/balatro_request, /tmp/balatro_response)
|
||||||
|
- Standard RL interface (step, reset, observation spaces)
|
||||||
|
|
||||||
|
This allows RL libraries like Stable-Baselines3 to train on Balatro
|
||||||
|
without knowing about the underlying pipe communication system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.current_state = None
|
self.current_state = None
|
||||||
|
self.prev_state = None
|
||||||
self.game_over = False
|
self.game_over = False
|
||||||
|
|
||||||
# TODO: Define action and observation spaces
|
# Initialize communication and reward systems
|
||||||
# self.action_space = ...
|
self.pipe_io = BalatroPipeIO()
|
||||||
# self.observation_space = ...
|
self.reward_calculator = BalatroRewardCalculator()
|
||||||
|
|
||||||
def reset(self) -> Dict[str, Any]:
|
# Define Gymnasium spaces
|
||||||
"""Reset the environment for a new episode"""
|
# Action Spaces; This should describe the type and shape of the action
|
||||||
self.current_state = None
|
# Constants
|
||||||
self.game_over = False
|
self.MAX_ACTIONS = 10
|
||||||
|
action_selection = np.array([self.MAX_ACTIONS])
|
||||||
|
card_indices = np.array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) # Handles up to 10 cards in a hand
|
||||||
|
self.action_space = spaces.MultiDiscrete(np.concatenate([
|
||||||
|
action_selection,
|
||||||
|
card_indices
|
||||||
|
]))
|
||||||
|
ACTION_SLICE_LAYOUT = [
|
||||||
|
("action_selection", 1),
|
||||||
|
("card_indices", 10)
|
||||||
|
]
|
||||||
|
slices = self._build_action_slices(ACTION_SLICE_LAYOUT)
|
||||||
|
|
||||||
# TODO: Send reset signal to Balatro mod
|
# Observation space: This should describe the type and shape of the observation
|
||||||
# This might involve starting a new run
|
# Constants
|
||||||
|
self.OBSERVATION_SIZE = 70 # you can get this value by running test_env.py.
|
||||||
return self._get_initial_state()
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf, # lowest bound of observation data
|
||||||
|
high=np.inf, # highest bound of observation data
|
||||||
|
shape=(self.OBSERVATION_SIZE,), # Adjust based on actual state size which This is a 1D array
|
||||||
|
dtype=np.float32 # Data type of the numbers
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize mappers
|
||||||
|
self.state_mapper = BalatroStateMapper(observation_size=self.OBSERVATION_SIZE, max_actions=self.MAX_ACTIONS)
|
||||||
|
self.action_mapper = BalatroActionMapper(action_slices=slices)
|
||||||
|
|
||||||
def step(self, action: Dict[str, Any]) -> Tuple[Dict, float, bool, Dict]:
|
def reset(self, seed=None, options=None): #TODO add types
|
||||||
"""
|
"""
|
||||||
Take an action in the environment
|
Reset the environment for a new episode
|
||||||
|
|
||||||
|
In Balatro context, this means starting a new run.
|
||||||
|
Communicates with Balatro mod via pipes to initiate reset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Initial observation/game state
|
||||||
|
"""
|
||||||
|
self.current_state = None
|
||||||
|
self.prev_state = None
|
||||||
|
self.game_over = False
|
||||||
|
|
||||||
|
# Reset reward tracking
|
||||||
|
self.reward_calculator.reset()
|
||||||
|
|
||||||
|
# Wait for initial request from Balatro (game start)
|
||||||
|
initial_request = self.pipe_io.wait_for_request()
|
||||||
|
if not initial_request:
|
||||||
|
raise RuntimeError("Failed to receive initial request from Balatro")
|
||||||
|
|
||||||
|
# Send dummy response to complete hand shake
|
||||||
|
success = self.pipe_io.send_response({"action": "ready"})
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError("Failed to complete handshake")
|
||||||
|
|
||||||
|
# Process initial state for SB3
|
||||||
|
self.current_state = initial_request
|
||||||
|
initial_observation = self.state_mapper.process_game_state(self.current_state)
|
||||||
|
|
||||||
|
return initial_observation, {}
|
||||||
|
|
||||||
|
def step(self, action): #TODO add types
|
||||||
|
"""
|
||||||
|
Take an action in the Balatro environment
|
||||||
|
Sends action to Balatro mod via JSON pipe, waits for response,
|
||||||
|
calculates reward, and returns standard RL step format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action: Action to take
|
action: Action dictionary (e.g., {"action": 1, "params": {...}})
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (observation, reward, done, info)
|
Tuple of (observation, reward, done, info) where:
|
||||||
|
- observation: Processed game state for neural network
|
||||||
|
- reward: Calculated reward for this step
|
||||||
|
- done: Whether episode is finished (game over)
|
||||||
|
- info: Additional debug information
|
||||||
"""
|
"""
|
||||||
# TODO: Send action to Balatro via HTTP
|
# Store previous state for reward calculation
|
||||||
# TODO: Receive new game state
|
self.prev_state = self.current_state
|
||||||
# TODO: Calculate reward
|
|
||||||
# TODO: Check if episode is done
|
# Send action response to Balatro mod
|
||||||
|
response_data = self.action_mapper.process_action(rl_action=action)
|
||||||
|
|
||||||
observation = self._process_game_state(self.current_state)
|
success = self.pipe_io.send_response(response_data)
|
||||||
reward = self._calculate_reward()
|
if not success:
|
||||||
done = self.game_over
|
raise RuntimeError("Failed to send response to Balatro")
|
||||||
info = {}
|
|
||||||
|
|
||||||
return observation, reward, done, info
|
# Wait for next request with new game state
|
||||||
|
next_request = self.pipe_io.wait_for_request()
|
||||||
def _get_initial_state(self) -> Dict[str, Any]:
|
if not next_request:
|
||||||
"""Get initial game state"""
|
# If no response, assume game ended
|
||||||
# TODO: Implement
|
self.game_over = True
|
||||||
return {}
|
observation = self.state_mapper.process_game_state(self.current_state)
|
||||||
|
reward = 0.0
|
||||||
def _process_game_state(self, raw_state: Dict[str, Any]) -> Dict[str, Any]:
|
return observation, reward, True, {"timeout": True} #TODO this is a bug we need to return 5 values
|
||||||
"""
|
|
||||||
Process raw game state into format suitable for RL agent
|
|
||||||
|
|
||||||
This might involve:
|
# Update current state
|
||||||
- Extracting relevant features
|
self.current_state = next_request
|
||||||
- Normalizing values
|
|
||||||
- Converting to numpy arrays
|
|
||||||
"""
|
|
||||||
if not raw_state:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Validate state structure
|
# Process new state for SB3
|
||||||
try:
|
observation = self.state_mapper.process_game_state(self.current_state)
|
||||||
GameStateValidator.validate_game_state(raw_state)
|
|
||||||
except ValueError as e:
|
|
||||||
self.logger.error(f"Invalid game state: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# TODO: Extract and process features
|
# Calculate reward using expert reward calculator
|
||||||
processed = {
|
reward = self.reward_calculator.calculate_reward(
|
||||||
'hand_features': self._extract_hand_features(raw_state.get('hand', {})),
|
current_state=self.current_state,
|
||||||
'game_features': self._extract_game_features(raw_state),
|
prev_state=self.prev_state if self.prev_state else {}
|
||||||
'available_actions': raw_state.get('available_actions', [])
|
)
|
||||||
|
|
||||||
|
# Check if episode is done
|
||||||
|
done = bool(self.current_state.get('game_over', 0)) # TODO send a game_over in our state
|
||||||
|
terminated = done
|
||||||
|
truncated = False # Not using time limits for now
|
||||||
|
|
||||||
|
info = {
|
||||||
|
'balatro_state': self.current_state.get('state', 0),
|
||||||
|
'available_actions': next_request.get('available_actions', [])
|
||||||
}
|
}
|
||||||
|
# TODO have a feeling observation is not being return when we run out of actions which causes problems
|
||||||
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""
|
||||||
|
Clean up environment resources
|
||||||
|
|
||||||
return processed
|
Call this when shutting down to clean up pipe communication.
|
||||||
|
"""
|
||||||
def _extract_hand_features(self, hand: Dict[str, Any]) -> np.ndarray:
|
self.pipe_io.cleanup()
|
||||||
"""Extract numerical features from hand"""
|
|
||||||
# TODO: Convert hand cards to numerical representation
|
@staticmethod
|
||||||
# This might include:
|
def _build_action_slices(layout: List[Tuple[str, int]]) -> Dict[str, slice]:
|
||||||
# - Card values (2-14 for 2-A)
|
"""
|
||||||
# - Suits (0-3 for different suits)
|
Create slices for our actions so that we can precisely extract the
|
||||||
# - Special abilities/bonuses
|
right params to send over to balatro
|
||||||
|
|
||||||
cards = hand.get('cards', [])
|
Args:
|
||||||
features = []
|
layout: Our ACTION_SLICE_LAYOUT that contains action name and size
|
||||||
|
Return:
|
||||||
for card in cards:
|
A dictionary containing a key being our action space slice name, and
|
||||||
# Extract card features
|
the slice
|
||||||
suit_encoding = self._encode_suit(card.get('suit', ''))
|
"""
|
||||||
value_encoding = self._encode_value(card.get('base', {}).get('value', ''))
|
slices = {}
|
||||||
nominal = card.get('base', {}).get('nominal', 0)
|
start = 0
|
||||||
|
for action_name, size in layout:
|
||||||
# Add ability features
|
slices[action_name] = slice(start, start + size)
|
||||||
ability = card.get('ability', {})
|
start += size
|
||||||
ability_features = [
|
return slices
|
||||||
ability.get('t_chips', 0),
|
|
||||||
ability.get('t_mult', 0),
|
|
||||||
ability.get('x_mult', 1),
|
|
||||||
ability.get('mult', 0)
|
|
||||||
]
|
|
||||||
|
|
||||||
card_features = [suit_encoding, value_encoding, nominal] + ability_features
|
|
||||||
features.extend(card_features)
|
|
||||||
|
|
||||||
# Pad or truncate to fixed size (e.g., 8 cards max * 7 features each)
|
|
||||||
max_cards = 8
|
|
||||||
features_per_card = 7
|
|
||||||
|
|
||||||
if len(features) < max_cards * features_per_card:
|
|
||||||
features.extend([0] * (max_cards * features_per_card - len(features)))
|
|
||||||
else:
|
|
||||||
features = features[:max_cards * features_per_card]
|
|
||||||
|
|
||||||
return np.array(features, dtype=np.float32)
|
|
||||||
|
|
||||||
def _extract_game_features(self, state: Dict[str, Any]) -> np.ndarray:
|
|
||||||
"""Extract game-level features"""
|
|
||||||
# TODO: Extract features like:
|
|
||||||
# - Current chips
|
|
||||||
# - Target chips
|
|
||||||
# - Money
|
|
||||||
# - Round/ante
|
|
||||||
# - Discards remaining
|
|
||||||
# - Hands remaining
|
|
||||||
|
|
||||||
features = [
|
|
||||||
state.get('state', 0), # Game state
|
|
||||||
len(state.get('available_actions', [])), # Number of available actions
|
|
||||||
]
|
|
||||||
|
|
||||||
return np.array(features, dtype=np.float32)
|
|
||||||
|
|
||||||
def _encode_suit(self, suit: str) -> int:
|
|
||||||
"""Encode suit as integer"""
|
|
||||||
suit_map = {'Hearts': 0, 'Diamonds': 1, 'Clubs': 2, 'Spades': 3}
|
|
||||||
return suit_map.get(suit, 0)
|
|
||||||
|
|
||||||
def _encode_value(self, value: str) -> int:
|
|
||||||
"""Encode card value as integer"""
|
|
||||||
if value.isdigit():
|
|
||||||
return int(value)
|
|
||||||
|
|
||||||
value_map = {
|
|
||||||
'Jack': 11, 'Queen': 12, 'King': 13, 'Ace': 14,
|
|
||||||
'A': 14, 'K': 13, 'Q': 12, 'J': 11
|
|
||||||
}
|
|
||||||
return value_map.get(value, 0)
|
|
||||||
|
|
||||||
def _calculate_reward(self) -> float:
|
|
||||||
"""Calculate reward for current state"""
|
|
||||||
# TODO: Implement reward calculation
|
|
||||||
# This might be based on:
|
|
||||||
# - Chips scored
|
|
||||||
# - Blind completion
|
|
||||||
# - Round progression
|
|
||||||
# - Final score
|
|
||||||
|
|
||||||
return 0.0
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,11 @@
|
||||||
"""
|
"""
|
||||||
Reward calculation for Balatro RL environment
|
Balatro Reward System - The Expert on What's Good/Bad
|
||||||
|
|
||||||
|
This module is the single source of truth for reward calculation in Balatro RL.
|
||||||
|
All reward logic is centralized here to make experimentation and tuning easier.
|
||||||
|
|
||||||
|
The BalatroRewardCalculator analyzes game state changes and assigns rewards
|
||||||
|
that teach the AI what constitutes good vs bad Balatro gameplay.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
@ -7,114 +13,86 @@ import numpy as np
|
||||||
|
|
||||||
class BalatroRewardCalculator:
|
class BalatroRewardCalculator:
|
||||||
"""
|
"""
|
||||||
Calculates rewards for Balatro RL training
|
Expert reward calculator for Balatro RL training
|
||||||
|
|
||||||
|
This is the single authority on what constitutes good/bad play in Balatro.
|
||||||
|
Centralizes all reward logic for easy experimentation and tuning.
|
||||||
|
|
||||||
|
Reward philosophy:
|
||||||
|
- Positive rewards for progress (chips, rounds, money)
|
||||||
|
- Bonus rewards for efficiency (fewer hands/discards used)
|
||||||
|
- Large rewards for major milestones (completing antes)
|
||||||
|
- Negative rewards for game over (based on progress made)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.prev_score = 0
|
self.chips = 0
|
||||||
self.prev_money = 0
|
|
||||||
self.prev_round = 0
|
|
||||||
self.prev_ante = 0
|
|
||||||
|
|
||||||
def calculate_reward(self, current_state: Dict[str, Any],
|
def calculate_reward(self, current_state: Dict[str, Any],
|
||||||
prev_state: Dict[str, Any] = None) -> float:
|
prev_state: Dict[str, Any] = None) -> float:
|
||||||
"""
|
"""
|
||||||
Calculate reward based on game state changes
|
Main reward calculation method - analyzes state changes and assigns rewards
|
||||||
|
|
||||||
|
This is the core method that determines what the AI should optimize for.
|
||||||
|
Examines differences between previous and current game state to calculate
|
||||||
|
appropriate rewards for the action that caused the transition.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
current_state: Current game state
|
current_state: Current Balatro game state
|
||||||
prev_state: Previous game state (optional)
|
prev_state: Previous Balatro game state (None for first step)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Calculated reward
|
Float reward value (positive = good, negative = bad, zero = neutral)
|
||||||
"""
|
"""
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
|
|
||||||
# Extract relevant metrics
|
# Extract relevant metrics
|
||||||
current_score = current_state.get('score', 0)
|
current_chips = current_state.get('chips', 0)
|
||||||
current_money = current_state.get('money', 0)
|
game_over = current_state.get('game_over', False) #TODO fix
|
||||||
current_round = current_state.get('round', 0)
|
|
||||||
current_ante = current_state.get('ante', 0)
|
|
||||||
game_over = current_state.get('game_over', False)
|
|
||||||
|
|
||||||
# Score-based rewards
|
# Score-based rewards
|
||||||
score_diff = current_score - self.prev_score
|
chip_diff = current_chips - self.chips
|
||||||
if score_diff > 0:
|
if chip_diff > 0:
|
||||||
reward += score_diff * 0.001 # Small reward for score increases
|
reward += chip_diff * 0.001 # Small reward for score increases
|
||||||
|
|
||||||
# Money-based rewards
|
|
||||||
money_diff = current_money - self.prev_money
|
|
||||||
if money_diff > 0:
|
|
||||||
reward += money_diff * 0.01 # Reward for gaining money
|
|
||||||
|
|
||||||
# Round progression rewards
|
|
||||||
if current_round > self.prev_round:
|
|
||||||
reward += 10.0 # Big reward for completing rounds
|
|
||||||
|
|
||||||
# Ante progression rewards
|
|
||||||
if current_ante > self.prev_ante:
|
|
||||||
reward += 50.0 # Very big reward for completing antes
|
|
||||||
|
|
||||||
# Game over penalty
|
# Game over penalty
|
||||||
if game_over:
|
if game_over:
|
||||||
# Penalty based on how early the game ended
|
# Penalty based on how early the game ended
|
||||||
max_ante = 8 # Typical max ante in Balatro
|
# max_ante = 8 # Typical max ante in Balatro
|
||||||
completion_ratio = current_ante / max_ante
|
# completion_ratio = current_ante / max_ante
|
||||||
reward += completion_ratio * 100.0 # Reward based on progress
|
# reward += completion_ratio * 100.0 # Reward based on progress
|
||||||
|
reward -= 10 #TODO keeping this simple for now
|
||||||
|
|
||||||
# Efficiency rewards (optional)
|
|
||||||
# reward += self._calculate_efficiency_reward(current_state)
|
|
||||||
|
|
||||||
# Update previous state
|
# Update previous state
|
||||||
self.prev_score = current_score
|
self.chips = current_chips
|
||||||
self.prev_money = current_money
|
|
||||||
self.prev_round = current_round
|
|
||||||
self.prev_ante = current_ante
|
|
||||||
|
|
||||||
return reward
|
|
||||||
|
|
||||||
def _calculate_efficiency_reward(self, state: Dict[str, Any]) -> float:
|
|
||||||
"""
|
|
||||||
Calculate rewards for efficient play
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: Current game state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Efficiency reward
|
|
||||||
"""
|
|
||||||
reward = 0.0
|
|
||||||
|
|
||||||
# Reward for using fewer discards
|
|
||||||
discards_remaining = state.get('discards', 0)
|
|
||||||
if discards_remaining > 0:
|
|
||||||
reward += discards_remaining * 0.1
|
|
||||||
|
|
||||||
# Reward for using fewer hands
|
|
||||||
hands_remaining = state.get('hands', 0)
|
|
||||||
if hands_remaining > 0:
|
|
||||||
reward += hands_remaining * 0.5
|
|
||||||
|
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset reward calculator for new episode"""
|
"""
|
||||||
self.prev_score = 0
|
Reset reward calculator state for new episode
|
||||||
self.prev_money = 0
|
|
||||||
self.prev_round = 0
|
Called at the start of each new Balatro run to clear
|
||||||
self.prev_ante = 0
|
previous state tracking variables.
|
||||||
|
"""
|
||||||
|
self.chips = 0
|
||||||
|
|
||||||
def get_shaped_reward(self, state: Dict[str, Any], action: str) -> float:
|
def get_shaped_reward(self, state: Dict[str, Any], action: str) -> float:
|
||||||
"""
|
"""
|
||||||
Get shaped reward based on specific actions
|
Calculate action-specific reward shaping
|
||||||
|
|
||||||
|
Provides small rewards/penalties for specific actions to guide
|
||||||
|
AI behavior beyond just game state changes. Use sparingly to
|
||||||
|
avoid overriding the main reward signal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: Current game state
|
state: Current game state when action was taken
|
||||||
action: Action taken
|
action: Action that was taken (for action-specific rewards)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Shaped reward
|
Small shaped reward value (usually -1 to +1)
|
||||||
"""
|
"""
|
||||||
|
# TODO look at this this could help but currently not in a working state
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
|
|
||||||
# Action-specific rewards
|
# Action-specific rewards
|
||||||
|
|
@ -134,4 +112,4 @@ class BalatroRewardCalculator:
|
||||||
# Neutral - depends on if it's a good purchase
|
# Neutral - depends on if it's a good purchase
|
||||||
reward += 0.0
|
reward += 0.0
|
||||||
|
|
||||||
return reward
|
return reward
|
||||||
|
|
|
||||||
|
|
@ -1,111 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
File-based communication handler for Balatro RL
|
|
||||||
Watches for request files and responds with AI actions
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from watchdog.observers import Observer
|
|
||||||
from watchdog.events import FileSystemEventHandler
|
|
||||||
|
|
||||||
from agent.balatro_agent import BalatroAgent
|
|
||||||
|
|
||||||
# File paths
|
|
||||||
REQUEST_FILE = "/tmp/balatro_request.json"
|
|
||||||
RESPONSE_FILE = "/tmp/balatro_response.json"
|
|
||||||
|
|
||||||
class RequestHandler(FileSystemEventHandler):
|
|
||||||
"""Handles incoming request files from Balatro"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.agent = BalatroAgent()
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def on_created(self, event):
|
|
||||||
"""Handle new request file creation"""
|
|
||||||
if event.src_path == REQUEST_FILE and not event.is_directory:
|
|
||||||
self.process_request()
|
|
||||||
|
|
||||||
def on_modified(self, event):
|
|
||||||
"""Handle request file modification"""
|
|
||||||
if event.src_path == REQUEST_FILE and not event.is_directory:
|
|
||||||
self.process_request()
|
|
||||||
|
|
||||||
def process_request(self):
|
|
||||||
"""Process a request file and generate response"""
|
|
||||||
try:
|
|
||||||
# Small delay to ensure file write is complete
|
|
||||||
time.sleep(0.01)
|
|
||||||
|
|
||||||
# Read request
|
|
||||||
with open(REQUEST_FILE, 'r') as f:
|
|
||||||
request_data = json.load(f)
|
|
||||||
|
|
||||||
self.logger.info(f"Processing request {request_data.get('id')}: state={request_data.get('state')}")
|
|
||||||
|
|
||||||
# Get action from AI agent
|
|
||||||
action_response = self.agent.get_action(request_data)
|
|
||||||
|
|
||||||
# Write response
|
|
||||||
response = {
|
|
||||||
"id": request_data.get("id"),
|
|
||||||
"action": action_response.get("action", "no_action"),
|
|
||||||
"params": action_response.get("params"),
|
|
||||||
"timestamp": time.time()
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(RESPONSE_FILE, 'w') as f:
|
|
||||||
json.dump(response, f)
|
|
||||||
|
|
||||||
self.logger.info(f"Response written: {response['action']}")
|
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error processing request: {e}")
|
|
||||||
|
|
||||||
# Write error response
|
|
||||||
error_response = {
|
|
||||||
"id": request_data.get("id", 0) if 'request_data' in locals() else 0,
|
|
||||||
"action": "no_action",
|
|
||||||
"params": {"error": str(e)},
|
|
||||||
"timestamp": time.time()
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(RESPONSE_FILE, 'w') as f:
|
|
||||||
json.dump(error_response, f)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main file watcher loop"""
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.info("Starting Balatro RL file watcher...")
|
|
||||||
|
|
||||||
# Set up file watcher
|
|
||||||
event_handler = RequestHandler()
|
|
||||||
observer = Observer()
|
|
||||||
observer.schedule(event_handler, path="/tmp", recursive=False)
|
|
||||||
|
|
||||||
observer.start()
|
|
||||||
logger.info(f"Watching for requests at: {REQUEST_FILE}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Stopping file watcher...")
|
|
||||||
observer.stop()
|
|
||||||
|
|
||||||
observer.join()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
numpy>=1.21.0
|
numpy>=1.21.0
|
||||||
torch>=1.9.0
|
torch>=1.9.0
|
||||||
gymnasium>=0.26.0 # For RL environment interface
|
gymnasium>=0.26.0 # For RL environment interface
|
||||||
stable-baselines3>=1.6.0 # For RL algorithms
|
stable-baselines3[extra]>=1.6.0 # For RL algorithms
|
||||||
tensorboard>=2.8.0 # For training visualization
|
tensorboard>=2.8.0 # For training visualization
|
||||||
watchdog>=2.1.0 # For file system monitoring
|
|
||||||
|
|
@ -22,5 +22,5 @@ echo "Setup complete!"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Usage:"
|
echo "Usage:"
|
||||||
echo " To activate environment: source venv/bin/activate"
|
echo " To activate environment: source venv/bin/activate"
|
||||||
echo " To run file watcher: python file_watcher.py"
|
echo " To test environment is working: python -m ai.test_env"
|
||||||
echo " To train agent: python -m agent.training"
|
echo " To train agent: python -m ai.train_balatro"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
"""
|
||||||
|
Environment Testing Script for Balatro RL
|
||||||
|
|
||||||
|
Tests the BalatroEnv to ensure it follows the Gym interface and
|
||||||
|
communicates properly with the Balatro mod via file I/O.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python test_env.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
# from stable_baselines3.common.env_checker import check_env # Not compatible with pipes
|
||||||
|
from .environment.balatro_env import BalatroEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_manual_actions():
|
||||||
|
"""Manual testing with random actions"""
|
||||||
|
print("\n=== Manual Action Testing ===")
|
||||||
|
|
||||||
|
env = BalatroEnv()
|
||||||
|
print("BalatroEnv created successfully")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Reset and get initial observation (ONLY ONCE!)
|
||||||
|
print("Calling env.reset() - this should wait for Balatro...")
|
||||||
|
obs, info = env.reset()
|
||||||
|
print("env.reset() completed!")
|
||||||
|
print(f"Initial observation: {obs}")
|
||||||
|
print(f"Observation space: {env.observation_space}")
|
||||||
|
print(f"Action space: {env.action_space}")
|
||||||
|
print(f"Sample action: {env.action_space.sample()}")
|
||||||
|
|
||||||
|
# Run test episodes
|
||||||
|
max_steps = 3
|
||||||
|
max_episodes = 1
|
||||||
|
|
||||||
|
for episode in range(max_episodes):
|
||||||
|
print(f"\n--- Episode {episode + 1} ---")
|
||||||
|
# Don't call reset() again! Use the obs from above
|
||||||
|
|
||||||
|
for step in range(max_steps):
|
||||||
|
# Get random action (like the old balatro_agent)
|
||||||
|
action = env.action_space.sample()
|
||||||
|
|
||||||
|
print(f"Step {step + 1}: Taking action {action}")
|
||||||
|
|
||||||
|
# Take action
|
||||||
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
done = terminated or truncated
|
||||||
|
|
||||||
|
print(f" obs={obs}, reward={reward}, done={done}")
|
||||||
|
|
||||||
|
# Optional: Add delay for readability
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
print(f" Episode finished after {step + 1} steps")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not done:
|
||||||
|
print(f" Episode reached max steps ({max_steps})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Manual testing failed: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
print("✓ Manual testing completed!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all tests"""
|
||||||
|
print("Starting Balatro Environment Testing...")
|
||||||
|
|
||||||
|
# Setup logging to see communication details
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
success = True
|
||||||
|
success &= test_manual_actions()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print("\n🎉 All tests passed! Environment is ready for training.")
|
||||||
|
else:
|
||||||
|
print("\n❌ Some tests failed. Check the environment implementation.")
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,273 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Balatro RL Training Script
|
||||||
|
|
||||||
|
Main training script for teaching AI to play Balatro using Stable-Baselines3.
|
||||||
|
This script creates the Balatro environment, sets up the RL model, and runs training.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python train_balatro.py
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- Balatro game running with RLBridge mod
|
||||||
|
- file_watcher.py should NOT be running (this replaces it)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# SB3 imports
|
||||||
|
from stable_baselines3 import DQN, A2C, PPO
|
||||||
|
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
|
||||||
|
from stable_baselines3.common.monitor import Monitor
|
||||||
|
|
||||||
|
# Our custom environment
|
||||||
|
from .environment.balatro_env import BalatroEnv
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging():
|
||||||
|
"""Setup logging for training"""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler('training.log'),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_environment():
|
||||||
|
"""Create and wrap the Balatro environment"""
|
||||||
|
# Create base environment
|
||||||
|
env = BalatroEnv()
|
||||||
|
|
||||||
|
# Wrap with Monitor for logging episode stats
|
||||||
|
env = Monitor(env, filename="training_monitor.csv")
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(env, algorithm="DQN", model_path=None):
|
||||||
|
"""
|
||||||
|
Create SB3 model for training
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: Balatro environment
|
||||||
|
algorithm: RL algorithm to use ("DQN", "A2C", "PPO")
|
||||||
|
model_path: Path to load existing model (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SB3 model ready for training
|
||||||
|
"""
|
||||||
|
if algorithm == "DQN":
|
||||||
|
model = DQN(
|
||||||
|
"MlpPolicy",
|
||||||
|
env,
|
||||||
|
verbose=1,
|
||||||
|
learning_rate=1e-4,
|
||||||
|
buffer_size=10000,
|
||||||
|
learning_starts=1000,
|
||||||
|
target_update_interval=500,
|
||||||
|
train_freq=4,
|
||||||
|
gradient_steps=1,
|
||||||
|
exploration_fraction=0.1,
|
||||||
|
exploration_initial_eps=1.0,
|
||||||
|
exploration_final_eps=0.05,
|
||||||
|
tensorboard_log="./tensorboard_logs/"
|
||||||
|
)
|
||||||
|
elif algorithm == "A2C":
|
||||||
|
model = A2C(
|
||||||
|
"MlpPolicy",
|
||||||
|
env,
|
||||||
|
verbose=1,
|
||||||
|
learning_rate=7e-4,
|
||||||
|
n_steps=5,
|
||||||
|
gamma=0.99,
|
||||||
|
tensorboard_log="./tensorboard_logs/"
|
||||||
|
)
|
||||||
|
elif algorithm == "PPO":
|
||||||
|
model = PPO(
|
||||||
|
"MlpPolicy",
|
||||||
|
env,
|
||||||
|
verbose=1,
|
||||||
|
learning_rate=3e-4,
|
||||||
|
n_steps=2048,
|
||||||
|
batch_size=64,
|
||||||
|
n_epochs=10,
|
||||||
|
gamma=0.99,
|
||||||
|
tensorboard_log="./tensorboard_logs/"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown algorithm: {algorithm}")
|
||||||
|
|
||||||
|
# Load existing model if path provided
|
||||||
|
if model_path and Path(model_path).exists():
|
||||||
|
model.load(model_path)
|
||||||
|
print(f"Loaded existing model from {model_path}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def create_callbacks(save_freq=1000):
|
||||||
|
"""Create training callbacks for saving and evaluation"""
|
||||||
|
callbacks = []
|
||||||
|
|
||||||
|
# Checkpoint callback - save model periodically
|
||||||
|
checkpoint_callback = CheckpointCallback(
|
||||||
|
save_freq=save_freq,
|
||||||
|
save_path="./models/",
|
||||||
|
name_prefix="balatro_model"
|
||||||
|
)
|
||||||
|
callbacks.append(checkpoint_callback)
|
||||||
|
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
|
def train_agent(total_timesteps=50000, algorithm="DQN", save_path="./models/balatro_final"):
|
||||||
|
"""
|
||||||
|
Main training function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_timesteps: Number of training steps
|
||||||
|
algorithm: RL algorithm to use
|
||||||
|
save_path: Where to save final model
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(f"Starting Balatro RL training with {algorithm}")
|
||||||
|
logger.info(f"Training for {total_timesteps} timesteps")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create environment and model
|
||||||
|
env = create_environment()
|
||||||
|
model = create_model(env, algorithm)
|
||||||
|
|
||||||
|
# Create callbacks
|
||||||
|
callbacks = create_callbacks(save_freq=max(1000, total_timesteps // 20))
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
logger.info("Starting training...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
model.learn(
|
||||||
|
total_timesteps=total_timesteps,
|
||||||
|
callback=callbacks,
|
||||||
|
progress_bar=True
|
||||||
|
)
|
||||||
|
|
||||||
|
training_time = time.time() - start_time
|
||||||
|
logger.info(f"Training completed in {training_time:.2f} seconds")
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
model.save(save_path)
|
||||||
|
logger.info(f"Model saved to {save_path}")
|
||||||
|
|
||||||
|
# Clean up environment
|
||||||
|
if hasattr(env, 'cleanup'):
|
||||||
|
env.cleanup()
|
||||||
|
elif hasattr(env.env, 'cleanup'): # Monitor wrapper
|
||||||
|
env.env.cleanup()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Training interrupted by user")
|
||||||
|
if hasattr(env, 'cleanup'):
|
||||||
|
env.cleanup()
|
||||||
|
elif hasattr(env.env, 'cleanup'):
|
||||||
|
env.env.cleanup()
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training failed: {e}")
|
||||||
|
if hasattr(env, 'cleanup'):
|
||||||
|
env.cleanup()
|
||||||
|
elif hasattr(env.env, 'cleanup'):
|
||||||
|
env.env.cleanup()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_trained_model(model_path, num_episodes=5):
|
||||||
|
"""
|
||||||
|
Test a trained model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to trained model
|
||||||
|
num_episodes: Number of episodes to test
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(f"Testing model from {model_path}")
|
||||||
|
|
||||||
|
# Create environment and load model
|
||||||
|
env = create_environment()
|
||||||
|
model = DQN.load(model_path)
|
||||||
|
|
||||||
|
episode_rewards = []
|
||||||
|
|
||||||
|
for episode in range(num_episodes):
|
||||||
|
obs = env.reset()
|
||||||
|
total_reward = 0
|
||||||
|
steps = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
action, _states = model.predict(obs, deterministic=True)
|
||||||
|
obs, reward, done, info = env.step(action)
|
||||||
|
total_reward += reward
|
||||||
|
steps += 1
|
||||||
|
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
episode_rewards.append(total_reward)
|
||||||
|
logger.info(f"Episode {episode + 1}: {steps} steps, reward: {total_reward:.2f}")
|
||||||
|
|
||||||
|
avg_reward = sum(episode_rewards) / len(episode_rewards)
|
||||||
|
logger.info(f"Average reward over {num_episodes} episodes: {avg_reward:.2f}")
|
||||||
|
|
||||||
|
env.cleanup()
|
||||||
|
return episode_rewards
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Setup
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
|
# Create necessary directories
|
||||||
|
Path(".models").mkdir(exist_ok=True)
|
||||||
|
Path(".tensorboard_logs").mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# We'll let BalatroEnv create the pipes when needed
|
||||||
|
print("🔧 Pipe communication will be initialized with environment...")
|
||||||
|
|
||||||
|
# Train the agent
|
||||||
|
print("\n🎮 Starting Balatro RL Training!")
|
||||||
|
print("Setup steps:")
|
||||||
|
print("1. ✅ Balatro is running with RLBridge mod")
|
||||||
|
print("2. ✅ Balatro is in menu state")
|
||||||
|
print("3. Press Enter below to start training")
|
||||||
|
print("4. When prompted, press 'R' in Balatro within 3 seconds")
|
||||||
|
print("\n📡 Training will create pipes and connect automatically!")
|
||||||
|
print()
|
||||||
|
|
||||||
|
input("Press Enter to start training (then you'll have 3 seconds to press 'R' in Balatro)...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = train_agent(
|
||||||
|
total_timesteps=10000, # Start small for testing
|
||||||
|
algorithm="PPO", # PPO supports MultiDiscrete actions
|
||||||
|
save_path="./models/balatro_trained"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model:
|
||||||
|
print("\n🎉 Training completed successfully!")
|
||||||
|
print("Testing the trained model...")
|
||||||
|
|
||||||
|
# Test the trained model
|
||||||
|
test_trained_model("./models/balatro_trained", num_episodes=3)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Training failed: {e}")
|
||||||
|
print("Check the logs for more details.")
|
||||||
|
finally:
|
||||||
|
# Pipes will be cleaned up by the environment
|
||||||
|
print("🧹 Training session ended")
|
||||||
|
|
@ -0,0 +1,192 @@
|
||||||
|
"""
|
||||||
|
Balatro Dual Pipe Communication Abstraction
|
||||||
|
|
||||||
|
Pure dual pipe I/O layer with persistent handles for communicating with Balatro mod.
|
||||||
|
Handles low-level pipe operations without any game logic or AI logic.
|
||||||
|
|
||||||
|
This abstraction can be used by:
|
||||||
|
- balatro_env.py (for SB3 training)
|
||||||
|
- file_watcher.py (for testing)
|
||||||
|
- Any other component that needs to talk to Balatro
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class BalatroPipeIO:
|
||||||
|
"""
|
||||||
|
Clean abstraction for dual pipe communication with Balatro mod using persistent handles
|
||||||
|
|
||||||
|
Responsibilities:
|
||||||
|
- Read JSON requests from Balatro mod via request pipe
|
||||||
|
- Write JSON responses back to Balatro mod via response pipe
|
||||||
|
- Keep pipe handles open persistently to avoid deadlocks
|
||||||
|
- Provide simple, clean interface for higher-level code
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_pipe: str = "/tmp/balatro_request", response_pipe: str = "/tmp/balatro_response"):
|
||||||
|
self.request_pipe = request_pipe
|
||||||
|
self.response_pipe = response_pipe
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Persistent pipe handles
|
||||||
|
self.request_handle = None
|
||||||
|
self.response_handle = None
|
||||||
|
|
||||||
|
# Create pipes and open persistent handles
|
||||||
|
self.create_pipes()
|
||||||
|
self.open_persistent_handles()
|
||||||
|
|
||||||
|
def create_pipes(self) -> None:
|
||||||
|
"""
|
||||||
|
Create dual named pipes for communication
|
||||||
|
|
||||||
|
Creates both request and response pipes if they don't exist.
|
||||||
|
Safe to call multiple times.
|
||||||
|
"""
|
||||||
|
for pipe_path in [self.request_pipe, self.response_pipe]:
|
||||||
|
try:
|
||||||
|
# Remove existing pipe if it exists
|
||||||
|
if os.path.exists(pipe_path):
|
||||||
|
os.unlink(pipe_path)
|
||||||
|
self.logger.debug(f"Removed existing pipe: {pipe_path}")
|
||||||
|
|
||||||
|
# Create named pipe
|
||||||
|
os.mkfifo(pipe_path)
|
||||||
|
self.logger.info(f"Created pipe: {pipe_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to create pipe {pipe_path}: {e}")
|
||||||
|
raise RuntimeError(f"Could not create communication pipe: {pipe_path}")
|
||||||
|
|
||||||
|
def open_persistent_handles(self) -> None:
|
||||||
|
"""
|
||||||
|
Open persistent handles for reading and writing
|
||||||
|
|
||||||
|
Opens request pipe for reading and response pipe for writing.
|
||||||
|
Keeps handles open to avoid deadlocks.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.logger.info(f"🔧 Waiting for Balatro to connect...")
|
||||||
|
self.logger.info(f" Press 'R' in Balatro now to activate RL training!")
|
||||||
|
|
||||||
|
# Open request pipe for reading (Balatro writes to this)
|
||||||
|
self.request_handle = open(self.request_pipe, 'r')
|
||||||
|
|
||||||
|
# Open response pipe for writing (AI writes to this)
|
||||||
|
self.response_handle = open(self.response_pipe, 'w')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to open persistent handles: {e}")
|
||||||
|
self.cleanup_handles()
|
||||||
|
raise RuntimeError(f"Could not open persistent pipe handles: {e}")
|
||||||
|
|
||||||
|
def wait_for_request(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Wait for new request from Balatro mod using persistent handle
|
||||||
|
|
||||||
|
Blocks until Balatro writes a request to the pipe.
|
||||||
|
No timeout needed - pipes block until data arrives.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON request data, or None if error
|
||||||
|
"""
|
||||||
|
if not self.request_handle:
|
||||||
|
self.logger.error("Request handle not open")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read from persistent request handle
|
||||||
|
request_line = self.request_handle.readline().strip()
|
||||||
|
if not request_line:
|
||||||
|
return None
|
||||||
|
|
||||||
|
request_data = json.loads(request_line)
|
||||||
|
self.logger.info(f"📥 RECEIVED REQUEST: {request_line}...") # Show first 100 chars
|
||||||
|
return request_data
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.error(f"Invalid JSON in request: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error reading from request pipe: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def send_response(self, response_data: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Send response back to Balatro mod using persistent handle
|
||||||
|
|
||||||
|
Writes response data to response pipe for Balatro to read.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_data: Response dictionary to send
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False if error
|
||||||
|
"""
|
||||||
|
if not self.response_handle:
|
||||||
|
self.logger.error("Response handle not open")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write to persistent response handle
|
||||||
|
json.dump(response_data, self.response_handle)
|
||||||
|
self.response_handle.write('\n') # Important: newline for pipe communication
|
||||||
|
self.response_handle.flush() # Force write to pipe immediately
|
||||||
|
|
||||||
|
self.logger.info(f"📤 SENT RESPONSE: {json.dumps(response_data)}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"❌ ERROR sending response: {e}")
|
||||||
|
import traceback
|
||||||
|
self.logger.error(f"❌ TRACEBACK: {traceback.format_exc()}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cleanup_handles(self):
|
||||||
|
"""
|
||||||
|
Close persistent pipe handles
|
||||||
|
|
||||||
|
Closes the open file handles.
|
||||||
|
"""
|
||||||
|
if self.request_handle:
|
||||||
|
try:
|
||||||
|
self.request_handle.close()
|
||||||
|
self.logger.debug("Closed request handle")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to close request handle: {e}")
|
||||||
|
self.request_handle = None
|
||||||
|
|
||||||
|
if self.response_handle:
|
||||||
|
try:
|
||||||
|
self.response_handle.close()
|
||||||
|
self.logger.debug("Closed response handle")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to close response handle: {e}")
|
||||||
|
self.response_handle = None
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""
|
||||||
|
Clean up communication pipes and handles
|
||||||
|
|
||||||
|
Closes handles and removes pipe files from the filesystem.
|
||||||
|
"""
|
||||||
|
# Close handles first
|
||||||
|
self.cleanup_handles()
|
||||||
|
|
||||||
|
# Remove pipe files
|
||||||
|
for pipe_path in [self.request_pipe, self.response_pipe]:
|
||||||
|
try:
|
||||||
|
if os.path.exists(pipe_path):
|
||||||
|
os.unlink(pipe_path)
|
||||||
|
self.logger.debug(f"Removed pipe: {pipe_path}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to remove pipe {pipe_path}: {e}")
|
||||||
|
|
||||||
|
self.logger.debug("Dual pipe communication cleanup complete")
|
||||||
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""
|
||||||
|
Debug utilities for timestamped print statements and debugging helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def timestamp() -> str:
|
||||||
|
"""
|
||||||
|
Get current timestamp formatted as HH:MM:SS.mmm
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Formatted timestamp string
|
||||||
|
"""
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
return now.strftime("%H:%M:%S.%f")[:-3] # Remove last 3 digits to get milliseconds
|
||||||
|
|
||||||
|
|
||||||
|
def tprint(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Print with timestamp prefix
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Arguments to print
|
||||||
|
**kwargs: Keyword arguments for print()
|
||||||
|
"""
|
||||||
|
print(f"[{timestamp()}]", *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def dprint(prefix: str, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Debug print with custom prefix and timestamp
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Custom prefix string
|
||||||
|
*args: Arguments to print
|
||||||
|
**kwargs: Keyword arguments for print()
|
||||||
|
"""
|
||||||
|
print(f"[{timestamp()}] [{prefix}]", *args, **kwargs)
|
||||||
|
|
@ -1,74 +0,0 @@
|
||||||
"""
|
|
||||||
Logging setup for Balatro RL system
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
def setup_logging(
|
|
||||||
level: str = "INFO",
|
|
||||||
log_file: Optional[str] = None,
|
|
||||||
format_string: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Setup logging configuration
|
|
||||||
|
|
||||||
Args:
|
|
||||||
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
||||||
log_file: Optional log file path
|
|
||||||
format_string: Optional custom format string
|
|
||||||
"""
|
|
||||||
if format_string is None:
|
|
||||||
format_string = (
|
|
||||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert string level to logging constant
|
|
||||||
numeric_level = getattr(logging, level.upper(), None)
|
|
||||||
if not isinstance(numeric_level, int):
|
|
||||||
raise ValueError(f'Invalid log level: {level}')
|
|
||||||
|
|
||||||
# Create formatter
|
|
||||||
formatter = logging.Formatter(format_string)
|
|
||||||
|
|
||||||
# Setup root logger
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
root_logger.setLevel(numeric_level)
|
|
||||||
|
|
||||||
# Remove existing handlers
|
|
||||||
for handler in root_logger.handlers[:]:
|
|
||||||
root_logger.removeHandler(handler)
|
|
||||||
|
|
||||||
# Console handler
|
|
||||||
console_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
console_handler.setLevel(numeric_level)
|
|
||||||
console_handler.setFormatter(formatter)
|
|
||||||
root_logger.addHandler(console_handler)
|
|
||||||
|
|
||||||
# File handler (optional)
|
|
||||||
if log_file:
|
|
||||||
log_path = Path(log_file)
|
|
||||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(log_file)
|
|
||||||
file_handler.setLevel(numeric_level)
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
root_logger.addHandler(file_handler)
|
|
||||||
|
|
||||||
# Set specific logger levels
|
|
||||||
logging.getLogger("uvicorn").setLevel(logging.INFO)
|
|
||||||
logging.getLogger("fastapi").setLevel(logging.INFO)
|
|
||||||
|
|
||||||
def get_logger(name: str) -> logging.Logger:
|
|
||||||
"""
|
|
||||||
Get a logger with the specified name
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Logger name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Logger instance
|
|
||||||
"""
|
|
||||||
return logging.getLogger(name)
|
|
||||||
|
|
@ -0,0 +1,245 @@
|
||||||
|
"""
|
||||||
|
Mappers for converting between Balatro game data and RL-compatible formats.
|
||||||
|
|
||||||
|
This module handles the two-way data transformation:
|
||||||
|
1. BalatroStateMapper: Converts incoming Balatro JSON to normalized RL observations
|
||||||
|
2. BalatroActionMapper: Converts RL actions to Balatro command JSON
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Any
|
||||||
|
from ..utils.validation import GameStateValidator, ResponseValidator
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BalatroStateMapper:
|
||||||
|
"""
|
||||||
|
Converts raw Balatro game state JSON to normalized RL observations.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Card data normalization
|
||||||
|
- Game state parsing
|
||||||
|
- Observation space formatting
|
||||||
|
"""
|
||||||
|
def __init__(self, observation_size: int, max_actions: int):
|
||||||
|
self.observation_size = observation_size
|
||||||
|
self.max_actions = max_actions
|
||||||
|
|
||||||
|
# Logger
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# GameStateValidator
|
||||||
|
self.game_state_validator = GameStateValidator()
|
||||||
|
|
||||||
|
#TODO review Might be something wrong here
|
||||||
|
def process_game_state(self, raw_state: Dict[str, Any] | None) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert Balatro's raw JSON game state into neural network input format
|
||||||
|
converting into standardized numerical arrays that neural networks can
|
||||||
|
process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_state: Raw game state from Balatro mod JSON
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed state suitable for RL training
|
||||||
|
"""
|
||||||
|
if not raw_state:
|
||||||
|
return np.zeros(self.observation_size, dtype=np.float32) # TODO maybe raise error?
|
||||||
|
|
||||||
|
# Validate game state request
|
||||||
|
try:
|
||||||
|
self.game_state_validator.validate_game_state(raw_state)
|
||||||
|
except ValueError as e:
|
||||||
|
self.logger.error(f"Invalid game state: {e}")
|
||||||
|
|
||||||
|
hand_features = self._extract_hand_features(raw_state.get('hand', {}))
|
||||||
|
game_features = self._extract_game_features(raw_state)
|
||||||
|
available_actions = self._extract_available_actions(raw_state.get('available_actions', []))
|
||||||
|
chips = self._extract_chip_features(raw_state)
|
||||||
|
# TODO extract state
|
||||||
|
|
||||||
|
return np.concatenate([hand_features, game_features, available_actions, chips])
|
||||||
|
|
||||||
|
def _extract_available_actions(self, available_actions: List[int]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert available actions into a np array
|
||||||
|
Args:
|
||||||
|
available_actions: Available actions list from Balatro game state
|
||||||
|
Returns:
|
||||||
|
Fixed-size numpy array of hand features
|
||||||
|
"""
|
||||||
|
mask = np.zeros(self.max_actions, dtype=np.float32)
|
||||||
|
for action_id in available_actions:
|
||||||
|
mask[action_id] = 1.0
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def _extract_hand_features(self, hand: Dict[str, Any]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert Balatro hand data into numerical features for neural network
|
||||||
|
|
||||||
|
Transforms card dictionaries into fixed-size numerical arrays:
|
||||||
|
- Card values: 2-14 (2 through Ace)
|
||||||
|
- Suits: 0-3 (Hearts, Diamonds, Clubs, Spades)
|
||||||
|
- Card abilities: chips, mult, special effects
|
||||||
|
- Pads/truncates to fixed hand size for consistent input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hand: Hand dictionary from Balatro game state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fixed-size numpy array of hand features
|
||||||
|
"""
|
||||||
|
|
||||||
|
cards = hand.get('cards', [])
|
||||||
|
features = []
|
||||||
|
|
||||||
|
for card in cards:
|
||||||
|
# Extract card features
|
||||||
|
suit_encoding = self._encode_suit(card.get('suit', ''))
|
||||||
|
value_encoding = self._encode_value(card.get('base', {}).get('value', ''))
|
||||||
|
nominal = card.get('base', {}).get('nominal', 0)
|
||||||
|
|
||||||
|
# Add ability features
|
||||||
|
ability = card.get('ability', {})
|
||||||
|
ability_features = [
|
||||||
|
ability.get('t_chips', 0),
|
||||||
|
ability.get('t_mult', 0),
|
||||||
|
ability.get('x_mult', 1),
|
||||||
|
ability.get('mult', 0)
|
||||||
|
]
|
||||||
|
|
||||||
|
card_features = [suit_encoding, value_encoding, nominal] + ability_features
|
||||||
|
features.extend(card_features)
|
||||||
|
|
||||||
|
# Pad or truncate to fixed size (e.g., 8 cards max * 7 features each)
|
||||||
|
# TODO hand.get('size')
|
||||||
|
# TODO hand.get('highlighted_count')
|
||||||
|
max_cards = 8
|
||||||
|
features_per_card = 7
|
||||||
|
|
||||||
|
if len(features) < max_cards * features_per_card:
|
||||||
|
features.extend([0] * (max_cards * features_per_card - len(features)))
|
||||||
|
else:
|
||||||
|
features = features[:max_cards * features_per_card]
|
||||||
|
|
||||||
|
return np.array(features, dtype=np.float32)
|
||||||
|
|
||||||
|
def _extract_game_features(self, state: Dict[str, Any]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Extract numerical game-level features from Balatro state
|
||||||
|
|
||||||
|
Converts game metadata into neural network inputs:
|
||||||
|
- Current game state (menu, selecting hand, etc.)
|
||||||
|
- Available actions count
|
||||||
|
- Money, chips, round progression
|
||||||
|
- Remaining hands/discards
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Full Balatro game state dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Numpy array of normalized game features
|
||||||
|
"""
|
||||||
|
|
||||||
|
features = [
|
||||||
|
state.get('state', 0), # Game state
|
||||||
|
len(state.get('available_actions', [])), # Number of available actions
|
||||||
|
]
|
||||||
|
|
||||||
|
return np.array(features, dtype=np.float32)
|
||||||
|
|
||||||
|
def _extract_chip_features(self, state: Dict[str, Any]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Extract information relating to chips
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Full Balatro game state dictionary
|
||||||
|
Returns:
|
||||||
|
Numpy array of normalized chip features
|
||||||
|
"""
|
||||||
|
features = [
|
||||||
|
state.get('chips', 0),
|
||||||
|
state.get('chip_target', 0)
|
||||||
|
]
|
||||||
|
|
||||||
|
return np.array(features, dtype=np.float32)
|
||||||
|
|
||||||
|
def _encode_suit(self, suit: str) -> int:
|
||||||
|
"""Encode suit as integer"""
|
||||||
|
suit_map = {'Hearts': 0, 'Diamonds': 1, 'Clubs': 2, 'Spades': 3}
|
||||||
|
return suit_map.get(suit, 0)
|
||||||
|
|
||||||
|
def _encode_value(self, value: str) -> int:
|
||||||
|
"""Encode card value as integer"""
|
||||||
|
if value.isdigit():
|
||||||
|
return int(value)
|
||||||
|
|
||||||
|
value_map = {
|
||||||
|
'Jack': 11, 'Queen': 12, 'King': 13, 'Ace': 14,
|
||||||
|
'A': 14, 'K': 13, 'Q': 12, 'J': 11
|
||||||
|
}
|
||||||
|
return value_map.get(value, 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BalatroActionMapper:
|
||||||
|
"""
|
||||||
|
Converts RL actions to Balatro command JSON.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Binary action conversion to card indices
|
||||||
|
- Action validation
|
||||||
|
- JSON response formatting
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, action_slices: Dict[str, slice]):
|
||||||
|
self.slices = action_slices
|
||||||
|
|
||||||
|
# Validator
|
||||||
|
self.response_validator = ResponseValidator()
|
||||||
|
|
||||||
|
# Logger
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def process_action(self, rl_action: np.ndarray) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert RL action to Balatro JSON.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rl_action: Binary action array from RL agent
|
||||||
|
game_state: Current game state for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON response formatted for Balatro mod
|
||||||
|
"""
|
||||||
|
ai_action = rl_action[self.slices["action_selection"]].tolist()[0]
|
||||||
|
response_data = {
|
||||||
|
"action": ai_action,
|
||||||
|
"params": self._extract_select_hand_params(rl_action),
|
||||||
|
}
|
||||||
|
self.response_validator.validate_response(response_data)
|
||||||
|
|
||||||
|
# Validate action structure
|
||||||
|
try:
|
||||||
|
self.response_validator.validate_response(response_data)
|
||||||
|
except ValueError as e:
|
||||||
|
self.logger.error(f"Invalid game state: {e}")
|
||||||
|
|
||||||
|
return response_data
|
||||||
|
|
||||||
|
def _extract_select_hand_params(self, raw_action: np.ndarray) -> List[int]:
|
||||||
|
"""
|
||||||
|
Converts the raw action to a list of Lua card indices
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_action: The whole action from the RL agent
|
||||||
|
Returns:
|
||||||
|
List of 1-based card indices for Lua
|
||||||
|
"""
|
||||||
|
card_indices = raw_action[self.slices["card_indices"]]
|
||||||
|
return [i + 1 for i, val in enumerate(card_indices) if val == 1]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,122 +0,0 @@
|
||||||
"""
|
|
||||||
JSON Serialization utilities for Balatro RL
|
|
||||||
Handles conversion between game state formats
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class GameStateSerializer:
|
|
||||||
"""Handles serialization/deserialization of game states"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def serialize_game_state(game_state: Dict[str, Any]) -> str:
|
|
||||||
"""Convert game state dict to JSON string"""
|
|
||||||
try:
|
|
||||||
return json.dumps(game_state, indent=2, ensure_ascii=False)
|
|
||||||
except (TypeError, ValueError) as e:
|
|
||||||
raise ValueError(f"Failed to serialize game state: {e}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def deserialize_game_state(json_string: str) -> Dict[str, Any]:
|
|
||||||
"""Convert JSON string to game state dict"""
|
|
||||||
try:
|
|
||||||
return json.loads(json_string)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise ValueError(f"Failed to deserialize game state: {e}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def serialize_action(action: Dict[str, Any]) -> str:
|
|
||||||
"""Convert action dict to JSON string"""
|
|
||||||
try:
|
|
||||||
return json.dumps(action, ensure_ascii=False)
|
|
||||||
except (TypeError, ValueError) as e:
|
|
||||||
raise ValueError(f"Failed to serialize action: {e}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def deserialize_action(json_string: str) -> Dict[str, Any]:
|
|
||||||
"""Convert JSON string to action dict"""
|
|
||||||
try:
|
|
||||||
return json.loads(json_string)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise ValueError(f"Failed to deserialize action: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class GameStateValidator:
|
|
||||||
"""Validates game state structure and content"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def validate_game_state(game_state: Dict[str, Any]) -> bool:
|
|
||||||
"""Validate that game state has required fields"""
|
|
||||||
required_fields = ['state', 'available_actions']
|
|
||||||
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in game_state:
|
|
||||||
raise ValueError(f"Missing required field: {field}")
|
|
||||||
|
|
||||||
# Validate hand structure if present
|
|
||||||
if 'hand' in game_state:
|
|
||||||
GameStateValidator._validate_hand(game_state['hand'])
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _validate_hand(hand: Dict[str, Any]) -> bool:
|
|
||||||
"""Validate hand structure"""
|
|
||||||
if not isinstance(hand, dict):
|
|
||||||
raise ValueError("Hand must be a dictionary")
|
|
||||||
|
|
||||||
if 'cards' in hand and not isinstance(hand['cards'], list):
|
|
||||||
raise ValueError("Hand cards must be a list")
|
|
||||||
|
|
||||||
# Validate each card
|
|
||||||
for i, card in enumerate(hand.get('cards', [])):
|
|
||||||
GameStateValidator._validate_card(card, i)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _validate_card(card: Dict[str, Any], index: int) -> bool:
|
|
||||||
"""Validate individual card structure"""
|
|
||||||
required_fields = ['suit', 'base']
|
|
||||||
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in card:
|
|
||||||
raise ValueError(f"Card {index} missing required field: {field}")
|
|
||||||
|
|
||||||
# Validate base structure
|
|
||||||
if 'base' in card:
|
|
||||||
base = card['base']
|
|
||||||
if 'value' not in base or 'nominal' not in base:
|
|
||||||
raise ValueError(f"Card {index} base missing value or nominal")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def validate_action(action: Dict[str, Any]) -> bool:
|
|
||||||
"""Validate action structure"""
|
|
||||||
if 'action' not in action:
|
|
||||||
raise ValueError("Action must have 'action' field")
|
|
||||||
|
|
||||||
action_type = action['action']
|
|
||||||
|
|
||||||
# Validate specific action types
|
|
||||||
if action_type in ['play_hand', 'discard']:
|
|
||||||
if 'card_indices' not in action:
|
|
||||||
raise ValueError(f"Action {action_type} requires card_indices")
|
|
||||||
|
|
||||||
if not isinstance(action['card_indices'], list):
|
|
||||||
raise ValueError("card_indices must be a list")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience functions
|
|
||||||
def to_json(obj: Dict[str, Any]) -> str:
|
|
||||||
"""Quick serialize to JSON"""
|
|
||||||
return GameStateSerializer.serialize_game_state(obj)
|
|
||||||
|
|
||||||
def from_json(json_str: str) -> Dict[str, Any]:
|
|
||||||
"""Quick deserialize from JSON"""
|
|
||||||
return GameStateSerializer.deserialize_game_state(json_str)
|
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""
|
||||||
|
Game State Validation utilities for Balatro RL
|
||||||
|
Validates game state structure and content
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GameStateValidator:
|
||||||
|
"""
|
||||||
|
Validates game state json structure
|
||||||
|
|
||||||
|
Request Contract (Macro level not everything)
|
||||||
|
{
|
||||||
|
game_state: Dict, # The whole game state table
|
||||||
|
available_actions: List, # The available actions all as integers
|
||||||
|
retry_count: int, # TODO update observation space. this is to handle retries if AI fails
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_game_state(game_state: Dict[str, Any]) -> bool:
|
||||||
|
"""Validate that game state has required fields"""
|
||||||
|
required_fields = ['game_state', 'available_actions', 'retry_count']
|
||||||
|
# TODO we can add more validations as needed like ensuring available_actions is a list and stuff
|
||||||
|
# TODO game_over might be a validation
|
||||||
|
# TODO validate chips for rewards
|
||||||
|
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in game_state:
|
||||||
|
raise ValueError(f"Missing required field: {field}")
|
||||||
|
|
||||||
|
# Validate hand structure if present
|
||||||
|
if 'hand' in game_state:
|
||||||
|
GameStateValidator._validate_hand(game_state['hand'])
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_hand(hand: Dict[str, Any]) -> bool:
|
||||||
|
"""Validate hand structure"""
|
||||||
|
if not isinstance(hand, dict):
|
||||||
|
raise ValueError("Hand must be a dictionary")
|
||||||
|
|
||||||
|
if 'cards' in hand and not isinstance(hand['cards'], list):
|
||||||
|
raise ValueError("Hand cards must be a list")
|
||||||
|
|
||||||
|
# Validate each card
|
||||||
|
for i, card in enumerate(hand.get('cards', [])):
|
||||||
|
GameStateValidator._validate_card(card, i)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_card(card: Dict[str, Any], index: int) -> bool:
|
||||||
|
"""Validate individual card structure"""
|
||||||
|
required_fields = ['suit', 'base']
|
||||||
|
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in card:
|
||||||
|
raise ValueError(f"Card {index} missing required field: {field}")
|
||||||
|
|
||||||
|
# Validate base structure
|
||||||
|
if 'base' in card:
|
||||||
|
base = card['base']
|
||||||
|
if 'value' not in base or 'nominal' not in base:
|
||||||
|
raise ValueError(f"Card {index} base missing value or nominal")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
class ResponseValidator:
|
||||||
|
"""
|
||||||
|
Validates response json structure
|
||||||
|
|
||||||
|
Response contract
|
||||||
|
{
|
||||||
|
action, # The action to take
|
||||||
|
params, # The params in the event the action takes in params
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_response(response: Dict[str, Any]):
|
||||||
|
required_fields = ["action"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in response:
|
||||||
|
raise ValueError(f"Missing required field: {field}")
|
||||||
|
|
||||||
Loading…
Reference in New Issue