Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update checkpoint.py #336

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 68 additions & 5 deletions checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
@contextlib.contextmanager
def copy_to_shm(file: str):
if file.startswith("/dev/shm/"):
# Nothing to do, the file is already in shared memory.
yield file
return

Expand Down Expand Up @@ -81,7 +80,6 @@ def fast_pickle(obj: Any, path: str) -> None:


def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
"""Loads a set of arrays."""
pool = ThreadPoolExecutor(max_workers=32)
fs = list()
num_tensors = 0
Expand Down Expand Up @@ -124,13 +122,11 @@ def get_load_path_str(
load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None,
) -> Optional[str]:
# Exclusion
if load_exclude_rules is not None:
for search_pattern in load_exclude_rules:
if re.search(search_pattern, init_path_str):
return None

# Renaming
load_path_str = init_path_str
if load_rename_rules is not None:
for search_pattern, replacement_pattern in load_rename_rules:
Expand Down Expand Up @@ -197,7 +193,6 @@ def restore(

state = jax.tree_util.tree_unflatten(structure, loaded_tensors)

# Sanity check to give a better error message.
ckpt_keys = set(state.params.keys())
code_keys = set(state_sharding.params.keys())

Expand All @@ -219,3 +214,71 @@ def restore(
if params_only:
state = state.params
return state

# Database and machine learning integration
import sqlite3
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import pandas as pd

def create_database():
conn = sqlite3.connect('data_analysis.db')
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
latency REAL,
packet_loss REAL)''')
conn.commit()
conn.close()

def record_data(latency, packet_loss):
conn = sqlite3.connect('data_analysis.db')
c = conn.cursor()
c.execute('INSERT INTO data (latency, packet_loss) VALUES (?, ?)', (latency, packet_loss))
conn.commit()
conn.close()

def train_model():
conn = sqlite3.connect('data_analysis.db')
data = pd.read_sql_query("SELECT * FROM data", conn)
conn.close()

X = data[['latency']]
y = data['packet_loss']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

model = LinearRegression()
model.fit(X_train, y_train)
return model

def analyze_task_startup(latency, model):
predicted_packet_loss = model.predict([[latency]])[0]
if predicted_packet_loss > 10:
print("High packet loss predicted: ", predicted_packet_loss)
else:
print("Packet loss within acceptable range: ", predicted_packet_loss)

def join_data_with_external_source():
external_data = pd.DataFrame({
'external_id': [1, 2, 3],
'external_info': ['info1', 'info2', 'info3']
})

conn = sqlite3.connect('data_analysis.db')
data = pd.read_sql_query("SELECT * FROM data", conn)
conn.close()

joined_data = data.merge(external_data, left_on='id', right_on='external_id')
return joined_data

if __name__ == "__main__":
create_database()
record_data(50, 5) # Example data
record_data(100, 20) # Example data

model = train_model()
analyze_task_startup(70, model)
joined_data = join_data_with_external_source()
print(joined_data)