Git Product home page Git Product logo

samshipengs / coordinated-multi-agent-imitation-learning Goto Github PK

View Code? Open in Web Editor NEW
38.0 5.0 7.0 12.45 MB

This is an implementation of the paper "Coordinated Multi Agent Imitation Learning", or the Sloan version "Data-Driven Ghosting using Deep Imitation Learning" using Tensorflow

Jupyter Notebook 78.25% Python 20.63% HTML 0.48% JavaScript 0.31% CSS 0.33%
ghosting imitation-learning multi-agent

coordinated-multi-agent-imitation-learning's Introduction

Coordinated-Multi-Agent-Imitation-Learning

Toronto Raptors had created a ghosting system that would help coaching staff to analyze defend plays better. The game is recorded by camera system above the arena, staff would mark the position of they player where they thought the player should have been and this is the ghost of the player. However, this involves a lot of mannual annotations. In the coordinated multi-agent imitation learning, a data driven method was proposed. (For more details of the Raptors' ghosting system see Lights, Cameras, Revolution).

So in this repo we attempt to implement the paper Coordinated-Multi-Agent-Imitation-Learning (or Sloan version) with Tensorflow.

Introduction

We are aiming to predict the movements or trajectories of defending players for a given team (in principle, we should also be able to create model that predicts offense trajectoy, but defending players were used for both the original ghosting work and also this paper. I assume the reason is that defending trajecotory is slightly easier to predict than offending).

In order to predict the trajectory, we need to roll out a sequence of prediction for the player's next action. The natural candidate to perform such task is Recurrent Neural Networks (LSTM more specifically), and the input data to the model will be a sequence of (x,y) coordinates of each players (both defendinging team and opponent).

The end result we would like to achieve is that, for a given game play suitation where team A is on defense, we can show what would another team B do, who presumably is the best defending team in the league. This is slightly different compaing to the original ghosting work done by Raptors. Instead of focusing on specifically what a player should do based on a coach experience, this work is modeling what another team would do in same suitation (again, in principle we could also model each specific player but that shall require a much larger data set and that is a more complicated task for the model to learn).

Data

The update-to-date data is proprietary, but we found a tracking and play-by-play data for 42 Toronto Raptors games played in Fall 2015 on this link. We will use this data for our implementation. See the link for a detailed description of the data.

Below is a short preview of the data for game with id 0021500463:

end_time_left home moments orig_events playbyplay quarter start_time_left visitor
0 702.31 {'abbreviation': 'CHI', 'players': [{'playerid... [[1, 1451351428029, 708.28, 12.78, None, [[-1,... [0] GAME_ID EVENTNUM EVENTMSGTYPE EVENTMS... 1 708.28 {'abbreviation': 'TOR', 'players': [{'playerid...
1 686.28 {'abbreviation': 'CHI', 'players': [{'playerid... [[1, 1451351428029, 708.28, 12.78, None, [[-1,... [1] GAME_ID EVENTNUM EVENTMSGTYPE EVENTMS... 1 708.28 {'abbreviation': 'TOR', 'players': [{'playerid...
2 668.42 {'abbreviation': 'CHI', 'players': [{'playerid... [[1, 1451351444029, 692.25, 12.21, None, [[-1,... [2, 3] GAME_ID EVENTNUM EVENTMSGTYPE EVENTMS... 1 692.25 {'abbreviation': 'TOR', 'players': [{'playerid...

The main columns we use for building the model is moments, quarter, home and visitor. Moments contain the most information such as basketball location, all players locations and their team ID and player ID. Quarter is used in both input features and preprocessing. Home and visitor basically specifies the team name and ID which can be usful when validating the preprocessed data.

Pre-processing

Not all the moments from the data set is used. Each event is supposed to describe a game play precisely but the given moments often contain frames that would not help the model. For examples, there are frames only ontain 8 or 9 players, or the basketball is out of bound, this is not allowed as the model expects a fixed input dimension. Many moments have frames that are not critical to decision making, e.g. dribbling before entering the half court, clocks being stopped etc. Shot clock sometimes has null value.

ALso to make it easier for the model to learn, we perform some extra preprocessings. Such as, only model defending players and normalize the court to just half court, the reason is that the game swaps court after half-time which could confuse the model and game plays involvs whole court is more dynamic in nature so that it's harder to predict.

We list out each pre-processing details in the following:

  1. Remove frames that do not contain 10 players and 1 basketball, and chunk the following frames as another event (same applies for any chunking in the subsequent processings).
    You can find the function named remove_non_eleven does this in preprocessing.py.
    This prevents players or basketball out of boundary.
  2. Chunk the moments from shotclock.
    chunk_halfcourt does this in preprocessing.py
    If the shotclock turns to 24 (shot clock reached) or 0 (resets, e.g. rebound or turnover), or shot clock is None or stopped, we remove them from the moments. Since the behavior of players differs dramatically at these time points.
  3. Chunk moments to just half-court.
    chunk_halfcourt in preprocessing.py
    Remove all moments that are not contained within a half-court and change the x coordinates to be between 0 and 47 (NBA court is 50x94 feet).
  4. Reorder data reorder_teams in preprocessing.py
    Reorder the matrix in moments s.t. the first five players data are always from defending player.

Originally we would like to use the play-by-play data to do the data processing but it turns out the play-by-play data itself is not accurate. For example, In game 0021500196, event 2, 'time_left': [705, 704, 685, 684]}, 'event_str': ['miss', 'rebound', 'miss', 'rebound'],

For 685.0 the shot clock is at 21.77, which at the time the shot was already missed for a while and the defending team got rebound and was already switching to offense. The event miss should have been marked right after 24s shot clock reset. This is resonable to human eyes but would certain affect the model learning.

Features

  1. Besides Cartesian coordiantes for basketball and all the players from the data, we also add Polar coodinates.
  2. The distance of each players to the ball and hoop in polar coordiantes.
  3. Add velocities for both players and basketball (in Cartesian coordinates).

You can check out the details in create_static_features and create_dynamic_features functions form features.py.

Below is an example plot of a game event,

Blue is the defending team, red is the opponent and the green one is the basketball. The arrow indicates the velocity vector for each player. The black circle is the hoop. The smaller the dot is the earlier player is in the sequence

Hidden Structure Learning

Finally we will get into how we want to build the model. It may seem like how we want to approach this i.e. feed the input sequence of data into a LSTM where the label for each current time step is the input of the next time step. However, there are two major issues:

  1. Since we are training on input data that contains multiple agents, we need to consider the order of the input.
  2. A standard one-to-one or many-to-one would not have practical use since in real game we would like to have predictions for next at least several time steps instead of just one prediction at a time.

In this section we mainly talk about the first issue. The input data point at each time step looks like,

we are supposed to feed into data that has consistent order to the model, otherwise the model is going to have a hard time to learn anything. This is known as "index free" multi-agent system. How do we define the order then? by their height, weight or their assigned roles e.g. Power-forward or Point-guard? Using the pre-defined roles sounds more reasonable but they may change during the actual game play. So instead of using fixed roles, the team of this paper suggested to learn the hidden states/roles for each players.

Here we will make use of the hmmlearn library (pomegranate looks like a good option too). We train a Hidden Markov model which would predict the hidden state for each time step, this is done by using Baum–Welch algorithm from which we can know the emission probabilities for each hidden roles.

Naturally we do not need to bother with the emission distribution, Viterbi algorithm would help us to find the most likely sequence of hidden roles. However since we are trying to assign hidden roles to each player then it is possible that different players get assigned the same hidden role (indeed it happened when I run Viterbi to get the sequence of assigned roles). More concretely, for each player at each time step we assign a hidden role:

Notice that how player 1 and 2 both get assigned to hidden role 1 for initial time step, and player 2 and 5 get assigned to the same hidden role 3. We cannot have this assignment as we will need the hidden role to order the players, so instead of having the hard assignment for each player we employ linear assignment techniques, more specifically Hungarian algorithm to assign the hidden role.

We do so by first compute the Euclidean distance (you can also try cosine similarity) from each player's data point at certain timestep to the center of each hidden roles distribution, which we assumed to be (mixture) multivariavte Gaussain. Then we use this as the cost matrix and apply Hungarian algorithm.

try to create a vis for the hidden state

Imitation Learning

We are hoping the model can learn or mimic the trajectory by training on players tracking data. Naturally we make use of LSTM for this task. One common example of the LSTM architecture is to take a sequence of length T of state S and outputs the action for each next time step.

however, the first obvious issue is that in real game we do not have the sequence of player states (unlike in a machine translation problem where you have the complete sentence ready), which are exactly the values we are trying to predict for. If we have these values then we do not to predict them anyway. So simply we do not have the input for a sequence of inputs.

What one could do is to train the model based on available data, use the predicted output of current time step as the next time step input during run time, that is instead of using true value as next time step input we use the output from previous time step.

This is doable and looks okay but in run time the model will get baffled by the drifting or compound error. As the prediction goes on for longer time steps, the prediction error gets larger and larger to the point where the prediction would be really far off from the realistic trajectories. This happens although the loss value is small in training time.

We demonstate this through a simple experiment. Below is a sine signla being added Gaussian noise with mean=2 and standard deviation=1.

First we apply regular lstm that uses ground truth as the real input for every time step, the prediction result looks pretty _good_,

but this is deceptive because in real settings we need to predict multiple steps ahead instead of relying ground truth. So if we take the trained model and simply make predictions based on previous result, the prediction quickly converges to the mean of the Gaussian noise,

So the paper proposed to let the model see for longer time steps and experience this drifting error during train time. We first start training the regular lstm model where each time step input is ground truth. Then we extend the horizon where the input uses i.e. during training time we use the current time step output as the next step input. We increase the horzion by 1 and repeat. This gives model the experience of handling drifting error in train time, which leads to better performance in real run time setting.

for the sine wave example, the test result becomes much better when we gradually increase the horizon from 0 to 6,

To illustrate this using network connections:

Step 1 Step 2 Step 3

coordinated-multi-agent-imitation-learning's People

Contributors

samshipengs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

coordinated-multi-agent-imitation-learning's Issues

Add more game data

Currently, it's only using one single game data, need to include all the games.

Play by play data is not accurate

For example, In game 0021500196, event 2, 'time_left': [705, 704, 685, 684]}, 'event_str': ['miss', 'rebound', 'miss', 'rebound'],
and match these with the shot clock left and the court visualization, the time_left in play-by-play seems like not describing the event segmentation correctly.

For 685.0 the shot clock is at 21.77, which at the time the shot was already missed for a while and the defending team got rebound and was already switching to offense. The event miss should be marked right after 24s shot clock reset.

Overfitting

2018-05-19 20:15:18,769 | INFO : Training with hyper parameters:
{'use_model': 'dynamic_rnn_layer_norm', 'batch_size': 64, 'sequence_length': 50, 'overlap': 25, 'state_size': [128, 128], 'use_peepholes': None, 'input_dim': 179, 'dropout_rate': 0.6, 'learning_rate': 0.0001, 'n_epoch': 1000}

2018-05-19 20:15:22,532 | INFO : Horizon 0 ==========
2018-05-19 20:15:58,637 | INFO : Epoch 0 | loss: 424.64 | time took: 34.02s | validation loss: 370.09
2018-05-19 20:21:29,900 | INFO : Epoch 10 | loss: 13.52 | time took: 32.61s | validation loss: 10.31
2018-05-19 20:26:59,878 | INFO : Epoch 20 | loss: 3.62 | time took: 32.94s | validation loss: 3.74
2018-05-19 20:32:30,784 | INFO : Epoch 30 | loss: 3.08 | time took: 32.63s | validation loss: 2.86
2018-05-19 20:37:59,144 | INFO : Epoch 40 | loss: 2.90 | time took: 32.77s | validation loss: 2.74
2018-05-19 20:43:26,475 | INFO : Epoch 50 | loss: 2.79 | time took: 32.46s | validation loss: 2.58
2018-05-19 20:48:53,346 | INFO : Epoch 60 | loss: 2.70 | time took: 32.50s | validation loss: 2.70
2018-05-19 20:54:20,320 | INFO : Epoch 70 | loss: 2.63 | time took: 32.46s | validation loss: 2.56
2018-05-19 20:59:47,275 | INFO : Epoch 80 | loss: 2.57 | time took: 32.45s | validation loss: 2.60
2018-05-19 21:05:14,118 | INFO : Epoch 90 | loss: 2.50 | time took: 32.50s | validation loss: 2.59
2018-05-19 21:10:43,281 | INFO : Epoch 100 | loss: 2.43 | time took: 32.98s | validation loss: 2.55
2018-05-19 21:16:13,949 | INFO : Epoch 110 | loss: 2.37 | time took: 33.29s | validation loss: 2.57
2018-05-19 21:21:45,966 | INFO : Epoch 120 | loss: 2.30 | time took: 32.93s | validation loss: 2.51
2018-05-19 21:27:17,978 | INFO : Epoch 130 | loss: 2.24 | time took: 32.95s | validation loss: 2.44
2018-05-19 21:32:49,640 | INFO : Epoch 140 | loss: 2.18 | time took: 32.94s | validation loss: 2.72
2018-05-19 21:38:21,280 | INFO : Epoch 150 | loss: 2.12 | time took: 32.89s | validation loss: 2.48
2018-05-19 21:43:51,814 | INFO : Epoch 160 | loss: 2.07 | time took: 32.71s | validation loss: 2.62
2018-05-19 21:49:22,015 | INFO : Epoch 170 | loss: 2.02 | time took: 32.80s | validation loss: 2.57
2018-05-19 21:54:51,845 | INFO : Epoch 180 | loss: 1.96 | time took: 32.79s | validation loss: 2.57
2018-05-19 22:00:21,549 | INFO : Epoch 190 | loss: 1.91 | time took: 32.69s | validation loss: 2.62
2018-05-19 22:05:51,438 | INFO : Epoch 200 | loss: 1.86 | time took: 32.75s | validation loss: 2.52
2018-05-19 22:11:21,148 | INFO : Epoch 210 | loss: 1.82 | time took: 32.73s | validation loss: 2.60
2018-05-19 22:16:50,945 | INFO : Epoch 220 | loss: 1.78 | time took: 32.74s | validation loss: 2.67
2018-05-19 22:22:21,564 | INFO : Epoch 230 | loss: 1.74 | time took: 32.80s | validation loss: 2.65
2018-05-19 22:27:51,485 | INFO : Epoch 240 | loss: 1.70 | time took: 32.78s | validation loss: 2.84
2018-05-19 22:33:21,232 | INFO : Epoch 250 | loss: 1.66 | time took: 32.75s | validation loss: 2.59
2018-05-19 22:38:51,611 | INFO : Epoch 260 | loss: 1.62 | time took: 33.05s | validation loss: 2.66
2018-05-19 22:44:21,437 | INFO : Epoch 270 | loss: 1.59 | time took: 32.74s | validation loss: 2.68
2018-05-19 22:49:51,152 | INFO : Epoch 280 | loss: 1.55 | time took: 32.75s | validation loss: 2.61
2018-05-19 22:55:20,930 | INFO : Epoch 290 | loss: 1.52 | time took: 32.71s | validation loss: 2.55
2018-05-19 23:00:50,665 | INFO : Epoch 300 | loss: 1.49 | time took: 32.75s | validation loss: 2.74
2018-05-19 23:06:20,603 | INFO : Epoch 310 | loss: 1.46 | time took: 32.75s | validation loss: 2.80
2018-05-19 23:11:50,676 | INFO : Epoch 320 | loss: 1.43 | time took: 32.75s | validation loss: 2.75
2018-05-19 23:17:20,512 | INFO : Epoch 330 | loss: 1.40 | time took: 32.76s | validation loss: 2.78
2018-05-19 23:22:50,682 | INFO : Epoch 340 | loss: 1.37 | time took: 32.76s | validation loss: 3.07
2018-05-19 23:28:20,826 | INFO : Epoch 350 | loss: 1.34 | time took: 32.75s | validation loss: 2.82
2018-05-19 23:33:50,696 | INFO : Epoch 360 | loss: 1.32 | time took: 32.74s | validation loss: 2.77
2018-05-19 23:39:20,657 | INFO : Epoch 370 | loss: 1.30 | time took: 32.94s | validation loss: 2.82
2018-05-19 23:44:50,417 | INFO : Epoch 380 | loss: 1.27 | time took: 32.84s | validation loss: 2.94
2018-05-19 23:50:20,729 | INFO : Epoch 390 | loss: 1.24 | time took: 32.97s | validation loss: 2.79
2018-05-19 23:55:55,456 | INFO : Epoch 400 | loss: 1.22 | time took: 33.06s | validation loss: 2.88
2018-05-20 00:01:30,456 | INFO : Epoch 410 | loss: 1.20 | time took: 32.58s | validation loss: 2.94
2018-05-20 00:07:07,730 | INFO : Epoch 420 | loss: 1.18 | time took: 33.55s | validation loss: 3.09
2018-05-20 00:12:44,634 | INFO : Epoch 430 | loss: 1.16 | time took: 33.37s | validation loss: 2.78
2018-05-20 00:18:14,081 | INFO : Epoch 440 | loss: 1.14 | time took: 32.67s | validation loss: 3.27
2018-05-20 00:23:47,005 | INFO : Epoch 450 | loss: 1.12 | time took: 33.36s | validation loss: 2.82
2018-05-20 00:29:12,431 | INFO : Epoch 460 | loss: 1.10 | time took: 32.09s | validation loss: 3.01
2018-05-20 00:34:35,043 | INFO : Epoch 470 | loss: 1.08 | time took: 32.03s | validation loss: 2.63
2018-05-20 00:39:58,028 | INFO : Epoch 480 | loss: 1.07 | time took: 32.10s | validation loss: 2.92
2018-05-20 00:45:20,698 | INFO : Epoch 490 | loss: 1.05 | time took: 32.07s | validation loss: 3.15
2018-05-20 00:50:43,453 | INFO : Epoch 500 | loss: 1.03 | time took: 32.09s | validation loss: 2.89
2018-05-20 00:56:06,164 | INFO : Epoch 510 | loss: 1.01 | time took: 32.08s | validation loss: 2.87
2018-05-20 01:01:28,880 | INFO : Epoch 520 | loss: 1.00 | time took: 32.04s | validation loss: 3.02
2018-05-20 01:06:52,427 | INFO : Epoch 530 | loss: 0.98 | time took: 32.94s | validation loss: 2.94
2018-05-20 01:12:15,404 | INFO : Epoch 540 | loss: 0.96 | time took: 32.09s | validation loss: 2.90
2018-05-20 01:17:38,158 | INFO : Epoch 550 | loss: 0.96 | time took: 32.09s | validation loss: 2.80
2018-05-20 01:23:01,047 | INFO : Epoch 560 | loss: 0.93 | time took: 32.13s | validation loss: 3.14
2018-05-20 01:28:23,858 | INFO : Epoch 570 | loss: 0.93 | time took: 32.08s | validation loss: 2.94
2018-05-20 01:33:46,562 | INFO : Epoch 580 | loss: 0.91 | time took: 32.07s | validation loss: 2.79
2018-05-20 01:39:09,246 | INFO : Epoch 590 | loss: 0.90 | time took: 32.07s | validation loss: 3.16
2018-05-20 01:44:31,918 | INFO : Epoch 600 | loss: 0.89 | time took: 32.08s | validation loss: 2.85
2018-05-20 01:49:54,587 | INFO : Epoch 610 | loss: 0.87 | time took: 32.09s | validation loss: 2.90
2018-05-20 01:55:17,323 | INFO : Epoch 620 | loss: 0.86 | time took: 32.03s | validation loss: 2.67
2018-05-20 02:00:40,017 | INFO : Epoch 630 | loss: 0.85 | time took: 32.07s | validation loss: 3.18
2018-05-20 02:06:02,918 | INFO : Epoch 640 | loss: 0.84 | time took: 32.09s | validation loss: 3.01
2018-05-20 02:11:25,634 | INFO : Epoch 650 | loss: 0.83 | time took: 32.05s | validation loss: 2.93
2018-05-20 02:16:48,348 | INFO : Epoch 660 | loss: 0.82 | time took: 32.06s | validation loss: 3.02
2018-05-20 02:22:11,146 | INFO : Epoch 670 | loss: 0.81 | time took: 32.07s | validation loss: 2.79
2018-05-20 02:27:33,876 | INFO : Epoch 680 | loss: 0.80 | time took: 32.12s | validation loss: 2.88
2018-05-20 02:32:56,727 | INFO : Epoch 690 | loss: 0.79 | time took: 32.13s | validation loss: 2.95
2018-05-20 02:38:19,538 | INFO : Epoch 700 | loss: 0.77 | time took: 32.07s | validation loss: 2.66
2018-05-20 02:43:42,278 | INFO : Epoch 710 | loss: 0.76 | time took: 32.05s | validation loss: 2.89
2018-05-20 02:49:05,039 | INFO : Epoch 720 | loss: 0.76 | time took: 32.11s | validation loss: 2.81
2018-05-20 02:54:27,839 | INFO : Epoch 730 | loss: 0.75 | time took: 32.08s | validation loss: 2.98
2018-05-20 02:59:50,629 | INFO : Epoch 740 | loss: 0.74 | time took: 32.08s | validation loss: 2.97
2018-05-20 03:05:13,247 | INFO : Epoch 750 | loss: 0.73 | time took: 32.07s | validation loss: 2.82
2018-05-20 03:10:35,949 | INFO : Epoch 760 | loss: 0.72 | time took: 32.04s | validation loss: 2.81
2018-05-20 03:15:58,732 | INFO : Epoch 770 | loss: 0.72 | time took: 32.07s | validation loss: 2.86
2018-05-20 03:21:21,429 | INFO : Epoch 780 | loss: 0.70 | time took: 32.07s | validation loss: 2.81
2018-05-20 03:26:44,166 | INFO : Epoch 790 | loss: 0.70 | time took: 32.07s | validation loss: 2.73
2018-05-20 03:32:06,811 | INFO : Epoch 800 | loss: 0.69 | time took: 32.07s | validation loss: 2.93
2018-05-20 03:37:29,600 | INFO : Epoch 810 | loss: 0.68 | time took: 32.13s | validation loss: 3.20
2018-05-20 03:42:52,533 | INFO : Epoch 820 | loss: 0.67 | time took: 32.09s | validation loss: 2.85
2018-05-20 03:48:15,239 | INFO : Epoch 830 | loss: 0.66 | time took: 32.05s | validation loss: 2.91
2018-05-20 03:53:38,068 | INFO : Epoch 840 | loss: 0.66 | time took: 32.11s | validation loss: 2.95
2018-05-20 03:59:00,843 | INFO : Epoch 850 | loss: 0.65 | time took: 32.06s | validation loss: 2.93
2018-05-20 04:04:23,524 | INFO : Epoch 860 | loss: 0.64 | time took: 32.05s | validation loss: 3.17
2018-05-20 04:09:46,182 | INFO : Epoch 870 | loss: 0.64 | time took: 32.06s | validation loss: 2.82
2018-05-20 04:15:08,897 | INFO : Epoch 880 | loss: 0.63 | time took: 32.09s | validation loss: 2.85
2018-05-20 04:20:31,599 | INFO : Epoch 890 | loss: 0.63 | time took: 32.07s | validation loss: 2.86
2018-05-20 04:25:54,164 | INFO : Epoch 900 | loss: 0.62 | time took: 32.08s | validation loss: 2.72
2018-05-20 04:31:16,932 | INFO : Epoch 910 | loss: 0.61 | time took: 32.08s | validation loss: 2.88
2018-05-20 04:36:39,548 | INFO : Epoch 920 | loss: 0.61 | time took: 32.05s | validation loss: 2.95
2018-05-20 04:42:02,369 | INFO : Epoch 930 | loss: 0.60 | time took: 32.03s | validation loss: 2.87
2018-05-20 04:47:25,195 | INFO : Epoch 940 | loss: 0.60 | time took: 32.07s | validation loss: 2.83
2018-05-20 04:52:47,975 | INFO : Epoch 950 | loss: 0.59 | time took: 32.13s | validation loss: 2.86
2018-05-20 04:58:10,736 | INFO : Epoch 960 | loss: 0.59 | time took: 32.06s | validation loss: 3.23
2018-05-20 05:03:33,409 | INFO : Epoch 970 | loss: 0.58 | time took: 32.09s | validation loss: 3.18
2018-05-20 05:08:56,194 | INFO : Epoch 980 | loss: 0.58 | time took: 32.10s | validation loss: 3.03
2018-05-20 05:14:18,889 | INFO : Epoch 990 | loss: 0.57 | time took: 32.08s | validation loss: 2.96
2018-05-20 05:19:07,667 | INFO : Total time took: 9.06hrs
2018-05-20 05:19:07,824 | INFO : Done saving model for policy 0

Roll out horizon

Currently the model is just using dynamic RNN, the roll out implemented using raw_rnn is not correct, need to fix that.

The result of using roll-out with raw_rnn horizon=0 should be equivalent to using regular rnn, but the result they produce is different:

  1. Using regular dynamic_rnn

Horizon Tensor("Placeholder_1:0", dtype=int32) ==========
Epoch 0 | loss: 231.87 | time took: 0.72s | validation loss: 158.90
Epoch 100 | loss: 7.85 | time took: 0.54s | validation loss: 10.80
Epoch 200 | loss: 6.92 | time took: 0.54s | validation loss: 10.02
Epoch 300 | loss: 6.06 | time took: 0.54s | validation loss: 8.95
Epoch 400 | loss: 5.92 | time took: 0.54s | validation loss: 8.73
Epoch 500 | loss: 5.38 | time took: 0.54s | validation loss: 8.24
Epoch 600 | loss: 5.32 | time took: 0.54s | validation loss: 8.67
Epoch 700 | loss: 5.01 | time took: 0.55s | validation loss: 8.05
Epoch 800 | loss: 5.33 | time took: 0.54s | validation loss: 9.63
Epoch 900 | loss: 4.78 | time took: 0.54s | validation loss: 7.76
Total time took: 0.15hrs

  1. raw_rnn with horizon=0

Epoch 0 | loss: 229.10 | time took: 0.83s | validation loss: 154.11
Epoch 100 | loss: 16.10 | time took: 0.67s | validation loss: 32.90
Epoch 200 | loss: 11.69 | time took: 0.67s | validation loss: 32.35
Epoch 300 | loss: 11.01 | time took: 0.67s | validation loss: 31.40
Epoch 400 | loss: 8.76 | time took: 0.67s | validation loss: 28.07
Epoch 500 | loss: 7.99 | time took: 0.67s | validation loss: 26.93
Epoch 600 | loss: 7.22 | time took: 0.67s | validation loss: 29.64
Epoch 700 | loss: 8.19 | time took: 0.67s | validation loss: 32.17
Epoch 800 | loss: 10.46 | time took: 0.67s | validation loss: 29.54
Epoch 900 | loss: 9.72 | time took: 0.67s | validation loss: 27.48
Total time took: 0.19hrs

which means there is probably something off in the raw_rnn implementation.

Could not read index 1 twice because it was cleared after a previous read (perhaps try setting clear_after_read = false?)

In JointTraining, the nested policy with nested horizon produces error below:

Wroking on policy 0
Horizon 0 ==========
Epoch 0 | loss: 85.31 | time took: 1.91s | validation loss: 36.48
Total time took: 0.05hrs
Horizon 2 ==========

InvalidArgumentError Traceback (most recent call last)
C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1360 try:
-> 1361 return fn(*args)
1362 except errors.OpError as e:

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
1339 return tf_session.TF_Run(session, options, feed_dict, fetch_list,
-> 1340 target_list, status, run_metadata)
1341

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py in exit(self, type_arg, value_arg, traceback_arg)
515 compat.as_text(c_api.TF_Message(self.status.status)),
--> 516 c_api.TF_GetCode(self.status.status))
517 # Delete the underlying status object from memory otherwise it stays alive

InvalidArgumentError: TensorArray TensorArray_3809: Could not read index 1 twice because it was cleared after a previous read (perhaps try setting clear_after_read = false?).
[[Node: rnn/while/cond/cond/TensorArrayReadV3_2 = TensorArrayReadV3[dtype=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/while/cond/cond/TensorArrayReadV3_1/Switch, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_1, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_2)]]

During handling of the above exception, another exception occurred:

InvalidArgumentError Traceback (most recent call last)
in ()
1 batch_size = 32
----> 2 train_all_single_policies(batch_size, sequence_length, train_game, train_target, test_game, test_target, models_path)

C:\Users\sshi\Desktop\raptors\code\train.py in train_all_single_policies(batch_size, sequence_length, train_game, train_target, test_game, test_target, models_path)
31 for batch in iterate_minibatches(train_game, train_target, batch_size, shuffle=False):
32 train_xi, train_yi = batch
---> 33 p, l, _, train_sum = model.train(train_xi, train_yi, k)
34 model.train_writer.add_summary(train_sum, train_step)
35 epoch_loss += l/n_train_batch

C:\Users\sshi\Desktop\raptors\code\model.py in train(self, train_xi, train_yi, k)
124 def train(self, train_xi, train_yi, k):
125 return self.sess.run([self.pred, self.loss, self.opt, self.train_summary],
--> 126 feed_dict={self.X: train_xi, self.Y: train_yi, self.h: k})
127
128 def validate(self, val_xi, val_yi, k):

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
903 try:
904 result = self._run(None, fetches, feed_dict, options_ptr,
--> 905 run_metadata_ptr)
906 if run_metadata:
907 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1135 if final_fetches or final_targets or (handle and feed_dict_tensor):
1136 results = self._do_run(handle, final_targets, final_fetches,
-> 1137 feed_dict_tensor, options, run_metadata)
1138 else:
1139 results = []

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1353 if handle is None:
1354 return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1355 options, run_metadata)
1356 else:
1357 return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1372 except KeyError:
1373 pass
-> 1374 raise type(e)(node_def, op, message)
1375
1376 def _extend_graph(self):

InvalidArgumentError: TensorArray TensorArray_3809: Could not read index 1 twice because it was cleared after a previous read (perhaps try setting clear_after_read = false?).
[[Node: rnn/while/cond/cond/TensorArrayReadV3_2 = TensorArrayReadV3[dtype=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/while/cond/cond/TensorArrayReadV3_1/Switch, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_1, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_2)]]

Caused by op 'rnn/while/cond/cond/TensorArrayReadV3_2', defined at:
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\runpy.py", line 85, in _run_code
exec(code, run_globals)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel_launcher.py", line 16, in
app.launch_new_instance()
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
app.start()
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\kernelapp.py", line 477, in start
ioloop.IOLoop.instance().start()
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\zmq\eventloop\ioloop.py", line 177, in start
super(ZMQIOLoop, self).start()
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tornado\ioloop.py", line 888, in start
handler_func(fd_obj, events)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
return fn(*args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\zmq\eventloop\zmqstream.py", line 440, in _handle_events
self._handle_recv()
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\zmq\eventloop\zmqstream.py", line 472, in _handle_recv
self._run_callback(callback, msg)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\zmq\eventloop\zmqstream.py", line 414, in _run_callback
callback(*args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
return fn(*args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
return self.dispatch_shell(stream, msg)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\kernelbase.py", line 235, in dispatch_shell
handler(stream, idents, msg)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
user_expressions, allow_stdin)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\ipkernel.py", line 196, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\zmqshell.py", line 533, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2717, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2827, in run_ast_nodes
if self.run_code(code, result):
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2881, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 2, in
train_all_single_policies(batch_size, sequence_length, train_game, train_target, test_game, test_target, models_path)
File "C:\Users\sshi\Desktop\raptors\code\train.py", line 14, in train_all_single_policies
learning_rate=0.01, seq_len=sequence_length-1)
File "C:\Users\sshi\Desktop\raptors\code\model.py", line 85, in init
policy_number=self.policy_number)
File "C:\Users\sshi\Desktop\raptors\code\model.py", line 41, in dynamic_raw_rnn
outputs_ta, last_state, _ = tf.nn.raw_rnn(cell, loop_fn)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn.py", line 1154, in raw_rnn
swap_memory=swap_memory)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3096, in while_loop
result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2874, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2814, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn.py", line 1115, in body
next_time, next_output, cell_state, loop_state)
File "C:\Users\sshi\Desktop\raptors\code\model.py", line 33, in loop_fn
lambda: tf.cond(tf.equal(tf.mod(time, horizon+1), tf.constant(0)),
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py", line 432, in new_func
return func(args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2027, in cond
orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 1868, in BuildCondBranch
original_result = fn()
File "C:\Users\sshi\Desktop\raptors\code\model.py", line 35, in
lambda: tf.concat((inputs_ta.read(time)[:, :policy_number
player_fts],
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py", line 432, in new_func
return func(args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2027, in cond
orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 1868, in BuildCondBranch
original_result = fn()
File "C:\Users\sshi\Desktop\raptors\code\model.py", line 37, in
inputs_ta.read(time)[:, policy_number
player_fts+2:]), axis=1)))
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py", line 58, in fn
return method(self, *args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py", line 58, in fn
return method(self, *args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py", line 58, in fn
return method(self, *args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py", line 861, in read
return self._implementation.read(index, name=name)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py", line 260, in read
name=name)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_data_flow_ops.py", line 4970, in _tensor_array_read_v3
dtype=dtype, name=name)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 3271, in create_op
op_def=op_def)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1650, in init
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): TensorArray TensorArray_3809: Could not read index 1 twice because it was cleared after a previous read (perhaps try setting clear_after_read = false?).
[[Node: rnn/while/cond/cond/TensorArrayReadV3_2 = TensorArrayReadV3[dtype=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/while/cond/cond/TensorArrayReadV3_1/Switch, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_1, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_2)]]

To dos

  1. Add basketball trajectory visualization with players.
  2. Randomize train and test
  3. Double check random in iterbatch
  4. Maybe try a different train and test split e.g. reduce test size to 1: 9

Mulitple players share same (conventional roles)

There are times that several players on the court were all F, initially I was trying to add more empty slots to compensate for this, i.e. create two more players slots for each role,
(p1x,p1y, p2x,p2y, ..., p7x,p7y), so the first 3 are reserved for F then in the case when there are several players sharing the same role then they have the right place to be.

However, after some visualization, I realized that when there are multiple roles then the extra slots are just zeros, this is misleading to the model since the zeros here do not represent anything in terms of trajectory, it represents the team play does not have certain roles.

So which means, for now, it seems like we should just keep the roles to be 5 (this probably has the negative effects on the model learning i.e. when there are several same roles on the court, the role assigning has to assign the duplicated roles to other roles even they are not very similar).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.