from __future__ import annotations

import argparse
from pathlib import Path


def build_parser() -> argparse.ArgumentParser:
  parser = argparse.ArgumentParser(
    description="LoRA fine-tuning workflow for a compact phishing-triage model.",
  )
  parser.add_argument(
    "--model-name",
    default="microsoft/Phi-3-mini-4k-instruct",
    help="Base model to fine-tune.",
  )
  parser.add_argument(
    "--train-file",
    default=str(Path("ai-model-project/data/sample_train.jsonl")),
    help="Path to training data in JSONL format.",
  )
  parser.add_argument(
    "--output-dir",
    default="outputs/phishing-triage-lora",
    help="Where to save adapters and tokenizer files.",
  )
  parser.add_argument("--epochs", type=int, default=2, help="Training epochs.")
  parser.add_argument("--batch-size", type=int, default=2, help="Per-device batch size.")
  return parser


def main() -> None:
  parser = build_parser()
  args = parser.parse_args()

  # Imports are intentionally local so the file stays readable as a workflow
  # reference even if the training dependencies are not installed in every environment.
  from datasets import load_dataset
  from peft import LoraConfig, get_peft_model
  from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
  )

  dataset = load_dataset("json", data_files=args.train_file, split="train")
  tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  tokenizer.pad_token = tokenizer.eos_token

  def format_example(record: dict[str, str]) -> dict[str, str]:
    prompt = (
      "Classify the following message as phishing or benign and produce a short analyst summary.\n\n"
      f"Message: {record['text']}\n\n"
      f"Label: {record['label']}\n"
      f"Summary: {record['summary']}"
    )
    return {"text": prompt}

  formatted_dataset = dataset.map(format_example)

  def tokenize(record: dict[str, str]) -> dict[str, list[int]]:
    tokens = tokenizer(
      record["text"],
      truncation=True,
      max_length=512,
      padding="max_length",
    )
    tokens["labels"] = tokens["input_ids"].copy()
    return tokens

  tokenized_dataset = formatted_dataset.map(tokenize)
  model = AutoModelForCausalLM.from_pretrained(args.model_name)

  lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
  )

  model = get_peft_model(model, lora_config)

  trainer = Trainer(
    model=model,
    args=TrainingArguments(
      output_dir=args.output_dir,
      num_train_epochs=args.epochs,
      per_device_train_batch_size=args.batch_size,
      logging_steps=1,
      save_strategy="epoch",
      evaluation_strategy="no",
      learning_rate=2e-4,
      warmup_ratio=0.1,
      report_to="none",
    ),
    train_dataset=tokenized_dataset,
    data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
  )

  trainer.train()
  model.save_pretrained(args.output_dir)
  tokenizer.save_pretrained(args.output_dir)


if __name__ == "__main__":
  main()
