module fundit::dataPuller

use fundit::sqlUtilities

/*
 *  取指数周收益
 *
 *  get_index_weekly_rets("'FA00000WKG','FA00000WKH','IN0000007G'", 1990.01.01, today())
 */
def get_index_weekly_rets(index_ids, start_date, end_date) {

    s_start_date = iif(start_date.isNull(), "", " AND price_date >= '" + start_date$STRING + "'")
    s_end_date = iif(end_date.isNull(), "", " AND price_date <= '" + end_date$STRING + "'")

    s_query = "SELECT factor_id AS index_id, year_week, price_date, factor_value AS cumulative_nav, ret_1w
               FROM pfdb.cm_factor_performance_weekly
               WHERE isvalid = 1
                 AND factor_id IN (" + index_ids + ")" +
                 s_start_date +
                 s_end_date + "
                 AND ret_1w IS NOT NULL
               UNION
               SELECT fund_id AS index_id, year_week, price_date, cumulative_nav, ret_1w
               FROM mfdb.fund_performance_weekly
               WHERE isvalid = 1
                 AND fund_id IN (" + index_ids + ")" +
                 s_start_date +
                 s_end_date + "
                 AND ret_1w IS NOT NULL
               ORDER BY year_week"

    conn = connect_mysql()

    t = odbc::query(conn, s_query)

    conn.close()

    return t
}


/*
 *  取基金周收益
 *
 *
 *  get_fund_weekly_rets("'MF00003TMH','MF00003UQM'", 1990.01.01, null, true)
 */
def get_fund_weekly_rets(fund_ids, start_date, end_date, isFromMySQL) {

    s_fund_id = iif(fund_ids.isNull(), "", " AND fund_id IN (" + fund_ids + ")")
    s_start_date = iif(start_date.isNull(), "", " AND price_date >= '" + start_date$STRING + "'")
    s_end_date = iif(end_date.isNull(), "", " AND price_date <= '" + end_date$STRING + "'")
    
    s_query = "SELECT fund_id, year_week, price_date, cumulative_nav, ret_1w
               FROM mfdb.fund_performance_weekly
               WHERE isvalid = 1 " + 
                 s_fund_id + 
                 s_start_date +
                 s_end_date + "
                 AND ret_1w IS NOT NULL
               ORDER BY fund_id, year_week"

    conn = connect_mysql()

    t = odbc::query(conn, s_query)

    conn.close()

    return t
}

/*
 *  取组合周收益
 *  TODO: 增加从本地取数据的功能
 *
 *
 *  get_portfolio_weekly_rets("166002,364640", 1990.01.01, today(), true)
 */
def get_portfolio_weekly_rets(portfolio_ids, start_date, end_date, isFromMySQL) {

    s_portfolio_id = iif(portfolio_ids.isNull(), "", " AND portfolio_id IN (" + portfolio_ids + ")")

    s_query = "SELECT portfolio_id, year_week, price_date, cumulative_nav, ret_1w
               FROM pfdb.pf_portfolio_performance_weekly
               WHERE isvalid = 1 " +
                 s_portfolio_id + "
                 AND ret_1w IS NOT NULL
                 AND price_date BETWEEN '" + start_date$STRING + "' AND '" + end_date$STRING + "'
               ORDER BY portfolio_id, year_week"

    conn = connect_mysql()

    t = odbc::query(conn, s_query)

    conn.close()

    return t
}

def get_monthly_ret(entity_type, entity_ids, start_date, end_date, isFromMySQL) {

    s_entity_ids = '';
    
    // 判断输入的 fund_ids 是字符串标量还是向量
    if ( entity_ids.form() == 0 ) {
        s_entity_ids = entity_ids;
    } else {
        s_entity_ids = "'" + entity_ids.concat("','") + "'";
    }

    tmp = get_performance_table_description(entity_type);

    yyyymm_start = start_date.temporalFormat("yyyy-MM")
    yyyymm_end = end_date.temporalFormat("yyyy-MM")

    if(isFromMySQL == true) {

        s_query = "SELECT " + tmp.sec_id_col[0] + ", end_date, price_date, ret_1m AS ret, " + tmp.cumulative_nav_col[0] + " AS nav, ret_ytd_a, ret_incep_a
                   FROM " + tmp.table_name[0] + "
                   WHERE " + tmp.sec_id_col[0] + " IN (" + s_entity_ids + ")
                      AND isvalid = 1
                      AND end_date BETWEEN '" + yyyymm_start + "' AND '" + yyyymm_end + "'
                   ORDER BY " + tmp.sec_id_col[0] + ", end_date";
     
        conn = connect_mysql()
     
        t = odbc::query(conn, s_query)
     
        conn.close()

    } else {
        
        tb_local = load_table_from_local("fundit", tmp.table_name[0])

        s_col = (sqlCol(tmp.sec_id_col[0]), sqlCol("end_date"), sqlColAlias(<ret_1m>, "ret"), sqlColAlias(<cumulative_nav>, "nav"), sqlCol("ret_ytd_a"), sqlCol("ret_incep_a"))
        // TODO: how to make the "fund_id" dynamicly decided by tmp.sec_id_col[0]?
        s_where = expr(<fund_id>, in, s_entity_ids.strReplace("'", "").split(","))
        
        t = sql(s_col, tb_local, s_where).eval()
    
    }

    return t
}


/*
 *  取公私募基金月收益
 *
 *  get_fund_monthly_ret("'MF00003PW1','MF00003PW1'", 1990.01.01, today(), true)
 */

def get_fund_monthly_ret(fund_ids, start_date, end_date, isFromMySQL) {

    return get_monthly_ret('FD', fund_ids, start_date, end_date, isFromMySQL);
}

/*
 * 取无风险月度利率
 *
 * get_risk_free_rate(1990.01.01, today())
 */
def get_risk_free_rate(start_date, end_date) {
    
    return get_monthly_ret('IX', "'IN0000000M'", start_date, end_date, true);
}


/*
 * 取基金最新收益及净值
 *
 * get_fund_latest_nav_performance("'HF000004KN','HF00018WXG'")
 */
def get_fund_latest_nav_performance(fund_ids, isFromMySQL) {

    if(isFromMySQL == true) {
    
        s_query = "SELECT *
                   FROM mfdb.fund_latest_nav_performance
                   WHERE fund_id IN (" + fund_ids + ")
                     AND isvalid = 1
                     ORDER BY fund_id"

        conn = connect_mysql()
    
        t = odbc::query(conn, s_query)
    
        conn.close()

    } else {

        tb_local = load_table_from_local("fundit", "mfdb.fund_latest_nav_performance")

        s_col = sqlCol("*")

        s_where = expr(<fund_id>, in, fund_ids.strReplace("'", "").split(","))
        
        t = sql(s_col, tb_local, s_where).eval()

    }

    return t

}

/*
 * 取私募基金净值
 * 
 * 
 * Create: 202408                                                    Joey
 *                 TODO: add isvalid and nav > 0 for local version
 * 
 *
 * Example: get_nav_by_price_date('HF', "'HF000004KN','HF00018WXG'", 2024.05.01, true)
 */
def get_nav_by_price_date(entity_type, entity_ids, price_date, isFromMySQL) {

    s_entity_ids = '';
    
    // 判断输入的 fund_ids 是字符串标量还是向量
    if ( entity_ids.form() == 0 ) {
        s_entity_ids = entity_ids;
    } else {
        s_entity_ids = "'" + entity_ids.concat("','") + "'";
    }

    tmp = get_nav_table_description(entity_type);

    if(isFromMySQL == true) {

        nav_table_name = tmp.table_name[0];
    
        s_query = "SELECT " + tmp.sec_id_col[0] + ", price_date, " + tmp.cumulative_nav_col[0] + ", " + tmp.nav_col[0] + "
                   FROM " + tmp.table_name[0] + "
                   WHERE " + tmp.sec_id_col[0] + " IN (" + s_entity_ids + ")
                     AND isvalid = 1
                     AND " + tmp.cumulative_nav_col[0] + " > 0
                     AND price_date >= '" + price_date$STRING + "'
                   ORDER BY " + tmp.sec_id_col[0] + ", price_date";
    
        conn = connect_mysql();
    
        t = odbc::query(conn, s_query);
    
        conn.close();

    } else {
    
        tb_local = load_table_from_local("fundit", tmp.table_name[0])

        s_col = sqlCol("*")

        // TODO: how to make the "fund_id" dynamicly decided by tmp.sec_id_col[0]?
        s_where = [expr(<fund_id>, in, s_entity_ids.strReplace("'", "").split(",")), <price_date >= price_date>]
        
        t = sql(s_col, tb_local, s_where).eval()
    
    }

    return t

}


/*
 * 取指数因子点位
 *
 * get_index_nav_by_price_date("'IN00000008','FA00000WKG'", 2024.06.01)
 */
def get_index_nav_by_price_date(index_ids, price_date) {

    s_query = "SELECT index_id, price_date, close AS cumulative_nav
               FROM mfdb.market_indexes
               WHERE index_id IN (" + index_ids + ")
                 AND isvalid = 1
                 AND close > 0
                 AND price_date >= '" + price_date + "'
               UNION
               SELECT index_id AS index_id, price_date, index_value AS cumulative_nav
               FROM mfdb.indexes_ty_index
               WHERE index_id IN (" + index_ids + ")
                 AND isvalid = 1
                 AND index_value > 0
                 AND price_date >= '" + price_date + "'
               UNION
               SELECT factor_id AS index_id, price_date, factor_value AS cumulative_nav
               FROM pfdb.cm_factor_value
               WHERE factor_id IN (" + index_ids + ")
                 AND isvalid = 1
                 AND factor_value > 0
                 AND price_date >= '" + price_date + "'
               ORDER BY price_date"

    conn = connect_mysql()

    t = odbc::query(conn, s_query)

    conn.close()

    return t

}


/*
 * 取有效基金基本信息
 *
 * Example: get_fund_info("'HF000004KN','HF00018WXG'")
 * 
 */
def get_fund_info(fund_ids) {

    s_query = "SELECT fi.fund_id, fi.inception_date, fi.primary_benchmark_id AS benchmark_id, IFNULL(fi.initial_unit_value, 1) AS ini_value, fs.strategy, fs.substrategy
               FROM mfdb.fund_information fi
               INNER JOIN mfdb.fund_strategy fs ON fi.fund_id = fs.fund_id AND fs.isvalid = 1
               WHERE fi.fund_id IN (" + fund_ids + ")
               AND fi.isvalid = 1
               ORDER BY fi.fund_id"

    conn = connect_mysql()

    t = odbc::query(conn, s_query)

    conn.close()

    return t

}

/*
 *  取组合有效信息
 * 
 *  Example: get_portfolio_info('166002,166114');
 *  
 */
def get_portfolio_info(portfolio_ids) {

    s_query = "SELECT cpm.id AS portfolio_id, cpm.userid, cpm.customer_id, cpm.inception_date, cpm.portfolio_source, cpm.portfolio_type
               FROM pfdb.`pf_customer_portfolio_map` cpm
               INNER JOIN pfdb.cm_user u ON cpm.userid = u.userid
               WHERE cpm.id IN (" + portfolio_ids + ")
                 AND cpm.isvalid = 1
                 AND u.isvalid = 1
               ORDER BY cpm.id"

    conn = connect_mysql()

    t = odbc::query(conn, s_query)

    conn.close()

    return t
}

/*
 * 取私募基金净值更新信息, 返回基金及其净值更新的最早净值日期
 *
 * @param fund_ids: fund_id STRING VECTOR
 * @param update_time: all updates after this time
 *
 * Example: get_fund_list_by_nav_updatetime(null, 2024.07.19T10:00:00)
 * 
 */
def get_fund_list_by_nav_updatetime(fund_ids, updatetime) {

    s_fund_sql = '';
    // 这里要用 isVoid, 因为 isNull对向量返回的是布尔向量
    if (! isVoid(fund_ids)){
        s_fund_ids = fund_ids.concat("','");
        s_fund_sql = " AND fi.fund_id IN ('" + s_fund_ids + "')";
    }
    
    s_query = "SELECT fi.fund_id, MIN(nav.price_date) AS price_date,
                      fi.inception_date, fi.primary_benchmark_id AS benchmark_id, IFNULL(fi.initial_unit_value, 1) AS ini_value
               FROM mfdb.fund_information fi
               INNER JOIN mfdb.nav ON fi.fund_id = nav.fund_id
               WHERE fi.isvalid = 1" +
                 s_fund_sql + "
                 AND nav.cumulative_nav > 0
                 AND nav.updatetime >= '" + updatetime + "'
               GROUP BY fi.fund_id
               ORDER BY fi.fund_id"

    conn = connect_mysql()

    t = odbc::query(conn, s_query)

    conn.close()

    return t

}


/*
 * 取私募基金用于月末 fund_performance 表更新的净值
 * 
 * @param fund_ids: 逗号分隔的ID字符串, 每个ID都有''
 * @param month_end: 月末日期字符串  YYYY-MM
 * 
 * 
 */
def get_nav_for_hedge_fund_performance(fund_ids, month_end) {
   
    s_query = "CALL pfdb.sp_get_nav_for_fund_performance(" + fund_ids + ", '" + month_end + "', 1);"

    conn = connect_mysql()

    t = odbc::query(conn, s_query)

    conn.close()

    return t

}

/*
 *  取某时间段的基金主基准
 *  NOTE: 目前数据库里只存最新的基准,以后很可能会支持时间序列
 * 
 *  Example: get_fund_primary_benchmark("'MF00003PW2', 'MF00003PW1', 'MF00003PXO'", '1990-01', '2024-06');
 */
def get_fund_primary_benchmark(fund_ids, month_start, month_end) {

    s_query = "SELECT fund_id, primary_benchmark_id AS benchmark_id, inception_date
               FROM mfdb.fund_information
               WHERE fund_id IN (" + fund_ids + ")
                 AND isvalid = 1;";

    conn = connect_mysql();

    t = odbc::query(conn, s_query);

    conn.close();

    t.addColumn('end_date', MONTH);
    m_start = temporalParse(month_start, 'yyyy-MM');
    m_end = temporalParse(month_end, 'yyyy-MM');
    tb_end_date = table(m_start..m_end AS end_date);

    return (SELECT t.fund_id, d.end_date, t.benchmark_id FROM t JOIN tb_end_date d WHERE d.end_date >= t.inception_date.month());

}


/*
 *  取某时间段的组合主基准
 *  NOTE: 目前所有组合默认主基准是沪深300,以后很可能会改
 * 
 *  Example: get_portfolio_primary_benchmark("166002,166114", '1990-01', '2024-08');
 */
def get_portfolio_primary_benchmark(portfolio_ids, month_start, month_end) {

    s_query = "SELECT id AS portfolio_id, 'IN00000008' AS benchmark_id, inception_date
               FROM pfdb.pf_customer_portfolio_map
               WHERE id IN (" + portfolio_ids + ")
                 AND isvalid = 1;";

    conn = connect_mysql();

    t = odbc::query(conn, s_query);

    conn.close();

    t.addColumn('end_date', MONTH);
    m_start = temporalParse(month_start, 'yyyy-MM');
    m_end = temporalParse(month_end, 'yyyy-MM');
    tb_end_date = table(m_start..m_end AS end_date);

    return (SELECT t.portfolio_id, d.end_date, t.benchmark_id FROM t JOIN tb_end_date d WHERE d.end_date >= t.inception_date.month());

}


/*
 *  取某时间段的基金BFI因子
 * 
 *  Example: get_fund_bfi_factors("'MF00003PW2', 'MF00003PW1', 'MF00003PXO'", '1990-01', '2024-06');
 */
def get_fund_bfi_factors(fund_ids, month_start, month_end) {

    s_query = "SELECT fund_id, end_date, factor_id
               FROM pfdb.pf_fund_factor_bfi_by_category_group
               WHERE fund_id IN (" + fund_ids + ")
                 AND end_date >= '" + month_start + "'
                 AND end_date <= '" + month_end + "'
                 AND isvalid = 1
               ORDER BY fund_id, end_date, factor_id;";

    conn = connect_mysql();

    t = odbc::query(conn, s_query);

    conn.close();

    return t;

}


/*
 *  取某时间段的组合BFI因子
 * 
 *  Example: get_portfolio_bfi_factors("166002,166114", '1900-01', '2024-06');
 */
def get_portfolio_bfi_factors(portfolio_ids, month_start, month_end) {

    s_query = "SELECT portfolio_id, end_date, factor_id
               FROM pfdb.pf_portfolio_factor_bfi_by_category_group
               WHERE portfolio_id IN (" + portfolio_ids + ")
                 AND end_date >= '" + month_start + "'
                 AND end_date <= '" + month_end + "'
                 AND isvalid = 1
               ORDER BY portfolio_id, end_date, factor_id;";

    conn = connect_mysql();

    t = odbc::query(conn, s_query);

    conn.close();

    return t;

}

/*
 * 取组合交易表
 *
 *
 * Example: get_portfolio_holding_history("166002,364640")
 */
def get_portfolio_holding_history(portfolio_ids) {

    s_query = "SELECT portfolio_id, holding_date, fund_id, amount, fund_share, ROUND(amount/fund_share, 6) AS nav
               FROM pfdb.pf_portfolio_fund_history
               WHERE portfolio_id IN (" + portfolio_ids + ")
                 AND isvalid = 1
               ORDER BY portfolio_id, holding_date";

    conn = connect_mysql();

    t = odbc::query(conn, s_query);

    conn.close();

    return t;

}

/*
 *  取基金证券从某日期后的所有净值
 *  @param json_query <JSON>: [{sec_id:xxx, holding_date: yyyy-mm-dd}]
 * 
 */
def get_holding_nav(json_query) {

    s_query = "CALL pfdb.sp_get_nav_after_date('" + json_query + "')";

    conn = connect_mysql();

    t = odbc::query(conn, s_query);

    conn.close();

    return t;
}