used debugger to iron out all the bugs step by step. things are looking way smoother
This commit is contained in:
parent
2890671520
commit
e0d20b386b
|
|
@ -85,3 +85,5 @@ The mod can successfully:
|
|||
## Communication with Claude
|
||||
- Please keep chats as efficient as possible and brief. No fluff talk, just get to the point
|
||||
- Try to break things down in first principles effectively
|
||||
- Do not add explanatory comments about why values changed (e.g., "# Changed from X to Y because...") - just use the correct values
|
||||
- Comments should explain what something does, not why it was modified
|
||||
|
|
|
|||
12
README.md
12
README.md
|
|
@ -38,14 +38,16 @@ number it is? I think we should add the hands and discards in the game state as
|
|||
- [x] Should we not give reward for just plain increasing chips? if you think about it, you can play anything and increase
|
||||
chips. Perhpas we just want to get wins of rounds just scoring chips is not enough?. Wonder if the losing penatly is not enough
|
||||
- [x] I wonder if there's a problem with the fact that they get points out of every hand played. I feel like it should learn to play more complex hands instead of just getting points even if just one hand scores we should maybe have the rewards reflect that
|
||||
|
||||
|
||||
### RL Enhancements
|
||||
- [x] **Retry Count Penalty**: Penalize high retry_count in rewards to discourage invalid actions. Currently retry_count tracks failed action attempts, but we could use this signal to teach the AI which actions are actually valid in each state. Formula: `reward -= retry_count * penalty_factor`. This would incentivize the AI to learn valid action spaces rather than trial-and-error.
|
||||
- [ ] Add a "Replay System" to analyze successful actions. For example, save seed, have an action log for reproduction etc
|
||||
Can probably do this by adding it before the game is reset like a check on what criteria I want to save for, and save
|
||||
- [ ] Now that we have improved logging that shows win rate and stuff, maybe we can reward the AI for increased win rate and stuff? that is my main goal so that it wins
|
||||
near 100% of the time
|
||||
- [ ] Add some mechanism of finding out how many times the AI has won the game also
|
||||
figure out a way to get a replay of the game. for example i just noticed the RL model scored a 592. It would be amazing to have that saved somewhere. But I would need the seed (which we can't get if we haven't lost) Or maybe there's a way to get a recording of it which would work too
|
||||
There's gotta be a function that gets the current seed even when the game isn't over
|
||||
- [ ] Is there a way to reward the AI for getting "fire" type scores which are really good
|
||||
|
||||
### RL Enhancements
|
||||
- [ ] **Retry Count Penalty**: Penalize high retry_count in rewards to discourage invalid actions. Currently retry_count tracks failed action attempts, but we could use this signal to teach the AI which actions are actually valid in each state. Formula: `reward -= retry_count * penalty_factor`. This would incentivize the AI to learn valid action spaces rather than trial-and-error.
|
||||
- [ ] Add a "Replay System" to analyze successful actions. For example, save seed, have an action log for reproduction etc
|
||||
Can probably do this by adding it before the game is reset like a check on what criteria I want to save for, and save
|
||||
### DEBUGGING
|
||||
|
|
@ -7,33 +7,32 @@ local utils = require("utils")
|
|||
local ACTIONS = {}
|
||||
|
||||
-- Action constants (like G.STATES pattern)
|
||||
ACTIONS.START_RUN = 1
|
||||
ACTIONS.SELECT_BLIND = 2
|
||||
ACTIONS.SELECT_HAND = 3
|
||||
ACTIONS.PLAY_HAND = 4
|
||||
ACTIONS.DISCARD_HAND = 5
|
||||
-- Core gameplay actions are 1,2,3 for clean AI mapping
|
||||
ACTIONS.SELECT_HAND = 1
|
||||
ACTIONS.PLAY_HAND = 2
|
||||
ACTIONS.DISCARD_HAND = 3
|
||||
-- Auto-executed actions (not exposed to AI)
|
||||
ACTIONS.START_RUN = 4
|
||||
ACTIONS.SELECT_BLIND = 5
|
||||
ACTIONS.RESTART_RUN = 6
|
||||
ACTIONS.CASH_OUT = 7
|
||||
|
||||
-- Action mapping tables
|
||||
local ACTION_IDS = {
|
||||
start_run = ACTIONS.START_RUN,
|
||||
select_blind = ACTIONS.SELECT_BLIND,
|
||||
select_hand = ACTIONS.SELECT_HAND,
|
||||
play_hand = ACTIONS.PLAY_HAND,
|
||||
discard_hand = ACTIONS.DISCARD_HAND,
|
||||
start_run = ACTIONS.START_RUN,
|
||||
select_blind = ACTIONS.SELECT_BLIND,
|
||||
restart_run = ACTIONS.RESTART_RUN,
|
||||
cash_out = ACTIONS.CASH_OUT,
|
||||
}
|
||||
|
||||
local ID_TO_ACTION = {
|
||||
[ACTIONS.START_RUN] = "start_run",
|
||||
[ACTIONS.SELECT_BLIND] = "select_blind",
|
||||
[ACTIONS.SELECT_HAND] = "select_hand",
|
||||
[ACTIONS.PLAY_HAND] = "play_hand",
|
||||
[ACTIONS.DISCARD_HAND] = "discard_hand",
|
||||
[ACTIONS.START_RUN] = "start_run",
|
||||
[ACTIONS.SELECT_BLIND] = "select_blind",
|
||||
[ACTIONS.RESTART_RUN] = "restart_run",
|
||||
[ACTIONS.CASH_OUT] = "cash_out",
|
||||
}
|
||||
|
||||
-- Centralized action state tracking
|
||||
|
|
@ -112,15 +111,6 @@ local action_registry = {
|
|||
return (G.STATE == G.STATES.GAME_OVER or G.STATE == G.STATES.ROUND_EVAL) and not action_state.restart_run
|
||||
end,
|
||||
},
|
||||
cash_out = {
|
||||
execute = function(params)
|
||||
ACTIONS.reset_state()
|
||||
return input.cash_out()
|
||||
end,
|
||||
available_when = function()
|
||||
return G.STATE == G.STATES.ROUND_EVAL and G.GAME.blind and G.GAME.blind.defeated and not action_state.cash_out
|
||||
end,
|
||||
},
|
||||
}
|
||||
|
||||
--- Get all currently available actions
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ local utils = require("utils")
|
|||
-- State management
|
||||
local last_combined_hash = nil
|
||||
local pending_action = nil
|
||||
local need_retry_request = false
|
||||
local rl_training_active = false
|
||||
local last_key_pressed = nil
|
||||
local retry_count = 0
|
||||
|
||||
--- Initialize AI system
|
||||
--- Sets up communication and prepares the AI for operation
|
||||
|
|
@ -75,7 +75,7 @@ function AI.update()
|
|||
-- Create combined hash to detect meaningful changes
|
||||
local combined_hash = AI.hash_combined_state(current_state, available_actions)
|
||||
|
||||
if combined_hash ~= last_combined_hash or need_retry_request then
|
||||
if combined_hash ~= last_combined_hash then
|
||||
-- Game state or available actions have changed
|
||||
utils.log_ai("State/Actions changed: State: " ..
|
||||
current_state.state .. " (" .. utils.get_state_name(current_state.state) .. ") | " ..
|
||||
|
|
@ -83,20 +83,21 @@ function AI.update()
|
|||
|
||||
action.reset_state()
|
||||
|
||||
-- Request action from AI
|
||||
need_retry_request = false
|
||||
-- Auto-skip trivial actions (don't send to AI)
|
||||
if AI.should_auto_skip(current_state, available_actions) then
|
||||
AI.execute_auto_skip_action(current_state, available_actions)
|
||||
return
|
||||
end
|
||||
|
||||
-- Add retry_count to current state
|
||||
current_state.retry_count = retry_count
|
||||
|
||||
-- Request action from AI (only for core gameplay)
|
||||
local ai_response = communication.request_action(current_state, available_actions)
|
||||
|
||||
if ai_response then
|
||||
-- Handling handshake
|
||||
if ai_response.action == "ready" then
|
||||
utils.log_ai("Handshake complete - AI ready")
|
||||
-- Force a state check on next frame to send real request
|
||||
last_combined_hash = nil
|
||||
else
|
||||
pending_action = ai_response
|
||||
last_combined_hash = combined_hash
|
||||
end
|
||||
pending_action = ai_response
|
||||
last_combined_hash = combined_hash
|
||||
end
|
||||
end
|
||||
|
||||
|
|
@ -116,17 +117,23 @@ function AI.update()
|
|||
local result = action.execute_action(pending_action.action, pending_action.params)
|
||||
if result.success then
|
||||
utils.log_ai("Action executed successfully: " .. pending_action.action)
|
||||
retry_count = 0 -- Reset retry count on success
|
||||
pending_action = nil
|
||||
utils.log_ai("\n\n\n")
|
||||
else
|
||||
utils.log_ai("Action failed: " .. (result.error or "Unknown error") .. " RETRYING...")
|
||||
need_retry_request = true
|
||||
utils.log_ai("Action failed: " .. (result.error or "Unknown error"))
|
||||
retry_count = retry_count + 1
|
||||
utils.log_ai("Retry count: " .. retry_count)
|
||||
-- Keep pending_action to retry on next frame
|
||||
-- Force state recheck to send updated state with retry_count
|
||||
last_combined_hash = nil
|
||||
end
|
||||
else
|
||||
utils.log_ai("Action no longer valid (state changed), discarding: " .. pending_action.action)
|
||||
-- Force a retry to get a new action for the current state
|
||||
need_retry_request = true
|
||||
retry_count = 0 -- Reset on state change
|
||||
pending_action = nil
|
||||
utils.log_ai("\n\n\n")
|
||||
end
|
||||
pending_action = nil
|
||||
utils.log_ai("\n\n\n")
|
||||
end
|
||||
end
|
||||
|
||||
|
|
@ -160,4 +167,42 @@ function AI.hash_combined_state(game_state, available_actions)
|
|||
return combined
|
||||
end
|
||||
|
||||
--- Check if current state should be auto-skipped (not sent to AI)
|
||||
--- @param current_state table Current game state
|
||||
--- @param available_actions table Available actions list
|
||||
--- @return boolean True if should auto-skip, false if send to AI
|
||||
function AI.should_auto_skip(current_state, available_actions)
|
||||
-- Auto-skip START_RUN in menu (action ID = 4)
|
||||
if current_state.state == G.STATES.MENU and #available_actions == 1 and available_actions[1] == 4 then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Auto-skip SELECT_BLIND in blind selection (action ID = 5)
|
||||
if current_state.state == G.STATES.BLIND_SELECT and #available_actions == 1 and available_actions[1] == 5 then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Don't auto-skip anything else - core actions (1,2,3) go to AI
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Execute auto-skip action without AI involvement
|
||||
--- @param current_state table Current game state
|
||||
--- @param available_actions table Available actions list
|
||||
function AI.execute_auto_skip_action(current_state, available_actions)
|
||||
local action_id = available_actions[1]
|
||||
utils.log_ai("Auto-executing action: " .. action.get_action_name(action_id))
|
||||
|
||||
local result = action.execute_action(action_id, {})
|
||||
if result.success then
|
||||
utils.log_ai("Auto-execution successful: " .. action.get_action_name(action_id))
|
||||
else
|
||||
utils.log_ai("Auto-execution failed: " .. (result.error or "Unknown error"))
|
||||
end
|
||||
|
||||
-- Force state recheck after auto-execution
|
||||
last_combined_hash = nil
|
||||
end
|
||||
|
||||
return AI
|
||||
|
|
|
|||
|
|
@ -98,14 +98,5 @@ function I.discard_hand()
|
|||
return { success = true }
|
||||
end
|
||||
|
||||
--- Click "Cash Out" button to cash out after winning a blind
|
||||
--- @return table Result with success status and optional error message
|
||||
function I.cash_out()
|
||||
-- Find the cash out button and simulate clicking it
|
||||
local cash_out_button = { config = { button = nil } }
|
||||
G.FUNCS.cash_out(cash_out_button)
|
||||
utils.log_input("cash_out " .. utils.completed_success_msg)
|
||||
return { success = true }
|
||||
end
|
||||
|
||||
return I
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ end
|
|||
--- @return number Chips needed to beat current blind
|
||||
function O.get_blind_chips()
|
||||
if not G.GAME or not G.GAME.blind then
|
||||
return 300 -- Default ante 1 small blind requirement
|
||||
return 300 -- TODO probably fix this if we are doing more than one blind Default ante 1 small blind requirement
|
||||
end
|
||||
return G.GAME.blind.chips or 300
|
||||
end
|
||||
|
|
@ -60,7 +60,6 @@ function O.get_hand_info()
|
|||
nominal = card.base.nominal or 0,
|
||||
value = card.base.value or ""
|
||||
},
|
||||
debuff = card.debuff or false,
|
||||
highlighted = card.highlighted or false,
|
||||
suit = card.base.suit or ""
|
||||
})
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from typing import Dict, Any, Tuple, List, Optional
|
|||
import logging
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
from ..utils.debug import tprint
|
||||
|
||||
from ..utils.communication import BalatroPipeIO
|
||||
from .reward import BalatroRewardCalculator
|
||||
|
|
@ -43,11 +42,11 @@ class BalatroEnv(gym.Env):
|
|||
|
||||
# Define Gymnasium spaces
|
||||
# Action Spaces; This should describe the type and shape of the action
|
||||
# Constants
|
||||
self.MAX_ACTIONS = 7 # get from balatro actions.lua
|
||||
# Constants - Core gameplay actions only (SELECT_HAND=1, PLAY_HAND=2, DISCARD_HAND=3)
|
||||
self.MAX_ACTIONS = 3
|
||||
self.MAX_CARDS = 8 # Max cards in hand
|
||||
action_selection = np.array([self.MAX_ACTIONS])
|
||||
card_indices = np.array([2] * self.MAX_CARDS) # 8 cards, each can be selected (1) or not (0)
|
||||
card_indices = np.array([2] * self.MAX_CARDS) # 8 cards, each can be selected (1) or not (0) #TODO can we or have we already masked card selection?
|
||||
self.action_space = spaces.MultiDiscrete(np.concatenate([
|
||||
action_selection,
|
||||
card_indices
|
||||
|
|
@ -60,7 +59,7 @@ class BalatroEnv(gym.Env):
|
|||
|
||||
# Observation space: This should describe the type and shape of the observation
|
||||
# Constants
|
||||
self.OBSERVATION_SIZE = 198 # Minimal observation space for ante 1 focus (removed retry_count)
|
||||
self.OBSERVATION_SIZE = 215
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, # lowest bound of observation data
|
||||
high=np.inf, # highest bound of observation data
|
||||
|
|
@ -72,7 +71,7 @@ class BalatroEnv(gym.Env):
|
|||
self.state_mapper = BalatroStateMapper(observation_size=self.OBSERVATION_SIZE, max_actions=self.MAX_ACTIONS)
|
||||
self.action_mapper = BalatroActionMapper(action_slices=slices)
|
||||
|
||||
def reset(self, seed=None, options=None): #TODO add types
|
||||
def reset(self, seed=None, options=None):
|
||||
"""
|
||||
Reset the environment for a new episode
|
||||
|
||||
|
|
@ -95,11 +94,6 @@ class BalatroEnv(gym.Env):
|
|||
if not initial_request:
|
||||
raise RuntimeError("Failed to receive initial request from Balatro")
|
||||
|
||||
# Send dummy response to complete hand shake
|
||||
success = self.pipe_io.send_response({"action": "ready"})
|
||||
if not success:
|
||||
raise RuntimeError("Failed to complete handshake")
|
||||
|
||||
# Process initial state for SB3
|
||||
self.current_state = initial_request
|
||||
initial_observation = self.state_mapper.process_game_state(self.current_state)
|
||||
|
|
@ -140,7 +134,6 @@ class BalatroEnv(gym.Env):
|
|||
# Wait for next request with new game state
|
||||
next_request = self.pipe_io.wait_for_request()
|
||||
if not next_request:
|
||||
# If no response, assume game ended
|
||||
self.game_over = True
|
||||
observation = self.state_mapper.process_game_state(self.current_state)
|
||||
reward = 0.0
|
||||
|
|
@ -149,30 +142,31 @@ class BalatroEnv(gym.Env):
|
|||
# Update current state
|
||||
self.current_state = next_request
|
||||
|
||||
# Check for game over condition
|
||||
game_over_flag = self.current_state.get('game_state', {}).get('game_over', 0)
|
||||
if game_over_flag == 1:
|
||||
observation = self.state_mapper.process_game_state(self.current_state)
|
||||
reward = self.reward_calculator.calculate_reward(
|
||||
current_state=self.current_state,
|
||||
prev_state=self.prev_state if self.prev_state else {}
|
||||
)
|
||||
|
||||
# Auto-send restart command to Balatro
|
||||
restart_response = {"action": 6, "params": []}
|
||||
self.pipe_io.send_response(restart_response)
|
||||
|
||||
return observation, reward, True, False, {"game_over": True}
|
||||
|
||||
# Process new state for SB3
|
||||
observation = self.state_mapper.process_game_state(self.current_state)
|
||||
|
||||
# Calculate reward using expert reward calculator - no more retry penalties!
|
||||
# Calculate reward using expert reward calculator
|
||||
reward = self.reward_calculator.calculate_reward(
|
||||
current_state=self.current_state,
|
||||
prev_state=self.prev_state if self.prev_state else {}
|
||||
)
|
||||
|
||||
# Check if episode is done - delay end until after restart
|
||||
game_over_flag = self.current_state.get('game_state', {}).get('game_over', 0)
|
||||
|
||||
if game_over_flag == 1:
|
||||
if not self.restart_pending:
|
||||
# First time seeing game over - don't end episode yet, wait for restart
|
||||
self.restart_pending = True
|
||||
done = False
|
||||
else:
|
||||
# Game has restarted after game over, now end episode
|
||||
self.restart_pending = False
|
||||
done = True
|
||||
else:
|
||||
# Normal gameplay, no episode end
|
||||
done = False
|
||||
done = False
|
||||
|
||||
terminated = done
|
||||
truncated = False # Not using time limits for now
|
||||
|
|
@ -181,10 +175,7 @@ class BalatroEnv(gym.Env):
|
|||
available_actions = next_request.get('available_actions', [])
|
||||
action_mask = self._create_action_mask(available_actions)
|
||||
|
||||
info = {
|
||||
'balatro_state': self.current_state.get('state', 0),
|
||||
'available_actions': available_actions,
|
||||
}
|
||||
info = {}
|
||||
|
||||
# Store action mask for MaskablePPO
|
||||
self._action_masks = action_mask
|
||||
|
|
@ -211,16 +202,26 @@ class BalatroEnv(gym.Env):
|
|||
"""Create action mask for MultiDiscrete space"""
|
||||
action_masks = []
|
||||
|
||||
# Action selection mask (7 possible actions)
|
||||
# Action selection mask (3 possible actions: SELECT_HAND=1, PLAY_HAND=2, DISCARD_HAND=3)
|
||||
# Map Balatro action IDs to AI indices: 1->0, 2->1, 3->2
|
||||
action_selection_mask = [False] * self.MAX_ACTIONS
|
||||
balatro_to_ai_mapping = {1: 0, 2: 1, 3: 2} # SELECT_HAND, PLAY_HAND, DISCARD_HAND
|
||||
|
||||
for action_id in available_actions:
|
||||
if 1 <= action_id <= self.MAX_ACTIONS:
|
||||
action_selection_mask[action_id - 1] = True
|
||||
if action_id in balatro_to_ai_mapping:
|
||||
ai_index = balatro_to_ai_mapping[action_id]
|
||||
action_selection_mask[ai_index] = True
|
||||
action_masks.append(action_selection_mask)
|
||||
|
||||
# Card selection masks (8 cards, each can be 0 or 1)
|
||||
for _ in range(self.MAX_CARDS):
|
||||
action_masks.append([True, True])
|
||||
# Card selection masks - context-aware based on available actions
|
||||
if any(action_id in [2, 3] for action_id in available_actions):
|
||||
# PLAY_HAND or DISCARD_HAND available - card params don't matter
|
||||
for _ in range(self.MAX_CARDS):
|
||||
action_masks.append([True, False]) # Force "not selected"
|
||||
else:
|
||||
# Only SELECT_HAND available - allow card selection
|
||||
for _ in range(self.MAX_CARDS):
|
||||
action_masks.append([True, True])
|
||||
|
||||
# Flatten for MaskablePPO
|
||||
return [item for sublist in action_masks for item in sublist]
|
||||
|
|
|
|||
|
|
@ -46,6 +46,14 @@ class BalatroRewardCalculator:
|
|||
# Extract key metrics
|
||||
current_chips = inner_game_state.get('chips', 0)
|
||||
game_over = inner_game_state.get('game_over', 0)
|
||||
retry_count = inner_game_state.get('retry_count', 0)
|
||||
|
||||
# === RETRY PENALTY ===
|
||||
# Negative reward for invalid actions that require retries
|
||||
if retry_count > 0:
|
||||
retry_penalty = -0.1 * retry_count # -0.1 per retry
|
||||
reward += retry_penalty
|
||||
reward_breakdown.append(f"retry_penalty: {retry_penalty:.2f} (retries: {retry_count})")
|
||||
|
||||
# Check if blind is defeated by comparing chips to requirement
|
||||
blind_chips = inner_game_state.get('blind_chips', 300) # TODO 300 only focusing on first blind
|
||||
|
|
@ -82,7 +90,9 @@ class BalatroRewardCalculator:
|
|||
elif chip_percentage >= self.REWARD_THRESHOLDS["decent"]:
|
||||
reward += 1.0
|
||||
reward_breakdown.append(f"Decent hand (+{chip_gain} chips, {chip_percentage:.1f}% of blind): +1.0")
|
||||
# <25% of blind = no reward (too small to matter)
|
||||
else:
|
||||
reward += 0.5
|
||||
reward_breakdown.append(f"Small hand (+{chip_gain} chips, {chip_percentage:.1f}% of blind): +0.5")
|
||||
|
||||
# === REMOVED HAND TYPE REWARDS ===
|
||||
# Hand type rewards removed - in Balatro, only chips matter!
|
||||
|
|
@ -91,16 +101,16 @@ class BalatroRewardCalculator:
|
|||
# === BLIND COMPLETION ===
|
||||
# Main goal - beat the blind (only reward once per episode)
|
||||
if blind_defeated and not self.blind_already_defeated and game_over == 0:
|
||||
reward += 500.0 # SUCCESS! Increased from +100 to +500
|
||||
reward_breakdown.append(f"BLIND DEFEATED: +500.0")
|
||||
reward += 50.0 # SUCCESS! Normalized from +500 to +50
|
||||
reward_breakdown.append(f"BLIND DEFEATED: +50.0")
|
||||
self.blind_already_defeated = True
|
||||
self.winning_chips = current_chips # Store winning chip count
|
||||
|
||||
# === PENALTIES ===
|
||||
# Game over penalty - ONLY for actual losses (blind not defeated)
|
||||
if game_over == 1 and not hasattr(self, 'game_over_penalty_applied') and not self.blind_already_defeated:
|
||||
reward -= 200.0 # Increased from -20 to -200
|
||||
reward_breakdown.append("Game over: -200.0")
|
||||
reward -= 20.0 # Normalized from -200 to -20
|
||||
reward_breakdown.append("Game over: -20.0")
|
||||
self.game_over_penalty_applied = True
|
||||
|
||||
# Track episode total reward
|
||||
|
|
|
|||
|
|
@ -69,8 +69,6 @@ class BalatroPipeIO:
|
|||
Opens request pipe for reading and response pipe for writing.
|
||||
Keeps handles open to avoid deadlocks.
|
||||
"""
|
||||
import time
|
||||
|
||||
try:
|
||||
self.logger.info(f"🔧 Waiting for Balatro to connect...")
|
||||
self.logger.info(f" Press 'R' in Balatro now to activate RL training!")
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ class BalatroStateMapper:
|
|||
|
||||
cards = hand.get('cards', [])
|
||||
|
||||
|
||||
NON_CARDS_FEATURES = 2 # TODO Update this if we add more non-card features
|
||||
features.append(float(hand.get('size', 0)))
|
||||
features.append(float(hand.get('highlighted_count', 0)))
|
||||
|
||||
|
|
@ -172,17 +172,19 @@ class BalatroStateMapper:
|
|||
}
|
||||
value = base.get('value')
|
||||
card_features.extend(make_onehot(values_mapping.get(value, 13), 14))
|
||||
|
||||
# Nominal value (actual chip value used in game calculations)
|
||||
card_features.append(base.get('nominal', 0.0))
|
||||
|
||||
features.extend(card_features)
|
||||
|
||||
# Pad or truncate to fixed size
|
||||
max_cards = 8 # Standard Balatro hand size
|
||||
features_per_card = 20 # 1+5+14 = highlighted+suit_onehot+value_onehot
|
||||
max_cards = 8 # TODO this might have to be updated in future if we go bigger Standard Balatro hand size
|
||||
features_per_card = 21 # 1+5+14+1 = highlighted+suit_onehot+value_onehot+nominal
|
||||
|
||||
if len(features) < max_cards * features_per_card:
|
||||
features.extend([0] * (max_cards * features_per_card - len(features)))
|
||||
else:
|
||||
features = features[:max_cards * features_per_card]
|
||||
# If no cards in hand, pad with zeros
|
||||
if len(features) == NON_CARDS_FEATURES:
|
||||
features.extend([0.0] * (max_cards * features_per_card))
|
||||
|
||||
return features
|
||||
|
||||
|
|
@ -206,8 +208,9 @@ class BalatroStateMapper:
|
|||
features.extend(self._extract_round_features(state.get('round', {})))
|
||||
features.append(float(state.get('blind_chips', 0)))
|
||||
features.append(float(state.get('chips', 0)))
|
||||
features.extend(make_onehot(state.get('state', 0), 10)) # Reduced state space
|
||||
features.extend(make_onehot(state.get('state', 0), 20))
|
||||
features.append(float(state.get('game_over', 0)))
|
||||
features.append(float(state.get('retry_count', 0)))
|
||||
features.extend(self._extract_hand_features(state.get('hand', {})))
|
||||
features.extend(self._extract_current_hand_scoring(state.get('current_hand', {})))
|
||||
|
||||
|
|
@ -264,12 +267,11 @@ class BalatroStateMapper:
|
|||
"Flush Five" # 12
|
||||
]
|
||||
|
||||
handname = current_hand.get('handname', 'None')
|
||||
try:
|
||||
hand_index = hand_types.index(handname)
|
||||
except ValueError:
|
||||
hand_index = 0 # Default to "None" for unknown hands
|
||||
|
||||
hand_name = current_hand.get('handname', 'None')
|
||||
if not hand_name:
|
||||
hand_name = "None"
|
||||
hand_index = hand_types.index(hand_name)
|
||||
|
||||
features.extend(make_onehot(hand_index, len(hand_types)))
|
||||
|
||||
return features
|
||||
|
|
@ -305,8 +307,13 @@ class BalatroActionMapper:
|
|||
JSON response formatted for Balatro mod
|
||||
"""
|
||||
ai_action = rl_action[self.slices["action_selection"]].tolist()[0]
|
||||
|
||||
# Map AI indices to Balatro action IDs: 0->1, 1->2, 2->3
|
||||
ai_to_balatro_mapping = {0: 1, 1: 2, 2: 3} # SELECT_HAND, PLAY_HAND, DISCARD_HAND
|
||||
balatro_action_id = ai_to_balatro_mapping.get(ai_action, 1) # Default to SELECT_HAND
|
||||
|
||||
response_data = {
|
||||
"action": ai_action + 1, # Convert 0-based AI output to 1-based Balatro action IDs
|
||||
"action": balatro_action_id,
|
||||
"params": self._extract_select_hand_params(rl_action),
|
||||
}
|
||||
self.response_validator.validate_response(response_data)
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ class GameStateValidator:
|
|||
assert isinstance(game_state["state"], int)
|
||||
assert isinstance(game_state["blind_chips"], (int, float))
|
||||
assert isinstance(game_state["chips"], (int, float))
|
||||
assert isinstance(game_state.get("retry_count", 0), int)
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -94,7 +95,7 @@ class GameStateValidator:
|
|||
@staticmethod
|
||||
def _validate_card(card: Dict[str, Any], index: int) -> bool:
|
||||
"""Validate individual card structure"""
|
||||
required_fields = ['suit', 'base', 'highlighted', 'debuff']
|
||||
required_fields = ['suit', 'base', 'highlighted']
|
||||
|
||||
for field in required_fields:
|
||||
assert field in card, f"Card {index} missing required field: {field}"
|
||||
|
|
@ -125,4 +126,8 @@ class ResponseValidator:
|
|||
|
||||
assert isinstance(response["action"], int)
|
||||
assert isinstance(response["params"], list)
|
||||
|
||||
# Validate action is within valid range (1=SELECT_HAND, 2=PLAY_HAND, 3=DISCARD_HAND)
|
||||
valid_actions = [1, 2, 3]
|
||||
assert response["action"] in valid_actions, f"Invalid action ID: {response['action']}"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue