ANNPlayer.java
import java.util.Random;
public class ANNPlayer
{
private NueralNet m_Net;
private int m_GameNum;
private int m_MaxGameNum;
private Random rand;
ANNPlayer(double alpha, double lambda)
{
m_Net=new NueralNet(6,6);
m_Net.setLambda(lambda);
m_Net.setAlpha(alpha);
m_MaxGameNum=1000;
m_GameNum=0;
rand = new Random(System.currentTimeMillis());
}
private int getLineStatus(int l1,int l2, int l3)
{
if ((l1==l2) && (l2==l3))
{
return l1*3;
}
if ((l1==l2) && (l3==0))
{
return l1*2;
}
if ((l1==l3) && (l2==0))
{
return l1*2;
}
if ((l2==l3) && (l1==0))
{
return l2*2;
}
if ((l1!=0) && (l2==0) && (l3==0))
{
return l1;
}
if ((l1==0) && (l2!=0) && (l3==0))
{
return l2;
}
if ((l1==0) && (l2==0) && (l3!=0))
{
return l3;
}
return 0;
}
private boolean win(int[][] Table)
{
int[] S=new int[3];
int i;
int j;
for (i=0;i<3;i++)
{
for (j=0;j<3;j++)
{
S[j]=Table[i][j];
}
if ((S[0]==S[1]) && (S[1]==S[2]) && (S[0]!=0))
{
return true;
}
}
for (i=0;i<3;i++)
{
for (j=0;j<3;j++)
{
S[j]=Table[j][i];
}
if ((S[0]==S[1]) && (S[1]==S[2]) && (S[0]!=0))
{
return true;
}
}
for (i=0;i<3;i++)
{
S[i]=Table[i][i];
}
if ((S[0]==S[1]) && (S[1]==S[2]) && (S[0]!=0))
{
return true;
}
for (i=0;i<3;i++)
{
S[i]=Table[i][2-i];
}
if ((S[0]==S[1]) && (S[1]==S[2]) && (S[0]!=0))
{
return true;
}
return false;
}
private static void copyTable(int Table[][], int NewTable[][])
{
for (int i=0;i<3;i++)
{
for (int j=0;j<3;j++)
{
NewTable[i][j] = Table[i][j];
}
}
}
private double[] prepareInputVector(int Player, int Table[][])
{
double []Input=new double[6];
int i;
int j;
int lineStatus;
int[] S=new int[3];
for (i=0;i<3;i++)
{
for (j=0;j<3;j++)
{
S[j]=Table[i][j];
}
lineStatus=getLineStatus(S[0],S[1],S[2]);
if (sign(lineStatus)==sign(Player))
{
Input[Math.abs(lineStatus)-1]++;
}
else if (sign(lineStatus)==sign(-1*Player))
{
Input[3+Math.abs(lineStatus)-1]--;
}
}
for (i=0;i<3;i++)
{
for (j=0;j<3;j++)
{
S[j]=Table[j][i];
}
lineStatus=getLineStatus(S[0],S[1],S[2]);
if (sign(lineStatus)==sign(Player))
{
Input[Math.abs(lineStatus)-1]++;
}
else if (sign(lineStatus)==sign(-1*Player))
{
Input[3+Math.abs(lineStatus)-1]--;
}
}
for (i=0;i<3;i++)
{
S[i]=Table[i][i];
}
lineStatus=getLineStatus(S[0],S[1],S[2]);
if (sign(lineStatus)==sign(Player))
{
Input[Math.abs(lineStatus)-1]++;
}
else if (sign(lineStatus)==sign(-1*Player))
{
Input[3+Math.abs(lineStatus)-1]--;
}
for (i=0;i<3;i++)
{
S[i]=Table[i][2-i];
}
lineStatus=getLineStatus(S[0],S[1],S[2]);
if (sign(lineStatus)==sign(Player))
{
Input[Math.abs(lineStatus)-1]++;
}
else if (sign(lineStatus)==sign(-1*Player))
{
Input[3+Math.abs(lineStatus)-1]--;
}
return Input;
}
public void getNextMove(int[][] Table)
{
int[][] bestTable = new int[3][3];
getBestTable(Table, bestTable, -1, false);
copyTable(bestTable, Table);
}
private double getScore(int Player, int[][] Table)
{
double score = m_Net.getScore((double[])prepareInputVector(-1,
Table).clone());
return score*Player*-1;
}
private boolean getBestTable(int[][] Table, int[][] BestTable,
int Player, boolean train)
{
boolean randomSelect = false;
double progress = (double)m_GameNum/(double)m_MaxGameNum;
if (Player == 1 && train)
{
if (rand.nextDouble()>progress)
{
randomSelect = true;
}
}
if (Player == -1 && train)
{
if (rand.nextDouble()>0.9)
{
randomSelect = true;
}
}
int i;
int j;
double MaxScore=0;
double CurrentScore=0;
int[][] NewTable = new int[3][3];
boolean first=true;
for (i=0;i<3;i++)
{
for (j=0;j<3;j++)
{
if (Table[i][j]!=0)
{
continue;
}
copyTable(Table,NewTable);
NewTable[i][j]=Player;
CurrentScore = getScore(Player,NewTable);
if (randomSelect)
{
CurrentScore = rand.nextDouble();
if (Player==1)
{
if (win(NewTable))
{
CurrentScore = 2.0;
}
}
}
if (first)
{
MaxScore = CurrentScore;
copyTable(NewTable,BestTable);
first = false;
}
else
{
if (MaxScore<CurrentScore)
{
MaxScore = CurrentScore;
copyTable(NewTable,BestTable);
}
else if (MaxScore==CurrentScore)
{
if (Math.round(rand.nextDouble())==1.0)
{
MaxScore = CurrentScore;
copyTable(NewTable,BestTable);
}
}
}
}
}
return randomSelect;
}
public void markMove(int[][] table, int player)
{
double reward;
if (win(table))
{
reward = player*-1.0;
}
else
{
reward = 0.0;
}
m_Net.updateWeights(reward, (double[])prepareInputVector(-1,
table).clone());
}
public void finishMarkMoves(int[][] table)
{
m_Net.finishUpdateWeights();
}
public void trainNet(int Player)
{
int[][] table = new int[3][3];;
int[][] BestTable = new int[3][3];
double reward = 0.0;
int MoveNum = 0;
boolean randomMove;
while(true)
{
randomMove = getBestTable(table,BestTable,Player, true);
copyTable(BestTable,table);
if (win(table))
{
reward = Player*-1;
}
else
{
reward = 0.0;
}
if (!randomMove)
{
m_Net.updateWeights(0.0, (double[])prepareInputVector(-1,
table).clone());
}
if (reward!=0)
{
break;
}
else if (MoveNum==8)
{
break;
}
MoveNum++;
Player*=-1;
}
//m_Net.updateWeights(reward/*-getScore(-1, table)*/,
(double[])prepareInputVector(-1, table).clone());
m_Net.finishUpdateWeights(reward);
m_GameNum++;
}
public int getGameNum()
{
return m_GameNum;
}
public int getMaxGameNum()
{
return m_MaxGameNum;
}
public void setMaxGameNum(int num)
{
m_MaxGameNum=num;
}
public void setGameNum(int num)
{
m_GameNum=num;
}
private int sign(int x)
{
if (x<0) return -1;
if (x>0) return 1;
return 0;
}
}
Related files:
TicTacToe.java
ANNPlayer.java
NueralNet.java
HumenPlayer.java
MessageBox.java
TurnDialog.java
Applet1.java |