Skip to content
Snippets Groups Projects
Commit e88e02cf authored by Max Rausch Dupont's avatar Max Rausch Dupont
Browse files

Add training script

parent 75e467e5
Branches
No related tags found
No related merge requests found
source $PROJECT_ROOT/setup.sh
${CONDA} run -n ${ENV_NAME} python $PROJECT_ROOT/run_molformer.py
${CONDA} run -n ${ENV_NAME} python $PROJECT_ROOT/train.py
train.py 0 → 100644
import os
import pathlib
from dataclasses import dataclass, field
from typing import Literal
import evaluate
import numpy as np
import wandb
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
HfArgumentParser,
TrainingArguments,
Trainer,
)
from wandb.util import generate_id
@dataclass
class ModelArguments:
train_method: Literal["classification", "regression"] = field(
default="regression",
metadata={
"help": "Method of training. Either 'classification' or 'regression'."
},
)
@dataclass
class DataTrainingArguments:
fold_column: Literal["DataSAIL", "random"] = field(
default="DataSAIL",
metadata={
"help": "Which splitting method to use to pick the cross-validation folds. "
"Either 'DataSAIL' or 'random'."
},
)
file: str = field(
default="20240430_MCHR1_splitted_RJ.csv",
metadata={"help": "Name of the data file."},
)
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments))
model_args, data_args = parser.parse_args_into_dataclasses()
root_dir = pathlib.Path(__file__).parent
model_name = "ibm/MoLFormer-XL-both-10pct"
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
trust_remote_code=True,
num_labels=2 if model_args.train_method == "classification" else 1,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer_config = {
"padding": "max_length",
"truncation": True,
"max_length": tokenizer.model_max_length,
"return_tensors": "pt",
}
def tokenize_input(examples):
return tokenizer(examples["text"], **tokenizer_config)
# Data Preparation
ds = load_dataset(
"csv",
data_files=str(root_dir / data_args.file),
split="train", # Creates this automatically
)
ds = ds.rename_column("smiles", "text")
ds = ds.rename_column(
"class" if model_args.train_method == "classification" else "acvalue_uM",
"label",
)
train = ds.filter(
lambda x: x[f"{data_args.fold_column}_10f"] not in ("Fold_8", "Fold_9")
)
# val = ds.filter(lambda x: x[f"{data_args.fold_column}_10f"] == "Fold_8")
# test = ds.filter(lambda x: x[f"{data_args.fold_column}_10f"] == "Fold_9")
train_tokenized = train.map(tokenize_input, batched=True)
# val_tokenized = val.map(tokenize_input, batched=True)
# test_tokenized = test.map(tokenize_input, batched=True)
# Training
classification_metrics = ["roc_auc", "precision", "recall", "f1"]
regression_metrics = ["mse", "r_squared", "pearsonr"]
def compute_metrics(model_output):
logits, labels = model_output
if model_args.train_method == "classification":
predictions = np.argmax(logits, axis=-1)
metrics = classification_metrics
else:
predictions = logits # Regression values
metrics = regression_metrics
output = {}
for metric in metrics:
metric_func = evaluate.load(metric)
result = metric_func.compute(predictions=predictions, references=labels)
if isinstance(result, dict):
output[metric] = list(result.values()).pop()
else:
output[metric] = result
return output
os.environ["WANDB_PROJECT"] = "cache5_molformer_finetune"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_TAGS"] = f"{model_args.train_method},{data_args.fold_column}"
wandb_prefix = generate_id()
K = 8
for i in range(K):
cv_test_fold = train_tokenized.filter(
lambda x: x[f"{data_args.fold_column}_10f"] == f"Fold_{i}"
)
cv_train_fold = train_tokenized.filter(
lambda x: x[f"{data_args.fold_column}_10f"] != f"Fold_{i}"
)
train_args = TrainingArguments(
output_dir=os.getenv("CHECKPOINT_DIR", root_dir),
evaluation_strategy="epoch",
logging_strategy="epoch",
do_train=True,
do_eval=True,
num_train_epochs=10,
report_to="wandb",
save_total_limit=2,
run_name=f"{wandb_prefix}_{i}",
)
trainer = Trainer(
model=model,
args=train_args,
train_dataset=cv_test_fold,
eval_dataset=cv_train_fold,
compute_metrics=compute_metrics,
)
trainer.train()
wandb.finish()
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment