//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TestDialect.h"
#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/ODSSupport.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Base64.h"
#include "llvm/Support/Casting.h"

#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include <cstdint>
#include <numeric>
#include <optional>

// Include this before the using namespace lines below to test that we don't
// have namespace dependencies.
#include "TestOpsDialect.cpp.inc"

using namespace mlir;
using namespace test;

//===----------------------------------------------------------------------===//
// PropertiesWithCustomPrint
//===----------------------------------------------------------------------===//

LogicalResult
test::setPropertiesFromAttribute(PropertiesWithCustomPrint &prop,
                                 Attribute attr,
                                 function_ref<InFlightDiagnostic()> emitError) {
  DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
  if (!dict) {
    emitError() << "expected DictionaryAttr to set TestProperties";
    return failure();
  }
  auto label = dict.getAs<mlir::StringAttr>("label");
  if (!label) {
    emitError() << "expected StringAttr for key `label`";
    return failure();
  }
  auto valueAttr = dict.getAs<IntegerAttr>("value");
  if (!valueAttr) {
    emitError() << "expected IntegerAttr for key `value`";
    return failure();
  }

  prop.label = std::make_shared<std::string>(label.getValue());
  prop.value = valueAttr.getValue().getSExtValue();
  return success();
}

DictionaryAttr
test::getPropertiesAsAttribute(MLIRContext *ctx,
                               const PropertiesWithCustomPrint &prop) {
  SmallVector<NamedAttribute> attrs;
  Builder b{ctx};
  attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label)));
  attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
  return b.getDictionaryAttr(attrs);
}

llvm::hash_code test::computeHash(const PropertiesWithCustomPrint &prop) {
  return llvm::hash_combine(prop.value, StringRef(*prop.label));
}

void test::customPrintProperties(OpAsmPrinter &p,
                                 const PropertiesWithCustomPrint &prop) {
  p.printKeywordOrString(*prop.label);
  p << " is " << prop.value;
}

ParseResult test::customParseProperties(OpAsmParser &parser,
                                        PropertiesWithCustomPrint &prop) {
  std::string label;
  if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") ||
      parser.parseInteger(prop.value))
    return failure();
  prop.label = std::make_shared<std::string>(std::move(label));
  return success();
}

//===----------------------------------------------------------------------===//
// MyPropStruct
//===----------------------------------------------------------------------===//

Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
  return StringAttr::get(ctx, content);
}

LogicalResult
MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
                          function_ref<InFlightDiagnostic()> emitError) {
  StringAttr strAttr = dyn_cast<StringAttr>(attr);
  if (!strAttr) {
    emitError() << "Expect StringAttr but got " << attr;
    return failure();
  }
  prop.content = strAttr.getValue();
  return success();
}

llvm::hash_code MyPropStruct::hash() const {
  return hash_value(StringRef(content));
}

LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
                                         MyPropStruct &prop) {
  StringRef str;
  if (failed(reader.readString(str)))
    return failure();
  prop.content = str.str();
  return success();
}

void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
                               MyPropStruct &prop) {
  writer.writeOwnedString(prop.content);
}

//===----------------------------------------------------------------------===//
// VersionedProperties
//===----------------------------------------------------------------------===//

LogicalResult
test::setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
                                 function_ref<InFlightDiagnostic()> emitError) {
  DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
  if (!dict) {
    emitError() << "expected DictionaryAttr to set VersionedProperties";
    return failure();
  }
  auto value1Attr = dict.getAs<IntegerAttr>("value1");
  if (!value1Attr) {
    emitError() << "expected IntegerAttr for key `value1`";
    return failure();
  }
  auto value2Attr = dict.getAs<IntegerAttr>("value2");
  if (!value2Attr) {
    emitError() << "expected IntegerAttr for key `value2`";
    return failure();
  }

  prop.value1 = value1Attr.getValue().getSExtValue();
  prop.value2 = value2Attr.getValue().getSExtValue();
  return success();
}

DictionaryAttr test::getPropertiesAsAttribute(MLIRContext *ctx,
                                              const VersionedProperties &prop) {
  SmallVector<NamedAttribute> attrs;
  Builder b{ctx};
  attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1)));
  attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
  return b.getDictionaryAttr(attrs);
}

llvm::hash_code test::computeHash(const VersionedProperties &prop) {
  return llvm::hash_combine(prop.value1, prop.value2);
}

void test::customPrintProperties(OpAsmPrinter &p,
                                 const VersionedProperties &prop) {
  p << prop.value1 << " | " << prop.value2;
}

ParseResult test::customParseProperties(OpAsmParser &parser,
                                        VersionedProperties &prop) {
  if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
      parser.parseInteger(prop.value2))
    return failure();
  return success();
}

//===----------------------------------------------------------------------===//
// Bytecode Support
//===----------------------------------------------------------------------===//

LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
                                         MutableArrayRef<int64_t> prop) {
  uint64_t size;
  if (failed(reader.readVarInt(size)))
    return failure();
  if (size != prop.size())
    return reader.emitError("array size mismach when reading properties: ")
           << size << " vs expected " << prop.size();
  for (auto &elt : prop) {
    uint64_t value;
    if (failed(reader.readVarInt(value)))
      return failure();
    elt = value;
  }
  return success();
}

void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
                               ArrayRef<int64_t> prop) {
  writer.writeVarInt(prop.size());
  for (auto elt : prop)
    writer.writeVarInt(elt);
}

//===----------------------------------------------------------------------===//
// Dynamic operations
//===----------------------------------------------------------------------===//

static std::unique_ptr<DynamicOpDefinition>
getDynamicGenericOp(TestDialect *dialect) {
  return DynamicOpDefinition::get(
      "dynamic_generic", dialect, [](Operation *op) { return success(); },
      [](Operation *op) { return success(); });
}

static std::unique_ptr<DynamicOpDefinition>
getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
  return DynamicOpDefinition::get(
      "dynamic_one_operand_two_results", dialect,
      [](Operation *op) {
        if (op->getNumOperands() != 1) {
          op->emitOpError()
              << "expected 1 operand, but had " << op->getNumOperands();
          return failure();
        }
        if (op->getNumResults() != 2) {
          op->emitOpError()
              << "expected 2 results, but had " << op->getNumResults();
          return failure();
        }
        return success();
      },
      [](Operation *op) { return success(); });
}

static std::unique_ptr<DynamicOpDefinition>
getDynamicCustomParserPrinterOp(TestDialect *dialect) {
  auto verifier = [](Operation *op) {
    if (op->getNumOperands() == 0 && op->getNumResults() == 0)
      return success();
    op->emitError() << "operation should have no operands and no results";
    return failure();
  };
  auto regionVerifier = [](Operation *op) { return success(); };

  auto parser = [](OpAsmParser &parser, OperationState &state) {
    return parser.parseKeyword("custom_keyword");
  };

  auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
    printer << op->getName() << " custom_keyword";
  };

  return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
                                  verifier, regionVerifier, parser, printer);
}

//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//

void test::registerTestDialect(DialectRegistry &registry) {
  registry.insert<TestDialect>();
}

void test::testSideEffectOpGetEffect(
    Operation *op,
    SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
        &effects) {
  auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
  if (!effectsAttr)
    return;

  effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
}

// This is the implementation of a dialect fallback for `TestEffectOpInterface`.
struct TestOpEffectInterfaceFallback
    : public TestEffectOpInterface::FallbackModel<
          TestOpEffectInterfaceFallback> {
  static bool classof(Operation *op) {
    bool isSupportedOp =
        op->getName().getStringRef() == "test.unregistered_side_effect_op";
    assert(isSupportedOp && "Unexpected dispatch");
    return isSupportedOp;
  }

  void
  getEffects(Operation *op,
             SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
                 &effects) const {
    testSideEffectOpGetEffect(op, effects);
  }
};

void TestDialect::initialize() {
  registerAttributes();
  registerTypes();
  registerOpsSyntax();
  addOperations<ManualCppOpWithFold>();
  registerTestDialectOperations(this);
  registerDynamicOp(getDynamicGenericOp(this));
  registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
  registerDynamicOp(getDynamicCustomParserPrinterOp(this));
  registerInterfaces();
  allowUnknownOperations();

  // Instantiate our fallback op interface that we'll use on specific
  // unregistered op.
  fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
}

TestDialect::~TestDialect() {
  delete static_cast<TestOpEffectInterfaceFallback *>(
      fallbackEffectOpInterfaces);
}

Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
                                            Type type, Location loc) {
  return TestOpConstant::create(builder, loc, type, value);
}

void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
                                               OperationName opName) {
  if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
      typeID == TypeID::get<TestEffectOpInterface>())
    return fallbackEffectOpInterfaces;
  return nullptr;
}

LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
                                                    NamedAttribute namedAttr) {
  if (namedAttr.getName() == "test.invalid_attr")
    return op->emitError() << "invalid to use 'test.invalid_attr'";
  return success();
}

LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
                                                    unsigned regionIndex,
                                                    unsigned argIndex,
                                                    NamedAttribute namedAttr) {
  if (namedAttr.getName() == "test.invalid_attr")
    return op->emitError() << "invalid to use 'test.invalid_attr'";
  return success();
}

LogicalResult
TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
                                         unsigned resultIndex,
                                         NamedAttribute namedAttr) {
  if (namedAttr.getName() == "test.invalid_attr")
    return op->emitError() << "invalid to use 'test.invalid_attr'";
  return success();
}

std::optional<Dialect::ParseOpHook>
TestDialect::getParseOperationHook(StringRef opName) const {
  if (opName == "test.dialect_custom_printer") {
    return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
      return parser.parseKeyword("custom_format");
    }};
  }
  if (opName == "test.dialect_custom_format_fallback") {
    return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
      return parser.parseKeyword("custom_format_fallback");
    }};
  }
  if (opName == "test.dialect_custom_printer.with.dot") {
    return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
      return ParseResult::success();
    }};
  }
  return std::nullopt;
}

llvm::unique_function<void(Operation *, OpAsmPrinter &)>
TestDialect::getOperationPrinter(Operation *op) const {
  StringRef opName = op->getName().getStringRef();
  if (opName == "test.dialect_custom_printer") {
    return [](Operation *op, OpAsmPrinter &printer) {
      printer.getStream() << " custom_format";
    };
  }
  if (opName == "test.dialect_custom_format_fallback") {
    return [](Operation *op, OpAsmPrinter &printer) {
      printer.getStream() << " custom_format_fallback";
    };
  }
  return {};
}

static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
                               PatternRewriter &rewriter) {
  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
      op, rewriter.getI32IntegerAttr(42));
  return success();
}

void TestDialect::getCanonicalizationPatterns(
    RewritePatternSet &results) const {
  results.add(&dialectCanonicalizationPattern);
}

//===----------------------------------------------------------------------===//
// TestCallWithSegmentsOp
//===----------------------------------------------------------------------===//
// The op `test.call_with_segments` models a call-like operation whose operands
// are divided into 3 variadic segments: `prefix`, `args`, and `suffix`.
// Only the middle segment represents the actual call arguments. The op uses
// the AttrSizedOperandSegments trait, so we can derive segment boundaries from
// the generated `operandSegmentSizes` attribute. We provide custom helpers to
// expose the logical call arguments as both a read-only range and a mutable
// range bound to the proper segment so that insertion/erasure updates the
// attribute automatically.

// Segment layout indices in the DenseI32ArrayAttr: [prefix, args, suffix].
static constexpr unsigned kTestCallWithSegmentsArgsSegIndex = 1;

Operation::operand_range CallWithSegmentsOp::getArgOperands() {
  // Leverage generated getters for segment sizes: slice between prefix and
  // suffix using current operand list.
  return getOperation()->getOperands().slice(getPrefix().size(),
                                             getArgs().size());
}

MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() {
  Operation *op = getOperation();

  // Obtain the canonical segment size attribute name for this op.
  auto segName =
      CallWithSegmentsOp::getOperandSegmentSizesAttrName(op->getName());
  auto sizesAttr = op->getAttrOfType<DenseI32ArrayAttr>(segName);
  assert(sizesAttr && "missing operandSegmentSizes attribute on op");

  // Compute the start and length of the args segment from the prefix size and
  // args size stored in the attribute.
  auto sizes = sizesAttr.asArrayRef();
  unsigned start = static_cast<unsigned>(sizes[0]); // prefix size
  unsigned len = static_cast<unsigned>(sizes[1]);   // args size

  NamedAttribute segNamed(segName, sizesAttr);
  MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex,
                                              segNamed};

  return MutableOperandRange(op, start, len, {binding});
}
