跳转至

Code Explanation

This is a sophisticated PyTorch script designed for financial time-series forecasting using a Temporal Graph Neural Network (TAGN).

Specifically, this implementation (tagn2d_imp.py) is a "Volatility Sieve" (波动率大筛子). Unlike standard price prediction models, it does not try to predict the exact price tomorrow. Instead, it tries to classify whether a stock will experience significant volatility (a large price swing up or down) in the near future.

Here is a detailed review of the training and validation processes structured into this code.


1. The Goal: Predicting High Volatility

Before diving into the loops, it is crucial to understand what the model is training to predict.

In feature_engineering, the labels (Y_tensor) are generated as follows: 1. It looks ahead PREDICTION_WINDOW days (set to 20 days in config). 2. It calculates the maximum percentage rise and maximum percentage fall within that future window. 3. The Condition: If the max rise exceeds LABEL_THRESHOLD (9%) OR the max fall exceeds negative LABEL_THRESHOLD (-9%), the label is set to 1 (Volatile). Otherwise, it is 0 (Calm).

The training objective is a binary classification task: Will this stock experience a >9% swing in the next 20 days?


2. Pre-Training Preparation (Crucial Context)

The training and validation loops rely heavily on how the data is prepared to avoid "data leakage" (cheating by seeing the future).

A. Split-Aware Normalization

In feature_engineering, the split_aware_zscore function handles data normalization carefully: * Training Data (first 80%): Normalized using mean and standard deviation calculated only from that training period. These stats are static. * Validation Data (last 20%): Normalized using an expanding window. For validation day t, it calculates stats using data from day 0 up to day t. This simulates real-time trading where you only know past history.

B. Temporal Splitting

In main(), the data tensors (X, Y) and the list of graph structures (graphs) are split temporally:

train_split = int(total_days * TRAIN_RATIO) # e.g., Day 0 to Day 1999
val_days = total_days - train_split         # e.g., Day 2000 to end
The training loop only sees the first part; the validation loop only evaluates the second part.

C. Class Imbalance Handling

High volatility events are rare (a minority class). If trained normally, the model would just always predict "0" (Calm) and achieve high accuracy but fail at its job. * In main(), it calculates the ratio of negative to positive samples in the training set. * It assigns a higher pos_weight to the "Volatile" class in the nn.CrossEntropyLoss. This forces the model to pay more attention to rare volatility events during training.


3. Detailed Training Process (train_epoch)

The training uses Truncated Backpropagation Through Time (TBPTT) to handle long sequences without running out of memory or suffering from vanishing gradients.

The Flow:

  1. Initialization:

    • Set model to model.train() (enables dropout).
    • Initialize the RNN hidden state hidden to zeros. This state will carry temporal memory forward.
  2. Outer Loop (Batch Chunks / TBPTT):

    • The code iterates through time in chunks defined by BATCH_SIZE (e.g., 64 days at a time).
    • for start_t in range(0, train_days, BATCH_SIZE):
  3. Inner Loop (Step-by-Step within Chunk):

    • Inside a chunk, it iterates one day at a time (t).
    • Data Loading: Loads features X[t] and the correlation graph graphs[t] for that specific day to the device (GPU/MPS).
    • Forward Pass: out, hidden = model(X[t], edge_index, hidden)
      • The GAT layer uses edge_index to let stocks exchange information with highly correlated neighbors (Spatial step).
      • The GRU cell takes the GAT output and the previous day's hidden state to compute the new hidden state and output (Temporal step).
    • Loss Calculation: Computes weighted CrossEntropyLoss between prediction out and true label Y[t].
    • Accumulation: The loss for this specific day is added to batch_loss.
  4. End of Chunk (Backpropagation):

    • After processing 64 days (the BATCH_SIZE), it averages the accumulated loss.
    • batch_loss.backward(): Computes gradients relative to the weights. Because of the continuous hidden state in the inner loop, gradients flow back through the 64 time steps.
    • Gradient Clipping: torch.nn.utils.clip_grad_norm_(...). This is vital in RNNs to prevent gradients from becoming too large and destabilizing training.
    • The Truncation Step: hidden = hidden.detach(). This is the most critical line for TBPTT. It tells PyTorch: "Stop tracking gradients for the hidden state here." The hidden state values are kept for the next batch chunk so temporal context continues, but the gradient history is cut off to save memory.
    • optimizer.step(): Updates model weights.

4. Detailed Validation Process (validate_epoch)

The validation process has a critical requirement when dealing with RNNs/GRUs: Warm-up.

The Flow:

  1. Initialization:

    • Set model to model.eval() (disables dropout).
    • Use torch.no_grad() context to stop storing gradients (saves memory and speeds up inference).
    • Initialize hidden state to zeros.
  2. The Warm-up Phase (Crucial Step):

    • Before evaluating any validation data, the model must process the entire training sequence sequentially.
    • for t in range(train_days): ... model(X[t], ...)
    • Why? The GRU's hidden state at the start of the validation period (e.g., day 2000) depends on what happened on days 0-1999. If you just start predicting at day 2000 with a zero hidden state, the model lacks temporal context and predictions will be junk.
    • During this phase, outputs are ignored, and loss is not calculated. Only the hidden state is updated.
  3. The Evaluation Phase:

    • It iterates from train_days to the end of the data.
    • It performs the forward pass using the warmed-up hidden state.
    • It calculates loss for monitoring purposes.
    • It compares predictions vs. targets (Y[t]) to compute accuracy and F1-score.

Summary of Key Differences

Feature Training (train_epoch) Validation (validate_epoch)
Mode model.train() (Dropout on) model.eval() (Dropout off)
Gradients Tracked (backward() called) Not tracked (torch.no_grad())
Time Flow Iterates through training split only. Warm-up: Runs through training split. Eval: Runs through validation split.
Hidden State detach() every BATCH_SIZE steps (TBPTT). Never detached. Runs continuously from t=0 to end.
Weights Updated by optimizer. Frozen.