雑記

この記事は参考文献から不使用コードを削ったり行間を詰めて全体を眺めやすくしたりしているだけで、参考文献のコードと同じです。目次

Traffic データのロードと862変量時系列化

from datasets import load_dataset
import pandas as pd
import numpy as np
from functools import lru_cache
from functools import partial
from gluonts.dataset.multivariate_grouper import MultivariateGrouper

@lru_cache(10_000)
def convert_to_pandas_period(date, freq):
    return pd.Period(date, freq)

def transform_start_field(batch, freq):
    batch['start'] = [convert_to_pandas_period(date, freq) for date in batch['start']]
    return batch


dataset = load_dataset('monash_tsf', 'traffic_hourly')
train_dataset = dataset['train']
test_dataset = dataset['test']
num_of_variates = len(train_dataset)  # 862
freq = '1H'
prediction_length = 48

train_dataset.set_transform(partial(transform_start_field, freq=freq))
test_dataset.set_transform(partial(transform_start_field, freq=freq))

train_grouper = MultivariateGrouper(max_target_dim=num_of_variates)
test_grouper = MultivariateGrouper(max_target_dim=num_of_variates, num_test_dates=1)
multi_variate_train_dataset = train_grouper(train_dataset)
multi_variate_test_dataset = test_grouper(test_dataset)

時間的特徴を付加する関数の用意

from transformers import PretrainedConfig
from gluonts.dataset.field_names import FieldName
from gluonts.transform import *

def create_transformation(freq: str, config: PretrainedConfig) -> Transformation:
    remove_field_names = []
    if config.num_static_real_features == 0:
        remove_field_names.append(FieldName.FEAT_STATIC_REAL)
    if config.num_dynamic_real_features == 0:
        remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
    if config.num_static_categorical_features == 0:
        remove_field_names.append(FieldName.FEAT_STATIC_CAT)
    return Chain(
        [RemoveFields(field_names=remove_field_names)]
        + ([AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=int)]
               if config.num_static_categorical_features > 0 else [])
        + ([AsNumpyArray(field=FieldName.FEAT_STATIC_REAL, expected_ndim=1)]
               if config.num_static_real_features > 0 else [])
        + [AsNumpyArray(field=FieldName.TARGET, expected_ndim=1 if config.input_size == 1 else 2),
           AddObservedValuesIndicator(target_field=FieldName.TARGET,
                                      output_field=FieldName.OBSERVED_VALUES),
           AddTimeFeatures(start_field=FieldName.START, target_field=FieldName.TARGET,
                           output_field=FieldName.FEAT_TIME,
                           time_features=time_features_from_frequency_str(freq),
                           pred_length=config.prediction_length),
           AddAgeFeature(target_field=FieldName.TARGET, output_field=FieldName.FEAT_AGE,
                         pred_length=config.prediction_length, log_scale=True),
           VstackFeatures(output_field=FieldName.FEAT_TIME,
                          input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] + \
                            ([FieldName.FEAT_DYNAMIC_REAL]
                               if config.num_dynamic_real_features > 0 else [])),
           RenameFields(mapping={
               FieldName.FEAT_STATIC_CAT: "static_categorical_features",
               FieldName.FEAT_STATIC_REAL: "static_real_features",
               FieldName.FEAT_TIME: "time_features",
               FieldName.TARGET: "values",
               FieldName.OBSERVED_VALUES: "observed_mask"}),
          ]
    )

ウィンドウをサンプリングする関数の用意

from gluonts.transform.sampler import InstanceSampler
from typing import Optional

def create_instance_splitter(
    config: PretrainedConfig, mode: str,
    train_sampler: Optional[InstanceSampler] = None,
    validation_sampler: Optional[InstanceSampler] = None,
) -> Transformation:
    instance_sampler = {
        "train": train_sampler
        or ExpectedNumInstanceSampler(
            num_instances=1.0, min_future=config.prediction_length),
        "validation": validation_sampler
        or ValidationSplitSampler(min_future=config.prediction_length),
        "test": TestSplitSampler(),
    }[mode]
    return InstanceSplitter(
        target_field="values", is_pad_field=FieldName.IS_PAD,
        start_field=FieldName.START, forecast_start_field=FieldName.FORECAST_START,
        instance_sampler=instance_sampler,
        past_length=config.context_length + max(config.lags_sequence),
        future_length=config.prediction_length,
        time_series_fields=["time_features", "observed_mask"],
    )

データローダーの用意

import torch
from typing import Iterable
from gluonts.itertools import Cached, Cyclic
from gluonts.dataset.loader import as_stacked_batches

def create_train_dataloader(
    config: PretrainedConfig, freq, data, batch_size: int,
    num_batches_per_epoch: int, shuffle_buffer_length: Optional[int] = None,
    cache_data: bool = True, **kwargs,
) -> Iterable:
    PREDICTION_INPUT_NAMES = [
        "past_time_features", "past_values", "past_observed_mask", "future_time_features"]

    if config.num_static_categorical_features > 0:
        PREDICTION_INPUT_NAMES.append("static_categorical_features")
    if config.num_static_real_features > 0:
        PREDICTION_INPUT_NAMES.append("static_real_features")
    TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + ["future_values", "future_observed_mask"]

    transformation = create_transformation(freq, config)
    transformed_data = transformation.apply(data, is_train=True)
    if cache_data:
        transformed_data = Cached(transformed_data)

    instance_splitter = create_instance_splitter(config, "train")
    stream = Cyclic(transformed_data).stream()
    training_instances = instance_splitter.apply(stream, is_train=True)
    
    return as_stacked_batches(
        training_instances, batch_size=batch_size, shuffle_buffer_length=shuffle_buffer_length,
        field_names=TRAINING_INPUT_NAMES, output_type=torch.tensor,
        num_batches_per_epoch=num_batches_per_epoch,
    )

モデルの生成と1バッチへの動作確認

from transformers import InformerConfig, InformerForPrediction
from gluonts.time_feature import time_features_from_frequency_str

time_features = time_features_from_frequency_str(freq)
config = InformerConfig(
    input_size=num_of_variates, prediction_length=prediction_length,
    context_length=prediction_length * 2, lags_sequence=[1, 24 * 7],
    num_time_features=len(time_features) + 1,
    dropout=0.1, encoder_layers=6, decoder_layers=4, d_model=64,
)
model = InformerForPrediction(config)

train_dataloader = create_train_dataloader(
    config=config, freq=freq, data=multi_variate_train_dataset,
    batch_size=256, num_batches_per_epoch=100,
)
batch = next(iter(train_dataloader))
outputs = model(
    past_values=batch["past_values"],
    past_time_features=batch["past_time_features"],
    past_observed_mask=batch["past_observed_mask"],
    static_categorical_features=batch["static_categorical_features"]
      if config.num_static_categorical_features > 0 else None,
    static_real_features=batch["static_real_features"]
      if config.num_static_real_features > 0 else None,
    future_values=batch["future_values"],
    future_time_features=batch["future_time_features"],
    future_observed_mask=batch["future_observed_mask"],
    output_hidden_states=True,
)
print("Loss:", outputs.loss.item())