Skip to content
Snippets Groups Projects
Commit 5792415a authored by Max Rausch Dupont's avatar Max Rausch Dupont
Browse files

Fix training and test set

parent 8021fa86
Branches
No related tags found
No related merge requests found
......@@ -111,7 +111,7 @@ def main():
predictions = np.argmax(logits, axis=-1)
metrics = classification_metrics
else:
predictions = logits.squeeze() # Regression values
predictions = logits.squeeze() # Regression values
metrics = regression_metrics
return metrics(torch.from_numpy(predictions), torch.from_numpy(labels))
......@@ -142,14 +142,14 @@ def main():
run_name=f"{wandb_prefix}_{i}",
warmup_ratio=0.1,
lr_scheduler_type="linear", # Linear schedule with warmup
weight_decay=0.01
weight_decay=0.01,
)
trainer = Trainer(
model=model,
args=train_args,
train_dataset=cv_test_fold,
eval_dataset=cv_train_fold,
train_dataset=cv_train_fold,
eval_dataset=cv_test_fold,
compute_metrics=compute_metrics,
)
trainer.train()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment