adding replay system. just a json file lol
This commit is contained in:
parent
02aaf4faa8
commit
102d005a81
15
README.md
15
README.md
|
|
@ -42,12 +42,11 @@ chips. Perhpas we just want to get wins of rounds just scoring chips is not enou
|
|||
|
||||
### RL Enhancements
|
||||
- [x] **Retry Count Penalty**: Penalize high retry_count in rewards to discourage invalid actions. Currently retry_count tracks failed action attempts, but we could use this signal to teach the AI which actions are actually valid in each state. Formula: `reward -= retry_count * penalty_factor`. This would incentivize the AI to learn valid action spaces rather than trial-and-error.
|
||||
- [ ] Add a "Replay System" to analyze successful actions. For example, save seed, have an action log for reproduction etc
|
||||
Can probably do this by adding it before the game is reset like a check on what criteria I want to save for, and save
|
||||
- [ ] Now that we have improved logging that shows win rate and stuff, maybe we can reward the AI for increased win rate and stuff? that is my main goal so that it wins
|
||||
near 100% of the time
|
||||
- [ ] Add some mechanism of finding out how many times the AI has won the game also
|
||||
figure out a way to get a replay of the game. for example i just noticed the RL model scored a 592. It would be amazing to have that saved somewhere. But I would need the seed (which we can't get if we haven't lost) Or maybe there's a way to get a recording of it which would work too
|
||||
There's gotta be a function that gets the current seed even when the game isn't over
|
||||
- [x] I just noticed the RL model scored a 592. It would be amazing to have that saved somewhere.
|
||||
- Create replay system that only saves winning games into a file or something (TBD)
|
||||
- We would probably store the raw requests and raw responses, and if we win, we can save, if not we can reset the list
|
||||
- The idea is that I'll have the seed so I can just look at the actions the requests and responses, plugin the seed manually in the game and play it out myself
|
||||
- Add something where we only keep top 5. I don't to have a long log of a bunch of wins
|
||||
|
||||
### DEBUGGING
|
||||
|
||||
### DEBUGGING
|
||||
|
|
|
|||
|
|
@ -14,10 +14,16 @@ function O.get_game_state()
|
|||
game_over = 1
|
||||
end
|
||||
|
||||
local game_win = 0
|
||||
if G.STATE == G.STATES.ROUND_EVAL then
|
||||
game_win = 1
|
||||
end
|
||||
|
||||
return {
|
||||
-- Basic state info
|
||||
state = G.STATE,
|
||||
game_over = game_over,
|
||||
game_win = game_win,
|
||||
|
||||
-- Round info (hands/discards left)
|
||||
round = O.get_round_info(),
|
||||
|
|
@ -33,6 +39,9 @@ function O.get_game_state()
|
|||
|
||||
-- Current hand scoring (chips × mult = score)
|
||||
current_hand = O.get_current_hand_scoring(),
|
||||
|
||||
-- Current seed
|
||||
seed = G.GAME and G.GAME.pseudorandom.seed or 0,
|
||||
}
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from gymnasium import spaces
|
|||
from ..utils.communication import BalatroPipeIO
|
||||
from .reward import BalatroRewardCalculator
|
||||
from ..utils.mappers import BalatroStateMapper, BalatroActionMapper
|
||||
from ..utils.replay import ReplaySystem
|
||||
|
||||
|
||||
class BalatroEnv(gym.Env):
|
||||
|
|
@ -40,6 +41,10 @@ class BalatroEnv(gym.Env):
|
|||
self.pipe_io = BalatroPipeIO()
|
||||
self.reward_calculator = BalatroRewardCalculator()
|
||||
|
||||
# Replay System
|
||||
self.replay_system = ReplaySystem()
|
||||
self.actions_taken = []
|
||||
|
||||
# Define Gymnasium spaces
|
||||
# Action Spaces; This should describe the type and shape of the action
|
||||
# Constants - Core gameplay actions only (SELECT_HAND=1, PLAY_HAND=2, DISCARD_HAND=3)
|
||||
|
|
@ -59,7 +64,7 @@ class BalatroEnv(gym.Env):
|
|||
|
||||
# Observation space: This should describe the type and shape of the observation
|
||||
# Constants
|
||||
self.OBSERVATION_SIZE = 215
|
||||
self.OBSERVATION_SIZE = 216
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, # lowest bound of observation data
|
||||
high=np.inf, # highest bound of observation data
|
||||
|
|
@ -85,6 +90,7 @@ class BalatroEnv(gym.Env):
|
|||
self.prev_state = None
|
||||
self.game_over = False
|
||||
self.restart_pending = False
|
||||
self.actions_taken = []
|
||||
|
||||
# Reset reward tracking
|
||||
self.reward_calculator.reset()
|
||||
|
|
@ -126,7 +132,7 @@ class BalatroEnv(gym.Env):
|
|||
|
||||
# Send action response to Balatro mod
|
||||
response_data = self.action_mapper.process_action(rl_action=action)
|
||||
|
||||
self.actions_taken.append(response_data)
|
||||
success = self.pipe_io.send_response(response_data)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to send response to Balatro")
|
||||
|
|
@ -141,9 +147,10 @@ class BalatroEnv(gym.Env):
|
|||
|
||||
# Update current state
|
||||
self.current_state = next_request
|
||||
game_state = self.current_state.get('game_state', {})
|
||||
|
||||
# Check for game over condition
|
||||
game_over_flag = self.current_state.get('game_state', {}).get('game_over', 0)
|
||||
game_over_flag = game_state.get('game_over', 0)
|
||||
if game_over_flag == 1:
|
||||
observation = self.state_mapper.process_game_state(self.current_state)
|
||||
reward = self.reward_calculator.calculate_reward(
|
||||
|
|
@ -155,8 +162,33 @@ class BalatroEnv(gym.Env):
|
|||
restart_response = {"action": 6, "params": []}
|
||||
self.pipe_io.send_response(restart_response)
|
||||
|
||||
return observation, reward, True, False, {"game_over": True}
|
||||
|
||||
return observation, reward, True, False, {}
|
||||
|
||||
# Check for game win condition
|
||||
game_win_flag = game_state.get('game_win', 0)
|
||||
if game_win_flag == 1:
|
||||
observation = self.state_mapper.process_game_state(self.current_state)
|
||||
reward = self.reward_calculator.calculate_reward(
|
||||
current_state=self.current_state,
|
||||
prev_state=self.prev_state if self.prev_state else {}
|
||||
)
|
||||
|
||||
# Save replay
|
||||
self.replay_system.try_save_replay(
|
||||
file_path=self.replay_system.REPLAY_FILE_PATH,
|
||||
seed=game_state.get('seed', ''),
|
||||
actions=self.actions_taken,
|
||||
score=reward,
|
||||
chips=game_state.get('chips', 0)
|
||||
)
|
||||
|
||||
# Auto-send restart command to Balatro
|
||||
restart_response = {"action": 6, "params": []}
|
||||
self.pipe_io.send_response(restart_response)
|
||||
|
||||
return observation, reward, True, False, {}
|
||||
|
||||
|
||||
# Process new state for SB3
|
||||
observation = self.state_mapper.process_game_state(self.current_state)
|
||||
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class BalatroRewardCalculator:
|
|||
# Check if blind is defeated by comparing chips to requirement
|
||||
blind_chips = inner_game_state.get('blind_chips', 300) # TODO 300 only focusing on first blind
|
||||
# Only consider blind defeated if we actually have chips AND a valid blind requirement
|
||||
blind_defeated = (current_chips > 0 and blind_chips > 0 and current_chips >= blind_chips)
|
||||
blind_defeated = inner_game_state.get('game_win', 0)
|
||||
|
||||
|
||||
# Hand type info - use current hand scoring
|
||||
|
|
|
|||
|
|
@ -210,6 +210,7 @@ class BalatroStateMapper:
|
|||
features.append(float(state.get('chips', 0)))
|
||||
features.extend(make_onehot(state.get('state', 0), 20))
|
||||
features.append(float(state.get('game_over', 0)))
|
||||
features.append(float(state.get('game_win', 0)))
|
||||
features.append(float(state.get('retry_count', 0)))
|
||||
features.extend(self._extract_hand_features(state.get('hand', {})))
|
||||
features.extend(self._extract_current_hand_scoring(state.get('current_hand', {})))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,89 @@
|
|||
import json
|
||||
from typing import List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
class ReplaySystem:
|
||||
def __init__(self, max_replays: int = 10):
|
||||
self.MAX_REPLAYS = max_replays
|
||||
self.REPLAY_FILE_PATH = "replays.json"
|
||||
|
||||
def try_save_replay(self, file_path: str, seed: str, actions: List[Dict[str, Any]], score: float, chips: int):
|
||||
"""Save the current replay to a file if the score is among the top MAX_REPLAYS."""
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
replay_data = {
|
||||
"seed": seed,
|
||||
"timestamp": timestamp,
|
||||
"chips": chips,
|
||||
"score": score,
|
||||
"actions": actions,
|
||||
}
|
||||
|
||||
# Load existing replays or create new list
|
||||
replays = self.load_replays(file_path)
|
||||
|
||||
# If we have fewer than MAX_REPLAYS, just add it
|
||||
if len(replays) < self.MAX_REPLAYS:
|
||||
replays.append(replay_data)
|
||||
else:
|
||||
# Check if this score is higher than the lowest score
|
||||
replays.sort(key=lambda x: x['score'], reverse=True)
|
||||
if score > replays[-1]['score']:
|
||||
# Replace the lowest scoring replay
|
||||
replays[-1] = replay_data
|
||||
else:
|
||||
# Score is not high enough, don't add it
|
||||
return len(replays)
|
||||
|
||||
# Sort by score (highest first) and keep only top MAX_REPLAYS
|
||||
replays.sort(key=lambda x: x['score'], reverse=True)
|
||||
replays = replays[:self.MAX_REPLAYS]
|
||||
|
||||
# Save back to file
|
||||
self.save_replays(file_path, replays)
|
||||
|
||||
return len(replays)
|
||||
|
||||
def load_replays(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
"""Load replays from file or return empty list if file doesn't exist."""
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
# Handle both list format and single replay format
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
elif isinstance(data, dict) and "replay" in data:
|
||||
return [data["replay"]]
|
||||
else:
|
||||
return []
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return []
|
||||
|
||||
def save_replays(self, file_path: str, replays: List[Dict[str, Any]]):
|
||||
"""Save replays to file."""
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(replays, f, indent=4)
|
||||
|
||||
def sort_replays(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
"""Sort replays by score and return the top MAX_REPLAYS."""
|
||||
replays = self.load_replays(file_path)
|
||||
replays.sort(key=lambda x: x['score'], reverse=True)
|
||||
return replays[:self.MAX_REPLAYS]
|
||||
|
||||
def get_top_replays(self, file_path: str, count: int = None) -> List[Dict[str, Any]]:
|
||||
"""Get the top replays from the file."""
|
||||
if count is None:
|
||||
count = self.MAX_REPLAYS
|
||||
|
||||
replays = self.load_replays(file_path)
|
||||
replays.sort(key=lambda x: x['score'], reverse=True)
|
||||
return replays[:count]
|
||||
|
||||
def clear_replays(self, file_path: str):
|
||||
"""Clear all replays from the file."""
|
||||
self.save_replays(file_path, [])
|
||||
|
||||
def get_replay_count(self, file_path: str) -> int:
|
||||
"""Get the number of replays in the file."""
|
||||
replays = self.load_replays(file_path)
|
||||
return len(replays)
|
||||
|
|
@ -51,10 +51,12 @@ class GameStateValidator:
|
|||
GameStateValidator._validate_round(game_state['round'])
|
||||
GameStateValidator._validate_current_hand(game_state['current_hand'])
|
||||
assert game_state["game_over"] in [0, 1]
|
||||
assert game_state["game_win"] in [0, 1]
|
||||
assert isinstance(game_state["state"], int)
|
||||
assert isinstance(game_state["blind_chips"], (int, float))
|
||||
assert isinstance(game_state["chips"], (int, float))
|
||||
assert isinstance(game_state.get("retry_count", 0), int)
|
||||
assert isinstance(game_state.get("seed", 0), str)
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue