Skip to content

Multi table - BasketballMan dataset

In the following we present an example script using the aindo.rdml library to generate synthetic data in the multi table case. We make use of the BasketballMen dataset.

import json
from pathlib import Path

import pandas as pd

from aindo.rdml.eval import compute_privacy_stats, report
from aindo.rdml.relational import Column, ForeignKey, PrimaryKey, RelationalData, Schema, Table
from aindo.rdml.synth import (
    TabularDataset,
    TabularModel,
    TabularPreproc,
    TabularTrainer,
    Validation,
)

# Data and output
DATA_DIR = Path("path/to/data/dir")
OUTPUT_DIR = Path("./output")
# Model settings
MODEL_SIZE = "small"
DEVICE = None  # Device to None means it will be set to CUDA if the latter is available, otherwise CPU
# Training settings
N_EPOCHS = 1_000  # One and only one between N_EPOCHS and N_STEPS should be an integer, and the other should be None.
N_STEPS = None
MEMORY = 4096
VALID_EACH = 200

# Load data and define schema
data = {
    "players": pd.read_csv(DATA_DIR / "players.csv"),
    "season": pd.read_csv(DATA_DIR / "season.csv"),
    "all_star": pd.read_csv(DATA_DIR / "all_star.csv"),
}
schema = Schema(
    players=Table(
        playerID=PrimaryKey(),
        pos=Column.CATEGORICAL,
        height=Column.NUMERIC,
        weight=Column.NUMERIC,
        college=Column.CATEGORICAL,
        race=Column.CATEGORICAL,
        birthCity=Column.CATEGORICAL,
        birthState=Column.CATEGORICAL,
        birthCountry=Column.CATEGORICAL,
    ),
    season=Table(
        playerID=ForeignKey(parent="players"),
        year=Column.INTEGER,
        stint=Column.INTEGER,
        tmID=Column.CATEGORICAL,
        lgID=Column.CATEGORICAL,
        GP=Column.INTEGER,
        points=Column.INTEGER,
        GS=Column.INTEGER,
        assists=Column.INTEGER,
        steals=Column.INTEGER,
        minutes=Column.INTEGER,
    ),
    all_star=Table(
        playerID=ForeignKey(parent="players"),
        conference=Column.CATEGORICAL,
        league_id=Column.CATEGORICAL,
        points=Column.INTEGER,
        rebounds=Column.INTEGER,
        assists=Column.INTEGER,
        blocks=Column.INTEGER,
    ),
)
data = RelationalData(data=data, schema=schema)

# Define preprocessor
preproc = TabularPreproc.from_schema(schema=schema).fit(data=data)

# Split data
split_ratio = 0.1
data_train_valid, data_test = data.split(ratio=split_ratio)
data_train, data_valid = data_train_valid.split(ratio=split_ratio)

# Build model
model = TabularModel.build(preproc=preproc, size=MODEL_SIZE)
model.device = DEVICE

# Train the model
dataset_train = TabularDataset.from_data(data=data_train, preproc=preproc, on_disk=True)
dataset_valid = TabularDataset.from_data(data=data_valid, preproc=preproc)
trainer = TabularTrainer(model=model)
trainer.train(
    dataset=dataset_train,
    n_epochs=N_EPOCHS,
    n_steps=N_STEPS,
    memory=MEMORY,
    valid=Validation(
        dataset=dataset_valid,
        early_stop="normal",
        save_best=OUTPUT_DIR / "best.pt",
        tensorboard=OUTPUT_DIR / "tb",
        each=VALID_EACH,
        trigger="step",
    ),
)

# Generate synthetic data
data_synth = model.generate(
    n_samples=data["players"].shape[0],
    batch_size=512,
)
data_synth.to_csv(OUTPUT_DIR / "synth")

# Compute and print PDF report
report(
    data_train=data_train,
    data_test=data_test,
    data_synth=data_synth,
    path=OUTPUT_DIR / "report.pdf",
)

# Compute extra privacy stats and print some results
privacy_stats = compute_privacy_stats(
    data_train=data_train,
    data_synth=data_synth,
)
privacy_stats_out = {
    t: {
        "privacy_score": ps.privacy_score,
        "privacy_score_std": ps.privacy_score_std,
        "%_points_at_risk": ps.risk * 100,
    }
    for t, ps in privacy_stats.items()
}
with open(OUTPUT_DIR / "privacy_stats.json", mode="w", encoding="utf-8") as f:
    json.dump(privacy_stats_out, f)