JUNTO Practice: Battleship Bots (Part 2)
Discussed on February 06, 2021.
Improve your Battleship bots. (Refer to the Jupyter notebook that I emailed you.) We will have the bots compete at the next meeting.
Results:
Oscar 908
John 72
GameResult.DRAW 20
dtype: int64
Dan 997
John 3
dtype: int64
Dan 948
Oscar 35
GameResult.DRAW 17
dtype: int64
Solutions
Click to see:
Oscar Martinez
class BattleShipVanillaCNN(nn.Module):
def __init__(self):
super(BattleShipVanillaCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 3, 5, padding=2)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(3, 5, 5, padding=2)
self.fc1 = nn.Linear(5 * 2 * 2, 100)
self.fc2 = nn.Linear(100, 100)
def forward(self, x):
self.mask = x.detach().clone().flatten(1) != 0
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
# print(x.shape)
x = x.view(-1, 5 * 2 * 2)
x = F.relu((self.fc1(x)))
x = self.fc2(x)
x[self.mask] = float("-inf")
return x
# CNN RL hyperparameters
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 1
EPS_END = 0.05
EPS_DECAY = 1000 # higher is slower
TARGET_UPDATE = 300 # we try increasing this
lr = 1e-1
# Initialize model
q_net = BattleShipVanillaCNN().to(device)
checkpoint = False
if checkpoint:
q_net.load_state_dict(
torch.load(
"battleship/vanilla_cnn/q_net_1611531763_1456876.pt",
map_location=device,
)
)
target_net = BattleShipVanillaCNN().to(device)
target_net.load_state_dict(q_net.state_dict())
target_net.eval()
# potentially change to ADAM
optimizer = optim.Adam(q_net.parameters(), lr)
replay_memory = ReplayMemory(200_000)
steps_done = 0
if checkpoint:
steps_done = 1456876
scores = []
losses = []
def select_action(state, training=True):
global steps_done
eps_sample = random.random()
eps_threshold = EPS_END + (
EPS_START - EPS_END
) * math.exp(-1.0 * steps_done / EPS_DECAY)
if training:
steps_done += 1
if eps_sample > eps_threshold:
q_net.eval()
with torch.no_grad():
# Here we can either sample or go with max
result = q_net(state)
result = result.max(1)[1].unsqueeze(1)
q_net.train()
else:
moves = [
i for i in range(100) if state.flatten()[i] == 0
]
result = torch.tensor(
[random.sample(moves, 1)],
device=device,
dtype=torch.long,
)
return result
def optimize_model():
global steps_done
global losses
if len(replay_memory) < BATCH_SIZE:
return
if not q_net.training:
q_net.train()
"""
Q network update
"""
transitions = replay_memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
non_final_mask = torch.tensor(
tuple(
map(lambda s: s is not None, batch.next_state)
),
device=device,
dtype=torch.bool,
)
non_final_next_states = torch.cat(
[s for s in batch.next_state if s is not None]
)
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.stack(batch.reward)
"""
for b in [state_batch,action_batch,reward_batch]:
print(b.shape)
"""
state_action_values = q_net(state_batch)
state_action_values = state_action_values.gather(
1, action_batch
)
# this might need to be corrected BELLMAN OPTIMIZATION
next_state_values = torch.zeros(
BATCH_SIZE, device=device
)
target_out = target_net(non_final_next_states)
next_state_values[non_final_mask] = target_out.max(1)[
0
].detach()
expected_state_action_values = (
next_state_values.unsqueeze(1) * GAMMA
) + reward_batch
q_loss = F.smooth_l1_loss(
state_action_values, expected_state_action_values
)
if q_loss.item() == float("inf"):
print(target_out)
print(expected_state_action_values)
pdb.set_trace()
losses.append(q_loss.item())
if steps_done % 1000 == 0:
print(f"Q Current Loss {q_loss.item():.4f}")
optimizer.zero_grad()
q_loss.backward()
"""
Gradient clipping
"""
for param in q_net.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()
def extract_coord(flat_index):
row = (flat_index // 10) + 1
col = (flat_index % 10) + 1
return Position(row, col)
class NeesonBot(Player):
def __init__(self, training=False):
self.training = training
def hook_start(self):
self.enemy_board = torch.zeros(
(10, 10), device=device
)
self.sunk = 0
self.turns = 0
def hook_self_guess(
self, *, guess: Position, result: GuessResult
):
self.previous_state = (
self.enemy_board.detach().clone()
)
self.turns += 1
if result == GuessResult.HIT:
reward = 1
self.enemy_board[
guess.row - 1, guess.col - 1
] = 1
elif result == GuessResult.SINK:
self.sunk += 1
reward = 5
self.enemy_board[
guess.row - 1, guess.col - 1
] = 1
if self.sunk == 5:
bins = (
(self.enemy_board + 1)
.flatten()
.int()
.bincount()
)
hits = bins[-1]
misses = bins[0]
if misses == 0:
misses += 1
factor = (hits / misses).item()
reward = reward * (10 * factor)
# Experimental, terminate game:
self.enemy_board = None
elif result == GuessResult.MISS:
reward = -1
self.enemy_board[
guess.row - 1, guess.col - 1
] = -1
self.reward = torch.tensor([reward], device=device)
if self.enemy_board is None:
replay_memory.push(
self.previous_state.unsqueeze(0).unsqueeze(
0
),
self.action,
self.enemy_board,
self.reward,
)
else:
replay_memory.push(
self.previous_state.unsqueeze(0).unsqueeze(
0
),
self.action,
self.enemy_board.detach()
.clone()
.unsqueeze(0)
.unsqueeze(0),
self.reward,
)
if self.training:
optimize_model()
def arrange(self):
return random_arrangement(10)
def guess(self):
action = select_action(
self.enemy_board.unsqueeze(0).unsqueeze(0)
)
self.action = action
return extract_coord(action.item())
John Lekberg
import random
class RandomPlayer(Player):
def hook_start(self) -> None:
self.remaining_guesses = [
Position(r, c)
for r in range(1, 11)
for c in range(1, 11)
]
def arrange(self) -> Arrangement:
return Arrangement(
ShipPlacement(Position(1, 1), Direction.HORIZONTAL),
ShipPlacement(Position(2, 1), Direction.HORIZONTAL),
ShipPlacement(Position(3, 1), Direction.HORIZONTAL),
ShipPlacement(Position(4, 1), Direction.HORIZONTAL),
ShipPlacement(Position(5, 1), Direction.HORIZONTAL),
)
def guess(self) -> Position:
assert len(self.remaining_guesses) > 0
result = random.choice(self.remaining_guesses)
self.remaining_guesses.remove(result)
return result
Daniel Bassett
class ZeroPosition(Player):
def hook_start(self) -> None:
self.remaining_guesses = [
Position(r, c)
for r in range(1, 11)
for c in range(1, 11)
]
def arrange(self) -> Arrangement:
return Arrangement(
ShipPlacement(Position(1, 1), Direction.HORIZONTAL),
ShipPlacement(Position(2, 1), Direction.HORIZONTAL),
ShipPlacement(Position(3, 1), Direction.HORIZONTAL),
ShipPlacement(Position(4, 1), Direction.HORIZONTAL),
ShipPlacement(Position(5, 1), Direction.HORIZONTAL),
)
def guess(self) -> Position:
assert len(self.remaining_guesses) > 0
result = self.remaining_guesses[0]
self.remaining_guesses.remove(result)
return result