VAST Database Python SDK#

Date: 2024-12-27

Prerequisites

> Vast Cluster with 5.2 or greater installed.
> VIP Pool configured and accessible from docker host
> VMS IP accessible from docker host

Install and Import required libraries#

# !pip install vastpy
# !pip install vastdb
# !pip install boto3
import vastdb
from vastpy import VASTClient

import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq
import pandas as pd
import json
from datetime import datetime, timedelta
import pytz
import random
import time
import psutil
import os
import re
import io
import sys
import uuid
import logging
import ftplib
from IPython.display import display, HTML

Define Variables#

Required VMS Information#

vastvms_endpoint = '10.143.11.204'
vms_username = "admin"
vms_password = "123456"
vip_pool_name = "main"
tenant_id = 1    #Default
#
# Create a Unique ID for this run of the Notebook.
#
demo_suffix = str(uuid.uuid4()).split("-")[-1]

Optional Variables when not using the VAST Cluster and Database Setup#

#
# S3 credentials will be set during Cluster / Database Setup
# If you are going to skip that part of the Notebook, provide the information here.
#
S3_ACCESS_KEY = ''
S3_SECRET_KEY = ''
#
# VAST Cluster Endpoint will be determined during Cluster / Database Setup
# If you are going to skip that part of the Notebook, provide the information here.
#
vastdb_endpoint = ''

#
# VAST DB configuration will be set during Cluster / Database Setup
# If you are going to skip that part of the Notebook, provide the information here.
#
vastdb_bucket = ''
vastdb_schema = ''
vastdb_path = ''

SDK Logging#

# Create a logger
logging.basicConfig(
    level=logging.INFO,    
    format='%(asctime)s - %(levelname)s - %(funcName)s - %(message)s', 
    handlers=[
        logging.FileHandler('vastdb_sdk.log', mode = 'a'), 
        logging.StreamHandler()  # Log to console
    ]
)
logger = logging.getLogger()
log = logging.getLogger(__name__)

Define Functions#

def schema_to_python(schema: pq.ParquetSchema) -> pa.Schema:
    """
    Convert a Parquet file's schema to a PyArrow schema compatible with VAST DB.
    Ensures all fields are NOT nullable, as VAST DB does not support nullable fields.
    Raises ValueError if an unsupported Arrow type is encountered.
    """
    # Supported Arrow types in VAST DB
    supported_types = {
        pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64(),
        pa.int8(), pa.int16(), pa.int32(), pa.int64(),
        pa.string(), pa.list_(pa.string()), pa.struct([]), pa.map_(pa.string(), pa.string()),
        pa.bool_(), pa.float32(), pa.float64(), pa.binary(), pa.decimal128(38, 18),
        pa.date32(), pa.timestamp('ns'), pa.time32('s'), pa.time64('ns'), pa.null()
    }
    def process_field(field: pa.Field) -> pa.Field:
        """
        Process a single field to ensure compatibility with VAST DB.
        Recursively handles nested fields.
        """
        field_type = field.type

        # Handle list type
        if pa.types.is_list(field_type): 
            value_field = pa.field("", field_type.value_type)
            return pa.field(field.name, pa.list_(process_field(value_field).type), nullable=True)

        # Handle struct type
        if pa.types.is_struct(field_type):
            struct_fields = [process_field(sub_field) for sub_field in field_type]
            return pa.field(field.name, pa.struct(struct_fields), nullable=True)

        # Handle map type
        if pa.types.is_map(field_type):
            key_field = process_field(pa.field("", field_type.key_type))
            item_field = process_field(pa.field("", field_type.item_type))
            return pa.field(field.name, pa.map_(key_field.type, item_field.type), nullable=True)
        
        # Handle supported primitive types
        if field.type in supported_types:
            return field.with_nullable(True)
        
        # Unsupported type
        raise ValueError(f"Unsupported Arrow type: {field_type} for field '{field.name}'")

    # Process all fields in the schema
    fields = [process_field(schema.field(i)) for i in range(len(schema))]

    # Return the new schema
    return pa.schema(fields)

def login_ftp_site(ftp_site):
    """ Create an FTP session """
    ftp = ftplib.FTP(ftp_site)
    ftp.login()
    return ftp

def get_ftp_file_list(ftp_site, directory='/pub'):
    """ List the files in a given FTP directory. """
    ftp = ftplib.FTP(ftp_site)
    ftp.login()
    ftp.cwd(directory)
    files = ftp.nlst() 
    ftp.quit()
    return files

def download_file_to_memory_and_load(ftp, ftp_file, directory='/pub/'): 
    """ Download a Parquet File from FTP and store as a Parquet Table in memory."""
    ftp.cwd(directory)
    file_path = os.path.join(directory,ftp_file) 
    log.debug(f"Downloading file: {file_path}")
        
    # Retrieve the file in binary mode and store in memory (BytesIO object)
    file_buffer = io.BytesIO()
    ftp.retrbinary(f"RETR {file_path}", file_buffer.write)
    file_buffer.seek(0)
    # Read the Parquet file from memory
    try:
        table = pq.read_table(file_buffer)
        log.debug(f"Loaded {file_path} into memory.")
    except Exception as e:
        log.critical(f"Failed to load {file_path} as Parquet: {e}")
        sys.exit(1)
    return table

def bye_ftp(ftp):
    """ Close FTP connection """
    ftp.quit()

Setup VAST Cluster and Database#

Create VMS Client Connection to Cluster

VASTPY .0.3.3 Documentation

# Connect to the VAST Cluster using VASTPY.
client = VASTClient(vms_username, vms_password, vastvms_endpoint)

Create Identity Policy#

identitypolicy_name = f"SDK-Demo-{demo_suffix}"

def create_identity_policy(client, name, tenant_id):   
    """ Create an Identity Policy for Tabular Access """
    policy = {
        "Version": "2012-10-17",
        "Statement": [
            {
                "Sid": "Read_write_ALL_DB",
                "Action": "s3:Tabular*",
                "Effect": "Allow",
                "Resource": "arn:aws:s3:::*"
            }
        ]
    }
    payload = {
        "name": name,
        "policy": json.dumps(policy, indent=2),  
        "tenant_id": tenant_id,
    }   
    
    try:
        identity_policy = client.s3policies.post(**payload)
        return identity_policy
    except Exception as e:
        log.critical(e)
        raise e


identity_policy   = create_identity_policy(client, identitypolicy_name, tenant_id)
identitypolicy_id = identity_policy['id']
log.info(f"Identity Policy ID: {identitypolicy_id}")

Create Local User and Secret / Access Keys#

api_endpoint = f"{vastvms_endpoint}/api/latest/users/"
local_user_name = f"SDK-Demo-User-{demo_suffix}"
local_user_id = "2147483" + str(random.randint(100, 600))

def create_local_user(client, LocalUser, local_user_id, IdentityPolicy_id, tenant_id): 
    """ Create a Local User on Cluster """
    payload = {"name": LocalUser,
              "uid": local_user_id,
        	  "allow_create_bucket":"true",
        	  "allow_delete_bucket":"false",
        	  "s3_policies_ids":[IdentityPolicy_id]}
    try:
        local_user = client.users.post(**payload)
        return local_user
    except Exception as e:
        log.critical(e)
        raise e

#
# Call VMS API
#
local_user = create_local_user(client, local_user_name, local_user_id, identitypolicy_id, tenant_id)
if local_user:
  local_user_id = local_user['id']
  try:
        response = client.users[local_user_id].access_keys.post(id=local_user_id)
        S3_ACCESS_KEY = response['access_key']
        S3_SECRET_KEY = response['secret_key']
        log.info(f"Local User '{local_user_name}' now has S3 Keys.")
  except Exception as e:
        log.critical(e)
        raise e

Create View Policy for Database#

view_policy_name = f"SDK-ViewPolicy-{demo_suffix}"

def create_view_policy(client, LocalUser, name, vip_pool_id, tenant_id):
    """ Create a View Policy for Database access """
    vip_pool_permission = { vip_pool_id : "RW"}
    payload = {"name": name,
                "flavor" : "S3_NATIVE",
                "permission_per_vip_pool": vip_pool_permission,
                "tenant_id": tenant_id,
                "s3_visibility":[LocalUser]
                }
    try:
        view_policy = client.viewpolicies.post(**payload)
        return view_policy
    except Exception as e:
        log.critical(e)
        raise e
#
# Query for VIP Pool
#
vip_pool_id = None
try: 
    vip_pool = client.vippools.get(name=vip_pool_name)
    vip_pool_id = vip_pool[0]['id']
    vip_pool_ip = vip_pool[0]['start_ip']
    vastdb_endpoint = f"http://{vip_pool_ip}"
except Exception as e:
   log.critical(f"Error during VMS client Query of VIP POOLS: {e}")

if vip_pool_id:
  print(f"The '{vip_pool_name}' VIP Pool was found and has an ID of {vip_pool_id}.")
  print(f"The VAST Database will be accessed using the VIP of {vastdb_endpoint}.") 
  log.info(f"Creating View Policy: {view_policy_name}.")  
  view_policy = create_view_policy(client, local_user_name, view_policy_name, vip_pool_id, tenant_id)
  viewpolicy_id = view_policy['id'] 
    
if view_policy:
  log.info(f"View Policy ID {viewpolicy_id} created.")
else:
  log.critical(f"Failed to create View Policy.")  

Create Database#

vastdb_bucket = f"sdk-demodb-{demo_suffix}"
vastdb_path = f"/{vastdb_bucket}"

def create_database_view(client, vastdb_bucket, viewpolicy_id, vastdb_path,  LocalUser, tenant_id):
    """ Create a VAST Database, assumes directory needs to be created."""
    payload = {"bucket": vastdb_bucket,
               "path": vastdb_path,
               "policy": viewpolicy_id,
               "bucket_owner":LocalUser,
               "policy_id": viewpolicy_id,
               "protocols":["DATABASE","S3"],
               "share_acl":{"acl":[],"enabled":"false"},
               "create_dir":"true",
               "tenant_id": tenant_id}
    try:
        db = client.views.post(**payload)
        return db
    except Exception as e:
        log.critical(e)
        raise e
        
db =  create_database_view(client, 
                          vastdb_bucket, 
                          viewpolicy_id,
                          vastdb_path,
                          local_user_name,
                          tenant_id
            )
if db:
   database_id = db['id']
   log.info(f"The Database was created at {db['created']}")
else:
   log.critical("Failed to created Database view.") 

Basic VAST Database Operations#

Connect to VAST Database#

#
# Establish Session with database
#
session = {}
try:
    session = vastdb.connect(
              endpoint=vastdb_endpoint,
              access=S3_ACCESS_KEY,
              secret=S3_SECRET_KEY
             )
except Exception as e:
    log.critical(e)
if session:
   log.info("VAST DB Session started")
else:
   log.critical("Unable to connect to VAST DB.")

Create Schema#

"""
   The VAST DB SDK transaction object implements the python context manager protocol.
   When the transaction object is created (__enter__) then begin_transaction function is called.
   If any exception is raised in the "with" block then when the block is exited (__exit__)
   rollback_transaction will be called. If no exceptions are raised then commit_transaction
   will be called.
"""
with session.transaction() as tx:
     bucket = tx.bucket(vastdb_bucket)
     new_schema= bucket.create_schema("new_schema",fail_if_exists=True)

Query Schemas from database#

#
# Retrieve the schemas from the VAST DB
#
with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schemas = bucket.schemas()

print(f"The schemas are returned as a list of Schemas from the database {vastdb_bucket}:")

for schema in schemas:
    print(schema.name)
    
   

Query Specific Schema from Database#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
schema    

Create a table in Schema#

#
# Create the pyarrow Schema for the table.
#
address_schema = pa.schema([
    ("street_address", pa.string()),  
    ("city", pa.string()),            
    ("state", pa.string()),           
    ("zip_code", pa.string()),        
    ("country", pa.string()),         
    ("latitude", pa.float64()),       
    ("longitude", pa.float64()),      
])

with session.transaction() as tx:
    log.info(f"Start Transaction id=0x{tx.txid:016x}")
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    table = schema.create_table("new_table", address_schema, fail_if_exists=True)  
    log.info(f"Commit Transaction")
    

Retrieve all of the tables from a given schema#

#
# Retrieve the tables from the VAST DB
#
with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    tables = schema.tables()

print(f"The Tables in schema '{schema.name}' are:")
for table in tables:
    print(table.name)

Retrieve a specific table from a given schema#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    table = schema.table("new_table")
    
print(type(table))    
print(f"The name of the table is: {table.name}")

Create a record for a given table#

# Create a single record
# Convert data to a Apache Arrow Table using the Address schema above.

arrow_table = pa.table(schema = address_schema, data =  [
    ["123 Main St"],       # street_address
    ["San Francisco"],     # city
    ["CA"],                # state
    ["94105"],             # zip_code
    ["USA"],               # country
    [37.7749],             # latitude
    [-122.4194]            # longitude
])

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)    # Start a transaction for this Database
    schema = bucket.schema("new_schema") # Get the Schema for the database to use
    table = schema.table("new_table")    # Get the table from the Schema.
    table.insert(arrow_table)            # Insert the data from the Arrow table into the VAST DB table.

Query all records from a table#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    table = schema.table("new_table")
    reader = table.select()              # Store all the VAST DB records in an Arrow Record Batch Reader Object
    result = reader.read_all()           # Place the contents of the Select into an Arrow Table

df = result.to_pandas()
#
# Print Results in a Pandas Dataframe
df.head(5)

Create 10,000 records in table#

# Generate random data for 10,000 records
def generate_random_address():
    streets = ["Main St", "Smith St", "Briaroaks Rd", "Arbor Forest Ln", "Cindy St"]
    cities = ["Tampa", "Los Angeles", "Las Vegas", "Houston", "Wooster"]
    states = ["FL", "CA", "NV", "TX", "OH"]
    countries = ["USA"]
    
    street_address = f"{random.randint(1, 9999)} {random.choice(streets)}"
    city = random.choice(cities)
    state = random.choice(states)
    zip_code = f"{random.randint(10000, 99999)}"
    country = random.choice(countries)
    latitude = random.uniform(-90, 90)
    longitude = random.uniform(-180, 180)
    
    return street_address, city, state, zip_code, country, latitude, longitude

# Create lists for each column
street_addresses, cities, states, zip_codes, countries, latitudes, longitudes = [], [], [], [], [], [], []

# Create 10K entires for each list.
for _ in range(10000):
    record = generate_random_address()
    street_addresses.append(record[0])
    cities.append(record[1])
    states.append(record[2])
    zip_codes.append(record[3])
    countries.append(record[4])
    latitudes.append(record[5])
    longitudes.append(record[6])

# Create a pyarrow Table and populate with random data.
address_table = pa.Table.from_arrays(
    [
        pa.array(street_addresses, type=pa.string()),
        pa.array(cities, type=pa.string()),
        pa.array(states, type=pa.string()),
        pa.array(zip_codes, type=pa.string()),
        pa.array(countries, type=pa.string()),
        pa.array(latitudes, type=pa.float64()),
        pa.array(longitudes, type=pa.float64()),
    ],
    schema=address_schema
)
mem = psutil.virtual_memory()
print(f"Available memory: {mem.available / (1024**3):.2f} GB")
print(f"Total Memory used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
print(f"The Apache Arrow 'address table' contains {address_table.num_rows} records and is using {address_table.nbytes:,} bytes in memory.")

#
# Insert the Records into VAST DB.
#
start = time.time()
with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)    # Start a transaction for this Database
    schema = bucket.schema("new_schema") # Get the Schema for the database to use
    table = schema.table("new_table")    # Get the table from the Schema.
    table.insert(address_table)  
end=time.time()
log.info("Dataset load took {0:.2f} seconds to run.".format(end-start))    

Query Table using Filters (Predicate pushdown)#

Comparison Operator#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    table = schema.table("new_table")
    reader = table.select(predicate=(table['state'] == 'TX'))  # Store all the VAST DB records in a Arrow Record Batch Reader Object
    result = reader.read_all()                                 # Place the contents of the Select into a Arrow Table

TX_df = result.to_pandas()

record_count = len(TX_df)
print(f"Total number of records with the state of 'TX':  {record_count}")
# Display the top 5 records
print("Top 5 records:")
TX_df.head(5)

“IS IN” Operator#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    table = schema.table("new_table")
    reader = table.select(predicate=(table['state'].isin(['TX','OH'])))    # Store all the VAST DB records in a Arrow Record Batch Reader Object
    result = reader.read_all()                                             # Place the contents of the Select into a Arrow Table

df = result.to_pandas()

record_count = len(df)
print(f"Total number of records with the state of 'TX' or 'OH':  {record_count}")
# Display the top 5 records
print("Top 5 records:")
df.head(5)

Substring Match#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    table = schema.table("new_table")
    reader = table.select(predicate=(table['street_address'].contains('St')))   # Store all the VAST DB records in a Arrow Record Batch Reader Object
    result = reader.read_all()                                                  # Place the contents of the Select into a Arrow Table

df = result.to_pandas()

record_count = len(df)
print(f"Total number of records with the string 'St' in the street address:  {record_count}")
# Display the top 5 records
print("Top 5 records:")
df.head(5)

Between#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    table = schema.table("new_table")
    reader = table.select(predicate=(table['latitude'].between(-30,9)))      # Store all the VAST DB records in a Arrow Record Batch Reader Object
    result = reader.read_all()                                               # Place the contents of the Select into a Arrow Table

df = result.to_pandas()

record_count = len(df)
print(f"Total number of records with a latitude between -30 and 9:  {record_count}")

min_value = df["latitude"].min()
max_value = df["latitude"].max()
print(f"The minimum value for latitude is {min_value} and the naximum is {max_value}")

# Display the top 5 records
print("Top 5 records:")
df.head(5)

Create a snapshot for the Database#

snapshot_name = f"SDK-DEMO-{demo_suffix}"
utc_now = datetime.now(pytz.utc)
new_time = utc_now + timedelta(minutes=30) # Set Snapshot expiration time + 30 minutes.
expiration_time = new_time.strftime("%Y-%m-%dT%H:%M:%S.") + f"{new_time.microsecond // 1000:03d}Z"
indestructible = False
tenant_id = 1

def create_snapshot(client, name, path, expiration_time, indestructible, tenant_id):
    """ Create a local Snapshot """
    payload = {
        "name": name,
        "path": path,
        "expiration_time": expiration_time,
        "indestructible": indestructible,
        "tenant_id": tenant_id,
    }
    try:
        snapshot = client.snapshots.post(**payload)
        return snapshot
    except Exception as e:
        print(e)
        raise e

# Call VMS API
#
snapshot = create_snapshot(client, 
                          snapshot_name, 
                          vastdb_path, 
                          expiration_time, 
                          indestructible, 
                          tenant_id
            )
if snapshot:
   log.info(f"Snapshot Status: {snapshot['state']}")
   df = pd.DataFrame([snapshot]) 
   display(df) 
else:
   log.error("Failed to create snapshot.")     

Delete Records from Table#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    table = schema.table("new_table")
    reader = table.select(columns=[], 
                          internal_row_id=True,
                          predicate=table['state'].contains('TX'))
    results = reader.read_all() # Are there any ROWIDs to delete?
    if results.num_rows > 0 :
        print(f"Deleting {results.num_rows} rows from {table.name}")
        table.delete(results)
    else:
        print("no records found!")

Query Snapshots from DB#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    snapShots = bucket.snapshots()
for snap in snapShots:
    snapname = snap.name.split("/")[-1]
    print(snapname)

Query Data from Database snapshot#

df={}
snapshot_table = {}
with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    snapShot = bucket.snapshot(snapshot_name)
    snapschema = snapShot.schema("new_schema")
    snapshot_tables = snapschema.tables()
    snapshot_table = snapschema.table("new_table", fail_if_missing = False)
    if snapshot_table:
      reader = snapshot_table.select(predicate=(snapshot_table['state'] == 'TX')) 
      snapshot_result = reader.read_all()
    else:
      log.error(f"Table 'new_table' was not found in snapshot '{snapshot_name}'.")  
    
# Print Results in a Pandas Dataframe

if snapshot_tables:
    print(f"Tables in the snapshot schema '{schema.name}' are:")
    for table in snapshot_tables:
        print(table.name)
    
    print(f"\nTable '{snapshot_table.name}' has {snapshot_result.num_rows} rows with TX as the state.")        
else:
    print("No Tables in Schema")       

Recover Data from Snapshot#

snapschema = {}
snapshot_table = {}
start = time.time()

with session.transaction() as tx:
    print(f"Start Transaction id=0x{tx.txid:016x}")
    #
    # Query Data from Snapshot
    #
    bucket = tx.bucket(vastdb_bucket)
    snapShot = bucket.snapshot(snapshot_name)    
    snapschema = snapShot.schema("new_schema",fail_if_missing = False)
    if snapschema:
       snapshot_table = snapschema.table("new_table", fail_if_missing = False)
       if snapshot_table: 
          reader = snapshot_table.select(predicate=(snapshot_table['state'] == 'TX'))
          snapshot_result = reader.read_all()
       else:
          raise ValueError(f"The table 'new_table' was not found in snapshot '{snapshot_name}'.")
          
    else:
       raise ValueError(f"The schema 'new_schema' was not found in snapshot '{snapshot_name}'.")
       
    #
    # Write data into live Database
    #
    schema = bucket.schema("new_schema")
    table = schema.create_table("recovered_table", snapshot_result.schema , fail_if_exists=True)  
    print(f"Inserting {snapshot_result.num_rows} into new table 'recovered_table'")
    table.insert(snapshot_result) 
    print(f"Commit Transaction")
end=time.time()
print("Recovery took {0:.2f} seconds to run.".format(end-start))   

Read Recovered Data#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    table = schema.table("recovered_table")
    reader = table.select()              # Store all the VAST DB records in an Arrow Record Batch Reader Object
    result = reader.read_all()           # Place the contents of the Select into an Arrow Table

df = result.to_pandas()
#
# Print Results in a Pandas Dataframe
print(f"The results have the following shape {df.shape}.")
df.head(5)

Import Data via FTP from parquet#

Open Targets Platform Data Download Documentation

# 
# Define FTP site and data location
#
ftp_site = 'ftp.ebi.ac.uk' 
ftp_directory = '/pub/databases/opentargets/platform/24.09/output/etl/parquet/diseases'
# Tabel name for imported data
vastdb_table  = 'Disease_Phenotype'
vastdb_schema = 'new_schema'

spinner = ['/', '-', '\\', '|']

Create Table if needed for Import#

#
# Retrieve the tables from the VAST DB
# Determine if we will create the table.
#
with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema(vastdb_schema)
    vastdb_tables = schema.tables(table_name=vastdb_table)
    
#
# Connect to FTP and get the list of Parquet Files.
#
processed_parquet=0
parquet_files = get_ftp_file_list(ftp_site,ftp_directory)
parquet_files = list(filter(lambda file: file.endswith('.parquet'), parquet_files))
print(f" The FTP site has {len(parquet_files)} parquet files that will be imported.")
if vastdb_tables:        
    log.info(f"Schema for {vastdb_table} already exists.")
else:
    print(f"Creating table '{vastdb_table}' om schema '{vastdb_schema}'")
    #
    # Download the first parquet file
    #
    ftp_session = login_ftp_site(ftp_site)
    pq_table = download_file_to_memory_and_load(ftp_session, parquet_files[0], ftp_directory)
    bye_ftp(ftp_session)
    #
    # Create a compatible Table Schema. VAST DB doesnt support "not nullable" in 5.2
    #
    new_table = schema_to_python(pq_table.schema)
    
    #
    # Create the table in VAST DB
    #
    with session.transaction() as tx:
        bucket = tx.bucket(vastdb_bucket)
        schema = bucket.schema(vastdb_schema)
        table = schema.create_table(vastdb_table, new_table)      
        log.info(f"Schema for {vastdb_table} created.")

Load Parquet Data into VAST DB#

#
# Start FTP Session 
#
MAX_FILE = 10
count_file = 0
ftp_session = login_ftp_site(ftp_site)
spin = 0 
rows_loaded = 0
rows_skipped = 0 
files_skipped = []

for prq_file in parquet_files:
    count_file +=1
    if count_file > MAX_FILE:
        break
    spin += 1
    if spin > len(spinner) - 1:
      spin = 0
    
    # Provide feedback to user on progress
    sys.stdout.write(f'\r{spinner[spin]}')  
    sys.stdout.flush()  
    pq_table = download_file_to_memory_and_load(ftp_session, prq_file, ftp_directory)
    # Start transaction on the vastdb_table.
    with session.transaction() as tx:           
        bucket = tx.bucket(vastdb_bucket)
        schema = bucket.schema(vastdb_schema)
        table  = schema.table(vastdb_table) 
        if pq_table.schema != new_table:
            log.debug("Schema mismatch, casting to new schema!")
            # compare_schemas(pq_table.schema, new_table)
            a_table = pq_table.cast(new_table, safe=False)
        else: 
            a_table = pq_table
        try:
            table.insert(a_table)    
            log.debug(f'Loaded {pq_table.shape[0]} rows into table {vastdb_table} from {prq_file}') 
            rows_loaded = rows_loaded + pq_table.shape[0]
        except Exception as e:
            log.critical(f'{pq_table.shape[0]} rows failed to load in table {vastdb_table} from {prq_file}')    
            rows_skipped = rows_skipped + pq_table.shape[0]
            files_skipped.append(prq_file)
                
        

#
# Close FTP Session.
#
bye_ftp(ftp_session)    
#
# Report Stats
#
with session.transaction() as tx:           
     bucket = tx.bucket(vastdb_bucket)
     schema = bucket.schema(vastdb_schema)
     table  = schema.table(vastdb_table) 
     table_stats = table.get_stats()
        
print(f'{rows_loaded} rows were loaded into the VASTDB table {vastdb_table}')
if rows_skipped:
    print(f'{rows_skipped} rows were skipped due to errors from these files:')
    print(files_skipped)

print(f"\nVAST Database Table Stats for table {vastdb_table}")
print(f"Endpoints   : {table_stats.endpoints} ")
print(f"Number of Rows: {table_stats.num_rows}")
print(f"Size of table in bytes: {table_stats.size_in_bytes}")

Query Struct Datatype#

pd.set_option("display.max_columns", None)

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema(vastdb_schema)
    table = schema.table(vastdb_table)
    reader = table.select()              
    result = reader.read_all()           

print(f"The results are returned as '{type(result)}'.")
#
# Build a Filter Condidition using a field from one of the STRUCT columns
#
filter_condition = pc.equal(pc.struct_field(result['ontology'],'isTherapeuticArea'),True)

filtered_table = result.filter(filter_condition)
df = filtered_table.to_pandas()

#
# Create a new column to store the value of ontology.
#
df['ontology_isTherapeuticArea'] = df['ontology'].apply(lambda struct: struct['isTherapeuticArea'])
#
# Print Results in a Pandas Dataframe
display_columns = ['id', 'name', 'code', 'ontology', 'ontology_isTherapeuticArea','parents', 'ancestors']
df_display = df[display_columns]
display(HTML(df.to_html(notebook=True)))
#df_display.sample(n=5)

UPSERT Example#

Create Census Data Table and Load Data#

#
# Read in Census Data from local Parquet File
#
""" Census data was created using https://app.mostly.ai """

pq_table = pq.read_table('census_data.parquet')
census_table_name = 'census_data'
#
# Create a Arrow Table Schema for the data.
#
census_data = schema_to_python(pq_table.schema)

#
# Create the table in VAST DB
#
start=time.time()
with session.transaction() as tx:
     bucket = tx.bucket(vastdb_bucket)
     schema = bucket.schema(vastdb_schema)
     table = schema.create_table(census_table_name, census_data) 
     table.insert(pq_table) 

end=time.time()

Msg = f"Table '{census_table_name}' created and {pq_table.num_rows} records loaded in {(end-start):.2f}."
log.info(Msg)
print(Msg)

Create 2 new records#

original_census_data = pq_table.slice(0,5) # Keep only 5 records from the original dataset.
original_census_schema = pq_table.schema.remove_metadata()
#
# Convert Parquet table to Pandas Dataframe.
#
df_original_census_data = original_census_data.to_pandas()

# Data for two new records
new_data = {
    "age": [25, 38],
    "workclass": ["Private", "Self-emp-not-inc"],
    "fnlwgt": [226802, 234721],
    "education": ["Bachelors", "Masters"],
    "education_num": [17, 19],
    "marital_status": ["Never-married", "Married-civ-spouse"],
    "occupation": ["Tech-support", "Exec-managerial"],
    "relationship": ["Not-in-family", "Husband"],
    "race": ["White", "White"],
    "sex": ["Female", "Male"],
    "capital_gain": [0, 0],
    "capital_loss": [0, 0],
    "hours_per_week": [40, 50],
    "native_country": ["United-States", "Spain"],
    "income": [">50K", "<=50K"],
    "SSN": ["123-45-6789", "987-65-4321"]
}
df_new_census_records = pd.DataFrame(new_data)

#
# Concatenate the new records and the first 5 records of the original data.
#
df_updated_census_data = pd.concat([df_new_census_records,df_original_census_data], ignore_index=True)
#
# Display the new dataframe with the 7 records.
#
df_updated_census_data [['SSN', 'education_num','occupation','age']]

Update 2 existing records in a new Dataframe#

#
# Update the occupation and age for 2 records.
#
df_updated_census_data.loc[3,'occupation'] = 'Forestry'
df_updated_census_data.loc[3,'age'] += 1
df_updated_census_data.loc[4,'occupation'] = 'Farming'
df_updated_census_data.loc[4,'age'] += 1
#
# Convert Pandas Dataframe to Arrow Table
#
at_updated_census_data = pa.Table.from_pandas(pd.DataFrame(df_updated_census_data), schema=original_census_schema)
# at_updated_census_data = at_updated_census_data.replace_schema_metadata(None)
updated_ssns=df_updated_census_data['SSN'].tolist()
print("Showing new records and the two that were updated:")
df_updated_census_data.iloc [[0,1,3,4],[15,4,6,0]]

Perform UPSERT#

UPSERT Functions#
def upsert_record(table, record, index):
    """Process a single record for insertion or update."""
    ssn_value = record['SSN'][0].as_py()
    reader = table.select(predicate=(table['SSN'] == ssn_value), internal_row_id=True)
    result = reader.read_all()

    if result.num_rows == 0:
        print(f"Insert new record with SSN {ssn_value}, index {index}.")
        table.insert(record)
    elif result.num_rows == 1:
        handle_update(table, record, result, ssn_value)
    else:
        print(f"Expected a unique row for SSN {ssn_value}, but got {result.num_rows} rows.")

def handle_update(table, record, existing_record, ssn_value):
    """Handle updates to an existing record."""
    update_columns = ['age', 'occupation']
    needs_update = any(
        not pc.all(pc.equal(record[col], existing_record[col])).as_py()
        for col in update_columns
    )

    if needs_update:
        print(f"Update is required for SSN {ssn_value} record.")
        row_id = existing_record['$row_id']
        updated_row = record.append_column('$row_id', row_id)
        table.update(updated_row, update_columns)
    else:
        print(f"No update needed for SSN {ssn_value} record.")
        
UPSERT Operation#
with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema(vastdb_schema)
    table = schema.table(census_table_name)

    for i in range(at_updated_census_data.num_rows): # for each of the updated / new records.
       record =  at_updated_census_data.slice(i,1) # create a single record arrow table.
       upsert_record(table, record, i)
    

Query Updated data#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema(vastdb_schema)
    table = schema.table(census_table_name)
    reader = table.select(predicate=(table['SSN'].isin(updated_ssns))) 
    result = reader.read_all()           # Place the contents of the Select into an Arrow Table

df = result.to_pandas()
#
# Print Results in a Pandas Dataframe
df[['SSN', 'education_num','occupation','age']]

Cleanup Cluster#

Delete Tables#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    tables = schema.tables()
    for table in tables: 
        print(f"Deleting table '{table.name}' in schema '{schema.name}' for database '{vastdb_bucket}'.")
        table.drop()
    # Query new_schema for tables
    tables = schema.tables()
    
if tables:
    print(f"The following Tables are in the Schema '{schema.name}':")
    for table in tables:
        print(table.name)
else:
    print("No Tables in Schema")  

Delete Schema#

with session.transaction() as tx:
    bucket = tx.bucket(vastdb_bucket)
    schema = bucket.schema("new_schema")
    schema.drop()
    schemas = bucket.schemas()

print(f"The schemas in the database {vastdb_bucket} are:")
for schema in schemas:
    print(schema.name)
       

Delete Database#

try:
    client.views[database_id].delete()
except Exception as e:
        print(e)
    

Delete View Policy#

try:
    client.viewpolicies[viewpolicy_id].delete()
except Exception as e:
        print(e)

Delete SDK Demo User#

try:
    client.users[local_user_id].delete()
except Exception as e:
        print(e)

Delete Identity Policy#

try:
    client.s3policies[identitypolicy_id].delete()
except Exception as e:
        print(e)
"""
      END of Script
"""